| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| from dataclasses import dataclass |
| from typing import get_args, List, Optional, Sequence, Union |
| |
| import torch |
| |
| from torch.utils._pytree import tree_flatten |
| |
| from typing_extensions import TypeAlias |
| |
| """ |
| The data types currently supported for element to be bundled. It should be |
| consistent with the types in bundled_program.schema.Value. |
| """ |
| ConfigValue: TypeAlias = Union[ |
| torch.Tensor, |
| int, |
| bool, |
| float, |
| ] |
| |
| """ |
| The data type of the input for method single execution. |
| """ |
| MethodInputType: TypeAlias = Sequence[ConfigValue] |
| |
| """ |
| The data type of the output for method single execution. |
| """ |
| MethodOutputType: TypeAlias = Sequence[torch.Tensor] |
| |
| """ |
| All supported types for input/expected output of MethodTestCase. |
| |
| Namedtuple is also supported and listed implicitly since it is a subclass of tuple. |
| """ |
| |
| # pyre-ignore |
| DataContainer: TypeAlias = Union[list, tuple, dict] |
| |
| |
| class MethodTestCase: |
| """Test case with inputs and expected outputs |
| The expected_outputs are optional and only required if the user wants to verify model outputs after execution. |
| """ |
| |
| def __init__( |
| self, |
| inputs: MethodInputType, |
| expected_outputs: Optional[MethodOutputType] = None, |
| ) -> None: |
| """Single test case for verifying specific method |
| |
| Args: |
| inputs: All inputs required by eager_model with specific inference method for one-time execution. |
| |
| It is worth mentioning that, although both bundled program and ET runtime apis support setting input |
| other than `torch.tensor` type, only the input in `torch.tensor` type will be actually updated in |
| the method, and the rest of the inputs will just do a sanity check if they match the default value in method. |
| |
| expected_outputs: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling. |
| |
| Returns: |
| self |
| """ |
| # TODO(gasoonjia): Update type check logic. |
| # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check. |
| self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs) |
| self.expected_outputs: List[ConfigValue] = [] |
| if expected_outputs is not None: |
| # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check. |
| self.expected_outputs = self._flatten_and_sanity_check(expected_outputs) |
| |
| def _flatten_and_sanity_check( |
| self, unflatten_data: DataContainer |
| ) -> List[ConfigValue]: |
| """Flat the given data and check its legality |
| |
| Args: |
| unflatten_data: Data needs to be flatten. |
| |
| Returns: |
| flatten_data: Flatten data with legal type. |
| """ |
| |
| flatten_data, _ = tree_flatten(unflatten_data) |
| |
| for data in flatten_data: |
| assert isinstance( |
| data, |
| get_args(ConfigValue), |
| ), "The type of input {} with type {} is not supported.\n".format( |
| data, type(data) |
| ) |
| assert not isinstance( |
| data, |
| type(None), |
| ), "The input {} should not be in null type.\n".format(data) |
| |
| return flatten_data |
| |
| |
| @dataclass |
| class MethodTestSuite: |
| """All test info related to verify method |
| |
| Attributes: |
| method_name: Name of the method to be verified. |
| test_cases: All test cases for verifying the method. |
| """ |
| |
| method_name: str |
| test_cases: Sequence[MethodTestCase] |