| # Copyright (c) Meta Platforms, Inc. and affiliates |
| |
| import torch |
| |
| from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor |
| |
| __all__ = [] # type: ignore[var-annotated] |
| |
| BINARY_NAMES = [ |
| "add", |
| "atan2", |
| "arctan2", |
| "bitwise_and", |
| "bitwise_or", |
| "bitwise_xor", |
| "bitwise_left_shift", |
| "bitwise_right_shift", |
| "div", |
| "divide", |
| "floor_divide", |
| "fmod", |
| "logaddexp", |
| "logaddexp2", |
| "mul", |
| "multiply", |
| "nextafter", |
| "remainder", |
| "sub", |
| "subtract", |
| "true_divide", |
| "eq", |
| "ne", |
| "le", |
| "ge", |
| "greater", |
| "greater_equal", |
| "gt", |
| "less_equal", |
| "lt", |
| "less", |
| "maximum", |
| "minimum", |
| "fmax", |
| "fmin", |
| "not_equal", |
| ] |
| |
| INPLACE_BINARY_NAMES = [ |
| n + "_" |
| for n in ( |
| list( |
| set(BINARY_NAMES) |
| - { |
| "logaddexp", |
| "logaddexp2", |
| "equal", |
| "fmin", |
| "minimum", |
| "maximum", |
| "fmax", |
| } |
| ) |
| ) |
| ] |
| |
| |
| def _get_at_least_one_mask(a, b): |
| if not is_masked_tensor(a) and not is_masked_tensor(b): |
| raise TypeError("At least one of `a` and `b` must be a MaskedTensor") |
| if not _masks_match(a, b): |
| raise ValueError("a and b must have matching masks") |
| if is_masked_tensor(a): |
| return a.get_mask() |
| return b.get_mask() |
| |
| |
| def _binary_helper(fn, args, kwargs, inplace): |
| if len(kwargs) != 0: |
| raise ValueError("len(kwargs) must equal 0") |
| for a in args[2:]: |
| if torch.is_tensor(a): |
| raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs") |
| |
| if not _masks_match(*args[:2]): |
| raise ValueError( |
| "Input masks must match. If you need support for this, please open an issue on Github." |
| ) |
| |
| data_args, data_kwargs = _map_mt_args_kwargs( |
| args, kwargs, lambda x: x.get_data() |
| ) |
| mask_args, mask_kwargs = _map_mt_args_kwargs( |
| args, kwargs, lambda x: x.get_mask() |
| ) |
| |
| args0_layout = data_args[0].layout |
| same_layout = ( |
| (torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and |
| (args0_layout == data_args[1].layout) |
| ) |
| |
| if args0_layout == torch.sparse_coo: |
| if same_layout: |
| if not _tensors_match(data_args[0].indices(), data_args[1].indices()): |
| raise ValueError( |
| "sparse_coo indices must match. If you need support for this, please open an issue on Github." |
| ) |
| if data_args[0].size() != data_args[1].size(): |
| raise ValueError("input1 and input2 must have the same size for binary functions.") |
| |
| data_args[1] = data_args[1].values() |
| |
| i = data_args[0].indices() |
| size = data_args[0].size() |
| data_args[0] = data_args[0].values() |
| v = fn(*data_args) |
| result_data = torch.sparse_coo_tensor(i, v, size) |
| |
| elif args0_layout == torch.sparse_csr: |
| if same_layout: |
| if not ( |
| _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices()) |
| and _tensors_match( |
| data_args[0].col_indices(), data_args[1].col_indices() |
| ) |
| ): |
| raise ValueError( |
| "sparse_csr indices must match. If you need support for this, please open an issue on Github." |
| ) |
| |
| data_args[1] = data_args[1].values() |
| |
| crow = data_args[0].crow_indices() |
| col = data_args[0].col_indices() |
| data_args[0] = data_args[0].values() |
| v = fn(*data_args) |
| result_data = torch.sparse_csr_tensor(crow, col, v) |
| |
| else: |
| result_data = fn(*data_args) |
| |
| if inplace: |
| args[0]._set_data_mask(result_data, mask_args[0]) |
| return args[0] |
| else: |
| result_mask = _get_at_least_one_mask(*args[:2]) |
| # sparse tensors don't have strides so we can only expand if the layout is strided |
| if args0_layout == torch.strided: |
| result_mask = result_mask.expand_as(result_data) |
| return _wrap_result(result_data, result_mask) |
| |
| |
| def _torch_binary(fn_name): |
| fn = getattr(torch.ops.aten, fn_name) |
| |
| def binary_fn(*args, **kwargs): |
| return _binary_helper(fn, args, kwargs, inplace=False) |
| |
| return binary_fn |
| |
| |
| def _torch_inplace_binary(fn_name): |
| fn = getattr(torch.ops.aten, fn_name) |
| |
| def binary_fn(*args, **kwargs): |
| return _binary_helper(fn, args, kwargs, inplace=True) |
| |
| return binary_fn |
| |
| |
| NATIVE_BINARY_MAP = { |
| getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES |
| } |
| NATIVE_INPLACE_BINARY_MAP = { |
| getattr(torch.ops.aten, name): _torch_inplace_binary(name) |
| for name in INPLACE_BINARY_NAMES |
| } |
| |
| NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys()) |
| NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys()) |
| |
| |
| def _is_native_binary(fn): |
| return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS |
| |
| |
| def _apply_native_binary(fn, *args, **kwargs): |
| if fn in NATIVE_BINARY_FNS: |
| return NATIVE_BINARY_MAP[fn](*args, **kwargs) |
| if fn in NATIVE_INPLACE_BINARY_FNS: |
| return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs) |
| return NotImplemented |