blob: e0f7f671147c31b6a3b6767245a34c49bf021507 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import unittest
import torch
import torch._C
from pathlib import Path
from torch.testing._internal.common_utils import IS_FBCODE
# hacky way to skip these tests in fbcode:
# during test execution in fbcode, test_nnapi is available during test discovery,
# but not during test execution. So we can't try-catch here, otherwise it'll think
# it sees tests but then fails when it tries to actuall run them.
if not IS_FBCODE:
from test_nnapi import TestNNAPI
HAS_TEST_NNAPI = True
else:
from torch.testing._internal.common_utils import TestCase as TestNNAPI
HAS_TEST_NNAPI = False
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
"""
Unit Tests for Nnapi backend with delegate
Inherits most tests from TestNNAPI, which loads Android NNAPI models
without the delegate API.
"""
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
torch_root = Path(__file__).resolve().parent.parent.parent
lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so'
@unittest.skipIf(not os.path.exists(lib_path),
"Skipping the test as libnnapi_backend.so was not found")
@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
class TestNnapiBackend(TestNNAPI):
def setUp(self):
super().setUp()
# Save default dtype
module = torch.nn.PReLU()
self.default_dtype = module.weight.dtype
# Change dtype to float32 (since a different unit test changed dtype to float64,
# which is not supported by the Android NNAPI delegate)
# Float32 should typically be the default in other files.
torch.set_default_dtype(torch.float32)
# Load nnapi delegate library
torch.ops.load_library(str(lib_path))
# Override
def call_lowering_to_nnapi(self, traced_module, args):
compile_spec = {"forward": {"inputs": args}}
return torch._C._jit_to_backend("nnapi", traced_module, compile_spec)
def test_tensor_input(self):
# Lower a simple module
args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
module = torch.nn.PReLU()
traced = torch.jit.trace(module, args)
# Argument input is a single Tensor
self.call_lowering_to_nnapi(traced, args)
# Argument input is a Tensor in a list
self.call_lowering_to_nnapi(traced, [args])
# Test exceptions for incorrect compile specs
def test_compile_spec_santiy(self):
args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
module = torch.nn.PReLU()
traced = torch.jit.trace(module, args)
errorMsgTail = r"""
method_compile_spec should contain a Tensor or Tensor List which bundles input parameters: shape, dtype, quantization, and dimorder.
For input shapes, use 0 for run/load time flexible input.
method_compile_spec must use the following format:
{"forward": {"inputs": at::Tensor}} OR {"forward": {"inputs": c10::List<at::Tensor>}}"""
# No forward key
compile_spec = {"backward": {"inputs": args}}
with self.assertRaisesRegex(RuntimeError, "method_compile_spec does not contain the \"forward\" key." + errorMsgTail):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No dictionary under the forward key
compile_spec = {"forward": 1}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
"under it's \"forward\" key."
+ errorMsgTail):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No inputs key (in the dictionary under the forward key)
compile_spec = {"forward": {"not inputs": args}}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
"under it's \"forward\" key."
+ errorMsgTail):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No Tensor or TensorList under the inputs key
compile_spec = {"forward": {"inputs": 1}}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
+ errorMsgTail):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
compile_spec = {"forward": {"inputs": [1]}}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
+ errorMsgTail):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
def tearDown(self):
# Change dtype back to default (Otherwise, other unit tests will complain)
torch.set_default_dtype(self.default_dtype)