blob: dfd26f057484d708e1c95ac9209f0e7fe03fc3eb [file] [log] [blame]
# Owner(s): ["module: masked operators"]
import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
make_tensor,
instantiate_parametrized_tests,
)
from torch.testing._internal.common_methods_invocations import (
SampleInput,
)
from torch.masked.maskedtensor.core import _masks_match, _tensors_match
def _compare_mt_t(mt_result, t_result):
mask = mt_result.get_mask()
mt_result_data = mt_result.get_data()
if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
mask = mask.to_dense()
if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}:
mt_result_data = mt_result_data.to_dense()
a = mt_result_data.detach().masked_fill_(~mask, 0)
b = t_result.detach().masked_fill_(~mask, 0)
if not _tensors_match(a, b, exact=False):
raise ValueError("The data in MaskedTensor a and Tensor b do not match")
def _compare_mts(mt1, mt2):
mt_data1 = mt1.get_data()
mt_data2 = mt2.get_data()
if mt_data1.layout != mt_data2.layout:
raise ValueError("mt1's data and mt2's data do not have the same layout. "
f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}")
mask = mt1.get_mask()
mask2 = mt2.get_mask()
if not _masks_match(mt1, mt2):
raise ValueError("mt1 and mt2 must have matching masks")
if mask.layout != mask2.layout:
raise ValueError("mt1's mask and mt2's mask do not have the same layout. "
f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}")
if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
mask = mask.to_dense()
if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}:
mt_data1 = mt_data1.to_dense()
mt_data2 = mt_data2.to_dense()
a = mt_data1.detach().masked_fill_(~mask, 0)
b = mt_data2.detach().masked_fill_(~mask, 0)
if not _tensors_match(a, b, exact=False):
raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")
def _create_random_mask(shape, device):
return make_tensor(
shape, device=device, dtype=torch.bool, low=0, high=1, requires_grad=False
)
def _generate_sample_data(
device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided
):
assert layout in {
torch.strided,
torch.sparse_coo,
torch.sparse_csr,
}, "Layout must be strided/sparse_coo/sparse_csr"
shapes = [
[],
[2],
[3, 5],
[3, 2, 1, 2],
]
inputs = []
for s in shapes:
data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type]
mask = _create_random_mask(s, device)
if layout == torch.sparse_coo:
mask = mask.to_sparse_coo().coalesce()
data = data.sparse_mask(mask).requires_grad_(requires_grad)
elif layout == torch.sparse_csr:
if data.ndim != 2 and mask.ndim != 2:
continue
mask = mask.to_sparse_csr()
data = data.sparse_mask(mask)
inputs.append(SampleInput(data, kwargs={"mask": mask}))
return inputs
class TestBasics(TestCase):
def sample_test(self):
return
instantiate_parametrized_tests(TestBasics)
if __name__ == '__main__':
run_tests()