| # Copyright (c) Meta Platforms, Inc. and affiliates |
| |
| import warnings |
| |
| import torch |
| |
| from .core import is_masked_tensor |
| from .creation import as_masked_tensor, masked_tensor |
| |
| __all__ = [] # type: ignore[var-annotated] |
| |
| |
| def _masked_all_all(data, mask=None): |
| if mask is None: |
| return data.all() |
| return data.masked_fill(~mask, True).all() |
| |
| |
| def _masked_all_dim(data, dim, keepdim=False, mask=None): |
| if mask is None: |
| return torch.all(data, dim=dim, keepdim=keepdim) |
| return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim) |
| |
| |
| def _masked_all(*args, **kwargs): |
| if len(args) == 1 and len(kwargs) == 1: |
| return _masked_all_all(args[0], mask=kwargs["mask"]) |
| return _masked_all_dim(*args, **kwargs) |
| |
| |
| def _multidim_any(mask, dim, keepdim): |
| if isinstance(dim, int): |
| return _multidim_any(mask, [dim], keepdim) |
| for d in sorted(dim)[::-1]: |
| mask = torch.any(mask, dim=d, keepdim=keepdim) |
| return mask |
| |
| |
| def _get_masked_fn(fn): |
| if fn == "all": |
| return _masked_all |
| return getattr(torch.masked, fn) |
| |
| |
| def _torch_reduce_all(fn): |
| def reduce_all(self): |
| masked_fn = _get_masked_fn(fn) |
| data = self.get_data() |
| mask = self.get_mask().values() if self.is_sparse() else self.get_mask() |
| # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the |
| # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts. |
| # Therefore, this implementation calculates it using the strides. |
| if fn == "all": |
| result_data = masked_fn(data, mask=mask) |
| |
| elif fn in {"argmin", "argmax"} and self.is_sparse_coo(): |
| sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int) |
| indices = ( |
| data.to_sparse_coo().indices() |
| if not self.is_sparse_coo() |
| else data.indices() |
| ) |
| idx = indices.unbind(1)[sparse_idx] |
| stride = data.size().numel() / torch.tensor( |
| data.size(), device=data.device |
| ).cumprod(0) |
| result_data = torch.sum(idx * stride) |
| |
| # we simply pass in the values for sparse COO/CSR tensors |
| elif self.is_sparse(): |
| result_data = masked_fn(masked_tensor(data.values(), mask)) |
| |
| else: |
| result_data = masked_fn(self, mask=mask) |
| |
| return as_masked_tensor(result_data, torch.any(mask)) |
| |
| return reduce_all |
| |
| |
| def _torch_reduce_dim(fn): |
| def reduce_dim(self, dim, keepdim=False, dtype=None): |
| if self.is_sparse(): |
| msg = ( |
| f"The sparse version of {fn} is not implemented in reductions.\n" |
| "If you would like this operator to be supported, please file an issue for a feature request at " |
| "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n" |
| "In the case that the semantics for the operator are not trivial, it would be appreciated " |
| "to also include a proposal for the semantics." |
| ) |
| warnings.warn(msg) |
| return NotImplemented |
| if not is_masked_tensor(self): |
| raise TypeError("Input to reduce_dim must be a MaskedTensor") |
| |
| masked_fn = _get_masked_fn(fn) |
| data = self.get_data() |
| mask = self.get_mask() |
| if fn == "all": |
| result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask) |
| else: |
| result_data = masked_fn( |
| self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask() |
| ) |
| return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim)) |
| |
| return reduce_dim |
| |
| |
| def _torch_reduce(fn): |
| def reduce_fn(*args, **kwargs): |
| if len(args) == 1 and len(kwargs) == 0: |
| return _torch_reduce_all(fn)(args[0]) |
| return _torch_reduce_dim(fn)(*args, **kwargs) |
| |
| return reduce_fn |
| |
| |
| def _reduce_dim_args(input, dim, keepdim=False, dtype=None): |
| return input, dim, keepdim, dtype |
| |
| |
| def _torch_grad_reduce(fn): |
| def grad_reduce(*args, **kwargs): |
| if len(args) == 1 and len(kwargs) == 0: |
| return _torch_reduce_all(fn)(args[0]) |
| # TODO: autograd.Function doesn't support kwarg |
| input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs) |
| return _torch_reduce_dim(fn)(input, dim, keepdim, dtype) |
| |
| return grad_reduce |
| |
| |
| REDUCE_NAMES = [ |
| "sum", |
| "mean", |
| "amin", |
| "amax", |
| "argmin", |
| "argmax", |
| "prod", |
| "all", |
| "norm", |
| "var", |
| "std", |
| ] |
| |
| NATIVE_REDUCE_MAP = { |
| getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES |
| } |
| TORCH_REDUCE_MAP = { |
| getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES |
| } |
| TENSOR_REDUCE_MAP = { |
| getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES |
| } |
| |
| NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys()) |
| TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys()) |
| TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys()) |
| |
| def _is_reduction(fn): |
| return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP |
| |
| |
| def _apply_reduction(fn, *args, **kwargs): |
| if fn in NATIVE_REDUCE_MAP: |
| return NATIVE_REDUCE_MAP[fn](*args, **kwargs) |
| if fn in TORCH_REDUCE_MAP: |
| return TORCH_REDUCE_MAP[fn](*args, **kwargs) |
| if fn in TENSOR_REDUCE_MAP: |
| return TENSOR_REDUCE_MAP[fn](*args, **kwargs) |
| return NotImplemented |