| #!/usr/bin/env python3 |
| # Owner(s): ["oncall: mobile"] |
| # mypy: allow-untyped-defs |
| |
| import io |
| import textwrap |
| from typing import Dict, List, Optional |
| |
| import torch |
| import torch.utils.bundled_inputs |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| |
| def model_size(sm): |
| buffer = io.BytesIO() |
| torch.jit.save(sm, buffer) |
| return len(buffer.getvalue()) |
| |
| |
| def save_and_load(sm): |
| buffer = io.BytesIO() |
| torch.jit.save(sm, buffer) |
| buffer.seek(0) |
| return torch.jit.load(buffer) |
| |
| |
| class TestBundledInputs(TestCase): |
| def test_single_tensors(self): |
| class SingleTensorModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| sm = torch.jit.script(SingleTensorModel()) |
| original_size = model_size(sm) |
| get_expr: List[str] = [] |
| samples = [ |
| # Tensor with small numel and small storage. |
| (torch.tensor([1]),), |
| # Tensor with large numel and small storage. |
| (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],), |
| # Tensor with small numel and large storage. |
| (torch.tensor(range(1 << 16))[-8:],), |
| # Large zero tensor. |
| (torch.zeros(1 << 16),), |
| # Large channels-last ones tensor. |
| (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),), |
| # Special encoding of random tensor. |
| (torch.utils.bundled_inputs.bundle_randn(1 << 16),), |
| # Quantized uniform tensor. |
| (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),), |
| ] |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| sm, samples, get_expr |
| ) |
| # print(get_expr[0]) |
| # print(sm._generate_bundled_inputs.code) |
| |
| # Make sure the model only grew a little bit, |
| # despite having nominally large bundled inputs. |
| augmented_size = model_size(sm) |
| self.assertLess(augmented_size, original_size + (1 << 12)) |
| |
| loaded = save_and_load(sm) |
| inflated = loaded.get_all_bundled_inputs() |
| self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) |
| self.assertEqual(len(inflated), len(samples)) |
| self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) |
| |
| for idx, inp in enumerate(inflated): |
| self.assertIsInstance(inp, tuple) |
| self.assertEqual(len(inp), 1) |
| self.assertIsInstance(inp[0], torch.Tensor) |
| if idx != 5: |
| # Strides might be important for benchmarking. |
| self.assertEqual(inp[0].stride(), samples[idx][0].stride()) |
| self.assertEqual(inp[0], samples[idx][0], exact_dtype=True) |
| |
| # This tensor is random, but with 100,000 trials, |
| # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105). |
| self.assertEqual(inflated[5][0].shape, (1 << 16,)) |
| self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0) |
| self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0) |
| |
| def test_large_tensor_with_inflation(self): |
| class SingleTensorModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| sm = torch.jit.script(SingleTensorModel()) |
| sample_tensor = torch.randn(1 << 16) |
| # We can store tensors with custom inflation functions regardless |
| # of size, even if inflation is just the identity. |
| sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor) |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)]) |
| |
| loaded = save_and_load(sm) |
| inflated = loaded.get_all_bundled_inputs() |
| self.assertEqual(len(inflated), 1) |
| |
| self.assertEqual(inflated[0][0], sample_tensor) |
| |
| def test_rejected_tensors(self): |
| def check_tensor(sample): |
| # Need to define the class in this scope to get a fresh type for each run. |
| class SingleTensorModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| sm = torch.jit.script(SingleTensorModel()) |
| with self.assertRaisesRegex(Exception, "Bundled input argument"): |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| sm, [(sample,)] |
| ) |
| |
| # Plain old big tensor. |
| check_tensor(torch.randn(1 << 16)) |
| # This tensor has two elements, but they're far apart in memory. |
| # We currently cannot represent this compactly while preserving |
| # the strides. |
| small_sparse = torch.randn(2, 1 << 16)[:, 0:1] |
| self.assertEqual(small_sparse.numel(), 2) |
| check_tensor(small_sparse) |
| |
| def test_non_tensors(self): |
| class StringAndIntModel(torch.nn.Module): |
| def forward(self, fmt: str, num: int): |
| return fmt.format(num) |
| |
| sm = torch.jit.script(StringAndIntModel()) |
| samples = [ |
| ("first {}", 1), |
| ("second {}", 2), |
| ] |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples) |
| |
| loaded = save_and_load(sm) |
| inflated = loaded.get_all_bundled_inputs() |
| self.assertEqual(inflated, samples) |
| self.assertTrue(loaded(*inflated[0]) == "first 1") |
| |
| def test_multiple_methods_with_inputs(self): |
| class MultipleMethodModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| @torch.jit.export |
| def foo(self, arg): |
| return arg |
| |
| mm = torch.jit.script(MultipleMethodModel()) |
| samples = [ |
| # Tensor with small numel and small storage. |
| (torch.tensor([1]),), |
| # Tensor with large numel and small storage. |
| (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],), |
| # Tensor with small numel and large storage. |
| (torch.tensor(range(1 << 16))[-8:],), |
| # Large zero tensor. |
| (torch.zeros(1 << 16),), |
| # Large channels-last ones tensor. |
| (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),), |
| ] |
| info = [ |
| "Tensor with small numel and small storage.", |
| "Tensor with large numel and small storage.", |
| "Tensor with small numel and large storage.", |
| "Large zero tensor.", |
| "Large channels-last ones tensor.", |
| "Special encoding of random tensor.", |
| ] |
| torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( |
| mm, |
| inputs={mm.forward: samples, mm.foo: samples}, |
| info={mm.forward: info, mm.foo: info}, |
| ) |
| loaded = save_and_load(mm) |
| inflated = loaded.get_all_bundled_inputs() |
| |
| # Make sure these functions are all consistent. |
| self.assertEqual(inflated, samples) |
| self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward()) |
| self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo()) |
| |
| # Check running and size helpers |
| self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) |
| self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) |
| |
| # Check helper that work on all functions |
| all_info = loaded.get_bundled_inputs_functions_and_info() |
| self.assertEqual(set(all_info.keys()), {"forward", "foo"}) |
| self.assertEqual( |
| all_info["forward"]["get_inputs_function_name"], |
| ["get_all_bundled_inputs_for_forward"], |
| ) |
| self.assertEqual( |
| all_info["foo"]["get_inputs_function_name"], |
| ["get_all_bundled_inputs_for_foo"], |
| ) |
| self.assertEqual(all_info["forward"]["info"], info) |
| self.assertEqual(all_info["foo"]["info"], info) |
| |
| # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs |
| for func_name in all_info.keys(): |
| input_func_name = all_info[func_name]["get_inputs_function_name"][0] |
| func_to_run = getattr(loaded, input_func_name) |
| self.assertEqual(func_to_run(), samples) |
| |
| def test_multiple_methods_with_inputs_both_defined_failure(self): |
| class MultipleMethodModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| @torch.jit.export |
| def foo(self, arg): |
| return arg |
| |
| samples = [(torch.tensor([1]),)] |
| |
| # inputs defined 2 ways so should fail |
| with self.assertRaises(Exception): |
| mm = torch.jit.script(MultipleMethodModel()) |
| definition = textwrap.dedent( |
| """ |
| def _generate_bundled_inputs_for_forward(self): |
| return [] |
| """ |
| ) |
| mm.define(definition) |
| torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( |
| mm, |
| inputs={ |
| mm.forward: samples, |
| mm.foo: samples, |
| }, |
| ) |
| |
| def test_multiple_methods_with_inputs_neither_defined_failure(self): |
| class MultipleMethodModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| @torch.jit.export |
| def foo(self, arg): |
| return arg |
| |
| samples = [(torch.tensor([1]),)] |
| |
| # inputs not defined so should fail |
| with self.assertRaises(Exception): |
| mm = torch.jit.script(MultipleMethodModel()) |
| mm._generate_bundled_inputs_for_forward() |
| torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs( |
| mm, |
| inputs={ |
| mm.forward: None, |
| mm.foo: samples, |
| }, |
| ) |
| |
| def test_bad_inputs(self): |
| class SingleTensorModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| # Non list for input list |
| with self.assertRaises(TypeError): |
| m = torch.jit.script(SingleTensorModel()) |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| m, |
| inputs="foo", # type: ignore[arg-type] |
| ) |
| |
| # List of non tuples. Most common error using the api. |
| with self.assertRaises(TypeError): |
| m = torch.jit.script(SingleTensorModel()) |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| m, |
| inputs=[torch.ones(1, 2)], # type: ignore[list-item] |
| ) |
| |
| def test_double_augment_fail(self): |
| class SingleTensorModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| m = torch.jit.script(SingleTensorModel()) |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| m, inputs=[(torch.ones(1),)] |
| ) |
| with self.assertRaisesRegex( |
| Exception, "Models can only be augmented with bundled inputs once." |
| ): |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| m, inputs=[(torch.ones(1),)] |
| ) |
| |
| def test_double_augment_non_mutator(self): |
| class SingleTensorModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| m = torch.jit.script(SingleTensorModel()) |
| bundled_model = torch.utils.bundled_inputs.bundle_inputs( |
| m, inputs=[(torch.ones(1),)] |
| ) |
| with self.assertRaises(AttributeError): |
| m.get_all_bundled_inputs() |
| self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)]) |
| self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1)) |
| |
| def test_double_augment_success(self): |
| class SingleTensorModel(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| m = torch.jit.script(SingleTensorModel()) |
| bundled_model = torch.utils.bundled_inputs.bundle_inputs( |
| m, inputs={m.forward: [(torch.ones(1),)]} |
| ) |
| self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)]) |
| |
| bundled_model2 = torch.utils.bundled_inputs.bundle_inputs( |
| bundled_model, inputs=[(torch.ones(2),)] |
| ) |
| self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)]) |
| |
| def test_dict_args(self): |
| class MyModel(torch.nn.Module): |
| def forward( |
| self, |
| arg1: Optional[Dict[str, torch.Tensor]], |
| arg2: Optional[List[torch.Tensor]], |
| arg3: torch.Tensor, |
| ): |
| if arg1 is None: |
| return arg3 |
| elif arg2 is None: |
| return arg1["a"] + arg1["b"] |
| else: |
| return arg1["a"] + arg1["b"] + arg2[0] |
| |
| small_sample = dict( |
| a=torch.zeros([10, 20]), |
| b=torch.zeros([1, 1]), |
| c=torch.zeros([10, 20]), |
| ) |
| small_list = [torch.zeros([10, 20])] |
| |
| big_sample = dict( |
| a=torch.zeros([1 << 5, 1 << 8, 1 << 10]), |
| b=torch.zeros([1 << 5, 1 << 8, 1 << 10]), |
| c=torch.zeros([1 << 5, 1 << 8, 1 << 10]), |
| ) |
| big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])] |
| |
| def condensed(t): |
| ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape) |
| assert ret.storage().size() == 1 |
| # ret.storage()[0] = 0 |
| return ret |
| |
| def bundle_optional_dict_of_randn(template): |
| return torch.utils.bundled_inputs.InflatableArg( |
| value=( |
| None |
| if template is None |
| else {k: condensed(v) for (k, v) in template.items()} |
| ), |
| fmt="{}", |
| fmt_fn=""" |
| def {}(self, value: Optional[Dict[str, Tensor]]): |
| if value is None: |
| return None |
| output = {{}} |
| for k, v in value.items(): |
| output[k] = torch.randn_like(v) |
| return output |
| """, |
| ) |
| |
| def bundle_optional_list_of_randn(template): |
| return torch.utils.bundled_inputs.InflatableArg( |
| value=(None if template is None else [condensed(v) for v in template]), |
| fmt="{}", |
| fmt_fn=""" |
| def {}(self, value: Optional[List[Tensor]]): |
| if value is None: |
| return None |
| output = [] |
| for v in value: |
| output.append(torch.randn_like(v)) |
| return output |
| """, |
| ) |
| |
| out: List[str] = [] |
| sm = torch.jit.script(MyModel()) |
| original_size = model_size(sm) |
| small_inputs = ( |
| bundle_optional_dict_of_randn(small_sample), |
| bundle_optional_list_of_randn(small_list), |
| torch.zeros([3, 4]), |
| ) |
| big_inputs = ( |
| bundle_optional_dict_of_randn(big_sample), |
| bundle_optional_list_of_randn(big_list), |
| torch.zeros([1 << 5, 1 << 8, 1 << 10]), |
| ) |
| |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| sm, |
| [big_inputs, small_inputs], |
| _receive_inflate_expr=out, |
| ) |
| augmented_size = model_size(sm) |
| # assert the size has not increased more than 8KB |
| self.assertLess(augmented_size, original_size + (1 << 13)) |
| |
| loaded = save_and_load(sm) |
| inflated = loaded.get_all_bundled_inputs() |
| self.assertEqual(len(inflated[0]), len(small_inputs)) |
| |
| methods, _ = ( |
| torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods( |
| loaded |
| ) |
| ) |
| |
| # One Function (forward) |
| # two bundled inputs (big_inputs and small_inputs) |
| # two args which have InflatableArg with fmt_fn |
| # 1 * 2 * 2 = 4 |
| self.assertEqual( |
| sum(method.startswith("_inflate_helper") for method in methods), 4 |
| ) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |