Added new test sample to interpolate op in OpInfo (#104181) Description: - Added new test sample to interpolate op in OpInfo - Fixed silent issue with zero tensor test sample for uint8 dtype Pull Request resolved: https://github.com/pytorch/pytorch/pull/104181 Approved by: https://github.com/pmeier, https://github.com/lezcano
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index c76be8d..cad4be5 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py
@@ -2924,6 +2924,7 @@ decorate('svd_lowrank', decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)})), decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')), decorate('linalg.pinv', 'singular', decorator=toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)})), + decorate('nn.functional.interpolate', 'bicubic', decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)})), # conv2d sometimes nondeterministic in this config? decorate('nn.functional.conv2d', decorator=unittest.skipIf(IS_ARM64, "flaky")), }
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 06c6457..427905a 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py
@@ -354,6 +354,13 @@ ("special.log_ndtr", "cuda", f64): {"atol": 1e-6, "rtol": 1e-5}, ("std_mean.unbiased", "cuda", f16): {"reference_in_float": True}, ("uniform", "cuda"): {"reference_in_float": True}, + # Temporarily skip interpolate bilinear and bicubic tests: + "nn.functional.interpolate.bicubic": { + "assert_equal": False, + "check_gradient": False, + }, + "nn.functional.interpolate.bilinear": {"assert_equal": False}, + "nn.functional.upsample_bilinear": {"assert_equal": False}, } # Always test with all sample for following ops
diff --git a/test/test_mps.py b/test/test_mps.py index 2961787..a6926ce 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -386,6 +386,9 @@ # cpu not giving nan for x/0.0 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + + # inconsistency errors between cpu and mps, max seen atol is 2 + 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_BEFORE_13_3_XFAILLIST = { @@ -433,6 +436,8 @@ MACOS_AFTER_13_1_XFAILLIST = { # before macOS 13.2 it falls back to cpu and pass the forward pass 'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode + # inconsistency errors between cpu and mps, max seen atol is 2 + 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_13_3_XFAILLIST = { @@ -10988,6 +10993,12 @@ elif op.name in ["pow", "__rpow__"]: atol = 1e-6 rtol = 4e-6 + elif op.name == "nn.functional.interpolate": + atol = 1e-3 + rtol = 1e-4 + elif op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8: + atol = 1.0 + rtol = 0.0 else: atol = None rtol = None @@ -11047,6 +11058,9 @@ rtol = 1.5e-3 elif op.name == "unique" and cpu_kwargs["sorted"] is False: continue + elif op.name == "nn.functional.interpolate": + atol = 1e-3 + rtol = 1e-4 else: atol = None rtol = None
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3eb5560..47897aa 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4293,22 +4293,78 @@ return tuple([N, C] + ([size] * rank)) return tuple([size] * rank) - make_arg = partial(make_tensor, device=device, dtype=dtype, - requires_grad=requires_grad, low=-1, high=1) + if mode in ('bilinear', 'bicubic') and dtype == torch.uint8: + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for a more close to typical image processing usage + rank = 2 + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg(shape(270, rank), memory_format=memory_format), + shape(130, rank, False), + scale_factor=None, + mode=mode, + align_corners=False, + ) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) for align_corners in align_corners_options: for rank in ranks_for_mode[mode]: - yield SampleInput(make_arg(shape(D, rank)), - shape(S, rank, False), None, mode, align_corners) - yield SampleInput(make_arg(shape(D, rank)), - shape(L, rank, False), None, mode, align_corners) + yield SampleInput( + make_arg(shape(D, rank)), + shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) for recompute_scale_factor in [False, True]: - yield SampleInput(make_arg(shape(D, rank)), - None, 1.7, mode, align_corners, - recompute_scale_factor=recompute_scale_factor) - yield SampleInput(make_arg(shape(D, rank)), - None, 0.6, mode, align_corners, - recompute_scale_factor=recompute_scale_factor) + for scale_factor in [1.7, 0.6]: + yield SampleInput( + make_arg(shape(D, rank)), + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + +def reference_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs) + + if mode in ('bilinear', 'bicubic'): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + for aa in [True, False]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + scale_factor=None, + mode=mode, + align_corners=False, + antialias=aa, + ) def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): N, C = 2, 3 @@ -4326,8 +4382,7 @@ return torch.Size([N, C] + ([size] * rank)) return torch.Size([size] * rank) - make_arg = partial(make_tensor, device=device, dtype=dtype, - requires_grad=requires_grad, low=-1, high=1) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) for rank in ranks_for_mode[mode]: yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False)) @@ -4335,8 +4390,26 @@ yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7) yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6) +def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs) -def sample_inputs_upsample_aten(mode, self, device, dtype, requires_grad, **kwargs): + if mode in ('bilinear', ): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide a single sample for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + ) + +def sample_inputs_upsample_aa(mode, self, device, dtype, requires_grad, **kwargs): N = 6 C = 3 H = 10 @@ -4344,8 +4417,7 @@ S = 3 L = 5 - input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, - requires_grad=requires_grad, low=-1, high=1) + input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, requires_grad=requires_grad) yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None) yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None) @@ -4356,7 +4428,6 @@ yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=1.7, scales_w=0.9) yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=True, scales_h=1.7, scales_w=0.9) - def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs): N = 5 for _ in range(1, N): @@ -12895,6 +12966,7 @@ dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bilinear'), skips=( # RuntimeError: false # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, @@ -12911,6 +12983,7 @@ dtypes=floating_types_and(torch.uint8, torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bicubic'), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, skips=( # RuntimeError: false @@ -12961,6 +13034,7 @@ dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'), + reference_inputs_func=partial(reference_inputs_upsample, 'bilinear'), skips=( # RuntimeError: false # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, @@ -12977,7 +13051,7 @@ dtypes=floating_types_and(torch.uint8), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, - sample_inputs_func=partial(sample_inputs_upsample_aten, 'bilinear'), + sample_inputs_func=partial(sample_inputs_upsample_aa, 'bilinear'), supports_out=False, skips=( DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),