| # mypy: allow-untyped-defs |
| import torch |
| import torch._prims_common as utils |
| |
| # Utilities should come BEFORE this import |
| from torch._decomp import register_decomposition |
| |
| from torch._prims_common import TensorLikeType |
| from torch._prims_common.wrappers import out_wrapper |
| from torch._refs import _broadcast_shapes |
| |
| # Data conversion references. |
| # |
| # Note: this module breaks the usual _refs to torch naming scheme where |
| # _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not |
| # part of _refs/__init__.py to avoid name clashes with Python builtin types |
| # (like int). |
| |
| __all__ = [ |
| # dtypes |
| "bfloat16", |
| "bool", |
| "byte", |
| "cdouble", |
| "cfloat", |
| "chalf", |
| "char", |
| "double", |
| "float", |
| "half", |
| "int", |
| "long", |
| "short", |
| # misc |
| "complex", |
| "polar", |
| ] |
| |
| |
| def _make_conversion_method(name: str, dtype: torch.dtype): |
| def fn( |
| self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format |
| ) -> TensorLikeType: |
| return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] |
| |
| fn.__name__ = name |
| return fn |
| |
| |
| bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) |
| |
| bool = _make_conversion_method("bool", torch.bool) |
| |
| byte = _make_conversion_method("byte", torch.uint8) |
| |
| cdouble = _make_conversion_method("cdouble", torch.cdouble) |
| |
| cfloat = _make_conversion_method("cfloat", torch.cfloat) |
| |
| chalf = _make_conversion_method("chalf", torch.complex32) |
| |
| char = _make_conversion_method("char", torch.int8) |
| |
| double = _make_conversion_method("double", torch.double) |
| |
| float = _make_conversion_method("float", torch.float) |
| |
| half = _make_conversion_method("half", torch.half) |
| |
| int = _make_conversion_method("int", torch.int) |
| |
| long = _make_conversion_method("long", torch.long) |
| |
| short = _make_conversion_method("short", torch.short) |
| |
| |
| @register_decomposition(torch._ops.ops.aten.complex) |
| # Note: complex has type promotion tests disabled due to different semantics. |
| # exact_dtype is for compat with complex_check_dtype from core. |
| @out_wrapper(exact_dtype=True) |
| def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: |
| allowed_dtypes = (torch.float32, torch.float64, torch.float16) |
| torch._check( |
| real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, |
| lambda: ( |
| f"Expected both inputs to be Half, Float or Double tensors but got " |
| f"{real.dtype} and {imag.dtype}" |
| ), |
| ) |
| torch._check( |
| real.dtype == imag.dtype, |
| lambda: ( |
| f"Expected object of scalar type {real.dtype} but got " |
| f"scalar type {imag.dtype} for second argument" |
| ), |
| ) |
| result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] |
| common_shape = _broadcast_shapes(real.shape, imag.shape) |
| result = real.new_empty( |
| common_shape, |
| dtype=result_dtype, |
| layout=real.layout, |
| device=real.device, |
| # pin_memory=real.is_pinned(), # NYI |
| ) |
| result.real = real |
| result.imag = imag |
| return result |
| |
| |
| @register_decomposition(torch._ops.ops.aten.polar) |
| # Note: polar has type promotion tests disabled due to different semantics. |
| # exact_dtype is for compat with complex_check_dtype from core. |
| @out_wrapper(exact_dtype=True) |
| def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: |
| result = torch.complex(abs, angle) |
| result.real = abs * torch.cos(angle) |
| result.imag = abs * torch.sin(angle) |
| return result |