| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import tempfile |
| import unittest |
| from pathlib import Path |
| |
| import torch |
| |
| from executorch.extension.pybindings.test.make_test import ( |
| create_program, |
| ModuleAdd, |
| ModuleMulti, |
| ) |
| from executorch.runtime import Runtime, Verification |
| |
| |
| class RuntimeTest(unittest.TestCase): |
| def test_smoke(self): |
| ep, inputs = create_program(ModuleAdd()) |
| runtime = Runtime.get() |
| # Demonstrate that get() returns a singleton. |
| runtime2 = Runtime.get() |
| self.assertTrue(runtime is runtime2) |
| program = runtime.load_program(ep.buffer, verification=Verification.Minimal) |
| method = program.load_method("forward") |
| outputs = method.execute(inputs) |
| self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) |
| |
| def test_module_with_multiple_method_names(self): |
| ep, inputs = create_program(ModuleMulti()) |
| runtime = Runtime.get() |
| |
| program = runtime.load_program(ep.buffer, verification=Verification.Minimal) |
| self.assertEqual(program.method_names, set({"forward", "forward2"})) |
| method = program.load_method("forward") |
| outputs = method.execute(inputs) |
| self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) |
| |
| method = program.load_method("forward2") |
| outputs = method.execute(inputs) |
| self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1] + 1)) |
| |
| def test_print_operator_names(self): |
| ep, inputs = create_program(ModuleAdd()) |
| runtime = Runtime.get() |
| |
| operator_names = runtime.operator_registry.operator_names |
| self.assertGreater(len(operator_names), 0) |
| |
| self.assertIn("aten::add.out", operator_names) |
| |
| def test_load_program_with_path(self): |
| ep, inputs = create_program(ModuleAdd()) |
| runtime = Runtime.get() |
| |
| def test_add(program): |
| method = program.load_method("forward") |
| outputs = method.execute(inputs) |
| self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) |
| |
| with tempfile.NamedTemporaryFile() as f: |
| f.write(ep.buffer) |
| f.flush() |
| # filename |
| program = runtime.load_program(f.name) |
| test_add(program) |
| # pathlib.Path |
| path = Path(f.name) |
| program = runtime.load_program(path) |
| test_add(program) |
| # BytesIO |
| with open(f.name, "rb") as f: |
| program = runtime.load_program(f.read()) |
| test_add(program) |