[quant][docs] Add types for scale and zero_point tensor for torch.fake_quantize_per_channel_affine docs (#85733)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/85525
Test Plan:
visual inspection for the docs page
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85733
Approved by: https://github.com/z-a-f
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 70c87c4..85de969 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -11927,14 +11927,14 @@
)
Args:
- input (Tensor): the input value(s), in ``torch.float32``.
- scale (double or Tensor): quantization scale
- zero_point (int64 or Tensor): quantization zero_point
+ input (Tensor): the input value(s), ``torch.float32`` tensor
+ scale (double scalar or ``float32`` Tensor): quantization scale
+ zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point
quant_min (int64): lower bound of the quantized domain
quant_max (int64): upper bound of the quantized domain
Returns:
- Tensor: A newly fake_quantized tensor
+ Tensor: A newly fake_quantized ``torch.float32`` tensor
Example::
@@ -11966,15 +11966,15 @@
)
Args:
- input (Tensor): the input value(s), in ``torch.float32``.
- scale (Tensor): quantization scale, per channel
- zero_point (Tensor): quantization zero_point, per channel
+ input (Tensor): the input value(s), in ``torch.float32``
+ scale (Tensor): quantization scale, per channel in ``torch.float32``
+ zero_point (Tensor): quantization zero_point, per channel in ``torch.int32`` or ``torch.half`` or ``torch.float32``
axis (int32): channel axis
quant_min (int64): lower bound of the quantized domain
quant_max (int64): upper bound of the quantized domain
Returns:
- Tensor: A newly fake_quantized per channel tensor
+ Tensor: A newly fake_quantized per channel ``torch.float32`` tensor
Example::
@@ -11988,7 +11988,7 @@
>>> scales = (torch.randn(2) + 1) * 0.05
>>> scales
tensor([0.0475, 0.0486])
- >>> zero_points = torch.zeros(2).to(torch.long)
+ >>> zero_points = torch.zeros(2).to(torch.int32)
>>> zero_points
tensor([0, 0])
>>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255)