| # Owner(s): ["module: dynamo"] |
| import torch |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| import torch.onnx.operators |
| |
| |
| def fn(a, b): |
| return a + b * 0.67 |
| |
| |
| class InteropTests(torch._dynamo.test_case.TestCase): |
| def _common(self, fn): |
| inputs = [torch.randn(10), torch.randn(10)] |
| ref = fn(*inputs) |
| opt_fn = torch.compile(fn, backend="eager", fullgraph=True) |
| res = opt_fn(*inputs) |
| self.assertEqual(ref, res) |
| |
| def test_fx_fn(self): |
| fx_fn = torch.fx.symbolic_trace(fn) |
| self._common(lambda a, b: fx_fn(a, b) + 1) |
| |
| def test_script_fn(self): |
| script_fn = torch.jit.script(fn) |
| self._common(lambda a, b: script_fn(a, b) + 1) |
| |
| def test_trace_fn(self): |
| trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)]) |
| self._common(lambda a, b: trace_fn(a, b) + 1) |
| |
| def test_vmap_in_graph(self): |
| from functools import wraps |
| |
| from torch._dynamo import allow_in_graph |
| |
| def traceable(f): |
| f = allow_in_graph(f) |
| |
| @wraps(f) |
| def wrapper(*args, **kwargs): |
| return f(*args, **kwargs) |
| |
| return wrapper |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| x = torch.randn(3, 5, 3) |
| |
| def fn(x): |
| return torch.vmap(torch.Tensor.t)(x) |
| |
| fn_opt = torch.compile(fn, backend=cnts, fullgraph=True) |
| fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True) |
| |
| self.assertEqual(fn(x), fn_opt(x)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(fn_opt(x), fn_opt_traceable(x)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |