blob: cf102224fc0873d650557077f68f94b9b530bbca [file] [log] [blame]
#include <ATen/core/dispatch/OperatorOptions.h>
#include <c10/core/ScalarType.h>
#include <gtest/gtest.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/passes.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <stdexcept>
#include "deep_wide_pt.h"
#include "test_utils.h"
using namespace caffe2;
using namespace torch;
using namespace torch::jit;
using namespace torch::jit::test;
using c10::IValue;
/*
When adding a test for an operator implemented in static runtime, there are
several things that you need to pay attention to:
1) if the op is an out variant, in the test script of the op,
instead of:
def forward(self, input):
return myop(input)
do:
def forward(self, input):
return myop(input).clone()
This makes sure that the output of myop is managed by the memory planner and
exercise the code path in the op impl that otherwise doesn't get exercised. The
output of the model is not managed by the memory planner, because it needs to
be returned to the client.
2) The memory planner rounds up the size of each Tensor's storage to multiples
of 64 bytes (alignment requirement on AVX512). Make sure the sizes of the input
tensors in args2 are big enough to trigger resizing.
3) for view ops such as aten::reshape or aten::to, if you want it to be
replaced by the copy version with the ReplaceWithCopy pass in passes.h, you
also want to make sure its output is not returned as the model output. The
reason is that ReplaceWithCopy only replaces the op whose output is not an
alias of the model output.
*/
C10_DECLARE_bool(static_runtime_enable_fast_math);
TEST(StaticRuntime, UnaryOps) {
const auto aten_sum = R"JIT(
def forward(self, input):
return torch.sum(input).clone()
)JIT";
const auto aten_sum_0 = R"JIT(
def forward(self, input):
return torch.sum(input, 0).clone()
)JIT";
const auto aten_sum_1 = R"JIT(
def forward(self, input):
return torch.sum(input, 1).clone()
)JIT";
const auto aten_sum_0_true = R"JIT(
def forward(self, input):
return torch.sum(input, 0, True).clone()
)JIT";
const auto aten_sum_1_true = R"JIT(
def forward(self, input):
return torch.sum(input, 1, True).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({3, 3, 6});
std::vector<IValue> args{a}, args2{b};
// sum
testStaticRuntime(aten_sum, args);
testStaticRuntime(aten_sum_0, args);
testStaticRuntime(aten_sum_1, args);
testStaticRuntime(aten_sum_0_true, args);
testStaticRuntime(aten_sum_1_true, args);
testStaticRuntime(aten_sum, args, args2, false, false, false);
testStaticRuntime(aten_sum_0, args, args2);
testStaticRuntime(aten_sum_1, args, args2);
testStaticRuntime(aten_sum_0_true, args, args2);
testStaticRuntime(aten_sum_1_true, args, args2);
}
TEST(StaticRuntime, Max) {
auto src_max_reduce = R"JIT(
def forward(self, input):
return torch.max(input).clone()
)JIT";
auto src_max_dim = R"JIT(
def forward(self, input, dim: int):
values, indices = torch.max(input, dim)
return values.clone(), indices.clone()
)JIT";
auto src_max_dim_keepdim = R"JIT(
def forward(self, input, dim: int):
values, indices = torch.max(input, dim, keepdim=True)
return values.clone(), indices.clone()
)JIT";
auto src_max_pointwise = R"JIT(
def forward(self, input, other):
return torch.max(input, other).clone()
)JIT";
auto input = at::randn({2, 3, 2});
auto input_other = at::randn({2, 3, 2});
auto large_input = at::randn({8, 9, 10});
auto large_input_other = at::randn({8, 9, 10});
testStaticRuntime(src_max_reduce, {input});
testStaticRuntime(src_max_dim, {input, 1});
testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0});
testStaticRuntime(src_max_dim_keepdim, {input, 0});
testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2});
testStaticRuntime(src_max_pointwise, {input, input_other});
testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other});
}
TEST(StaticRuntime, Mean) {
const auto src_default = R"JIT(
def forward(self, input):
return torch.mean(input).clone()
)JIT";
const auto src_dtype = R"JIT(
def forward(self, input, dtype: int):
return torch.mean(input, dtype=dtype).clone()
)JIT";
const auto src_dim = R"JIT(
def forward(self, input, dim: List[int]):
return torch.mean(input, dim).clone()
)JIT";
const auto src_dim_keepdim = R"JIT(
def forward(self, input, dim: List[int]):
return torch.mean(input, dim, keepdim=True).clone()
)JIT";
const auto src_dim_dtype = R"JIT(
def forward(self, input, dim: List[int], dtype: int):
return torch.mean(input, dim, dtype=dtype).clone()
)JIT";
auto input = at::randn({2, 3, 2});
auto large_input = at::randn({8, 7, 6, 8});
std::vector<IValue> args_default = {input};
std::vector<IValue> args_dtype = {input, torch::kFloat};
std::vector<IValue> args_dim = {input, c10::List<int64_t>{0, 2}};
std::vector<IValue> args_dim_keepdim = {input, c10::List<int64_t>{1, 2}};
std::vector<IValue> args_dim_dtype = {input, c10::List<int64_t>{0, 1}, torch::kBFloat16};
testStaticRuntime(src_default, args_default);
testStaticRuntime(src_dtype, args_dtype);
testStaticRuntime(src_dim, args_dim);
testStaticRuntime(src_dim_keepdim, args_dim_keepdim);
testStaticRuntime(src_dim_dtype, args_dim_dtype);
std::vector<IValue> large_args_dim = {large_input, c10::List<int64_t>{0, 3}};
std::vector<IValue> large_args_dim_keepdim = {large_input, c10::List<int64_t>{1, 2}};
std::vector<IValue> large_args_dim_dtype = {large_input, c10::List<int64_t>{1, 3}, torch::kBFloat16};
testStaticRuntime(src_dim, args_dim, large_args_dim);
testStaticRuntime(src_dim_keepdim, args_dim_keepdim, large_args_dim_keepdim);
testStaticRuntime(src_dim_dtype, args_dim_dtype, large_args_dim_dtype);
}
TEST(StaticRuntime, Sigmoid) {
const auto sigmoid_script = R"JIT(
def forward(self, inp: Tensor):
b = torch.sigmoid(inp).clone()
return (b)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
std::vector<IValue> args{a}, args2{b};
testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
FLAGS_static_runtime_enable_fast_math = false;
testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
FLAGS_static_runtime_enable_fast_math = true;
}
TEST(StaticRuntime, Clone) {
/*
Clone called two times to trigger memory planner for output of first clone.
The output of last op(second clone) is not managed by memory planner since it
needs to be returned to the client and cannot be reused by planner.
*/
const auto clone_script_0 = R"JIT(
def forward(self, input):
a = torch.clone(input).clone()
return (a * a)
)JIT";
// Case: clone with different set of memory_formats
const auto clone_script_1 = R"JIT(
def forward(self, input: Tensor, memory_format: int):
a = torch.clone(input, memory_format=memory_format).clone()
return (a * a)
)JIT";
/*
Case: input stride set to 0 (due to expand op)
calls native clone instead of out variant
*/
const auto clone_script_2 = R"JIT(
def forward(self, input: Tensor, other:Tensor):
a = input.expand_as(other)
return a.clone().clone()
)JIT";
/*
Case: testing the case of sliced tensor for
testing non-contiguous tensor storage
*/
const auto clone_script_3 = R"JIT(
def forward(self, input: Tensor):
a = input[:, 0:10:2]
return a.clone().clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({3, 2}).as_strided({3, 2}, {1, 3});
auto b_larger = at::randn({30, 20}).as_strided({30, 20}, {1, 3});
auto c = at::randn({1, 20, 13, 8});
auto d = at::randn({1, 0, 3, 4});
auto e = at::randn({2, 1});
auto f = at::randn({2, 10});
auto g = at::randn({3, 20});
std::vector<IValue> args_0{b, c10::MemoryFormat::Contiguous};
std::vector<IValue> args_1{b_larger, c10::MemoryFormat::Preserve};
std::vector<IValue> args_2{c, c10::MemoryFormat::ChannelsLast};
std::vector<IValue> args_3{d, c10::MemoryFormat::ChannelsLast};
std::vector<IValue> args_4{e,a};
std::vector<IValue> args_5{e,f};
testStaticRuntime(clone_script_0, {a});
testStaticRuntime(clone_script_0, {a}, {b_larger});
testStaticRuntime(clone_script_1, args_0);
testStaticRuntime(clone_script_1, args_1);
testStaticRuntime(clone_script_1, args_2);
testStaticRuntime(clone_script_1, args_3);
testStaticRuntime(clone_script_1, args_0, args_1);
testStaticRuntime(clone_script_1, args_3, args_2);
testStaticRuntime(clone_script_2, args_4);
testStaticRuntime(clone_script_2, args_4, args_5);
testStaticRuntime(clone_script_3, {f});
testStaticRuntime(clone_script_3, {f}, {g});
}
TEST(StaticRuntime, Clamp) {
const auto clamp_script_1 = R"JIT(
def forward(self, inp: Tensor, min: int, max: int):
a = torch.clamp(inp, min, max).clone()
return (a)
)JIT";
const auto clamp_script_2 = R"JIT(
def forward(self, inp: Tensor, min: Tensor, max: Tensor):
a = torch.clamp(inp, min, max).clone()
return (a)
)JIT";
auto a = at::randn({2, 3});
auto max_t = at::full_like(a, 1);
auto min_t = at::full_like(a, -1);
auto b = at::randn({4, 3, 2});
auto max_t1 = at::full_like(b, 1);
auto min_t1 = at::full_like(b, -1);
testStaticRuntime(clamp_script_1, {a, -1, 1});
testStaticRuntime(clamp_script_2, {a, min_t, max_t});
testStaticRuntime(clamp_script_1, {a, -1, 1}, {b, -1, 1});
testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
}
TEST(StaticRuntime, ClampMinOnly) {
const auto src = R"JIT(
def forward(self, inp: Tensor, min: float):
a = torch.clamp(inp, min, None).clone()
return (a)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
testStaticRuntime(src, {a, 0.5});
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
}
TEST(StaticRuntime, ClampMaxOnly) {
const auto src = R"JIT(
def forward(self, inp: Tensor, max: float):
a = torch.clamp(inp, None, max).clone()
return (a)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
testStaticRuntime(src, {a, 0.5});
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
}
TEST(StaticRuntime, ClampIntTensor) {
const auto src = R"JIT(
def forward(self, inp: Tensor, min: float, max: float):
a = torch.clamp(inp, min, max).clone()
return (a)
)JIT";
auto a = at::randint(0, 20, {2, 3}, at::kFloat);
auto b = at::randint(0, 20, {4, 3, 2}, at::kFloat);
auto min = 5.0f;
auto max = 5.0f;
testStaticRuntime(src, {a, min, max});
testStaticRuntime(src, {a, min, max}, {b, min, max});
}
TEST(StaticRuntime, LenWithTuple) {
const auto src = R"IR(
graph(%input : int[]):
%res : int = aten::len(%input)
return (%res)
)IR";
testStaticRuntime(src, {c10::List<int64_t>(4)});
}
TEST(StaticRuntime, LenWithTensor) {
const auto src = R"IR(
graph(%input : Tensor):
%res : int = aten::len(%input)
return (%res)
)IR";
testStaticRuntime(src, {at::randn({2, 2, 2})});
}
TEST(StaticRuntime, LenWithStr) {
const auto src = R"IR(
graph(%input : str):
%res : int = aten::len(%input)
return (%res)
)IR";
testStaticRuntime(src, {"static_runtime"});
}
TEST(StaticRuntime, LenWithDict_str) {
const auto script = R"JIT(
def forward(self, input: Dict[str, str]):
return len(input)
)JIT";
c10::Dict<std::string, std::string> dict;
dict.insert("abc", "123");
dict.insert("def", "456");
testStaticRuntime(script, {dict});
}
TEST(StaticRuntime, LenWithDict_int) {
const auto script = R"JIT(
def forward(self, input: Dict[int, int]):
return len(input)
)JIT";
c10::Dict<int64_t, int64_t> dict;
dict.insert(0, 1);
dict.insert(2, 3);
testStaticRuntime(script, {dict});
}
TEST(StaticRuntime, LenWithDict_bool) {
const auto script = R"JIT(
def forward(self, input: Dict[bool, bool]):
return len(input)
)JIT";
c10::Dict<bool, bool> dict;
dict.insert(true, false);
dict.insert(false, true);
testStaticRuntime(script, {dict});
}
TEST(StaticRuntime, LenWithDict_float) {
const auto script = R"JIT(
def forward(self, input: Dict[float, float]):
return len(input)
)JIT";
c10::Dict<double, double> dict;
dict.insert(0.1, 0.9);
dict.insert(0.8, 0.18);
testStaticRuntime(script, {dict});
}
TEST(StaticRuntime, LenWithDict_complex) {
const auto script = R"JIT(
def forward(self, input: Dict[complex, complex]):
return len(input)
)JIT";
c10::Dict<c10::complex<double>, c10::complex<double>> dict;
dict.insert(0.1, 0.4);
dict.insert(0.9, 0.45);
testStaticRuntime(script, {dict});
}
TEST(StaticRuntime, LenWithDict_Tensor) {
const auto script = R"JIT(
def forward(self, input: Dict[Tensor, Tensor]):
return len(input)
)JIT";
c10::Dict<at::Tensor, at::Tensor> dict;
dict.insert(at::randn({1, 2}), at::randn({1, 2}));
dict.insert(at::randn({1, 2}), at::randn({1, 2}));
testStaticRuntime(script, {dict});
}
TEST(StaticRuntime, Logit) {
// no nnc
const auto logit_script_1 = R"JIT(
def forward(self, inp: Tensor):
a = torch.logit(inp).clone()
return (a)
)JIT";
// with nnc
const auto logit_script_2 = R"JIT(
def forward(self, inp: Tensor):
a = torch.logit(inp, 1e-6).clone()
return (a)
)JIT";
// no nnc
const auto logit_script_3 = R"JIT(
def forward(self, inp: Tensor, eps: float):
a = torch.logit(inp, eps).clone()
return (a)
)JIT";
auto a = at::ones({2, 3});
double b = 1e-6;
std::vector<IValue> args_1{a};
std::vector<IValue> args_2({a, b});
auto c = at::ones({4, 3, 2});
// logit
testStaticRuntime(logit_script_1, args_1);
testStaticRuntime(logit_script_2, args_1);
testStaticRuntime(logit_script_3, args_2);
testStaticRuntime(logit_script_1, args_1, {c});
testStaticRuntime(logit_script_2, args_1, {c});
testStaticRuntime(logit_script_3, args_2, {c, b});
}
TEST(StaticRuntime, EmbeddingBag) {
const std::string embedding_bag_default = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
x, y, z, _ = torch.embedding_bag(a, b, c)
return (x.clone(), y.clone(), z.clone(), _.clone())
)JIT";
const std::string embedding_bag_mean = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
x, y, z, _ = torch.embedding_bag(a, b, c, False, 1)
return (x.clone(), y.clone(), z.clone(), _.clone())
)JIT";
const std::string embedding_bag_max = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
x, y, z, _ = torch.embedding_bag(a, b, c, False, 2)
return (x.clone(), y.clone(), z.clone(), _.clone())
)JIT";
const std::string embedding_bag_sum_last_offset = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
x, y, z, _ = torch.embedding_bag(a, b, c, False, 0, False, None, True)
return (x.clone(), y.clone(), z.clone(), _.clone())
)JIT";
const std::string embedding_bag_mean_last_offset = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
x, y, z, _ = torch.embedding_bag(a, b, c, False, 1, False, None, True)
return (x.clone(), y.clone(), z.clone(), _.clone())
)JIT";
const std::string embedding_bag_max_last_offset = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
x, y, z, _ = torch.embedding_bag(a, b, c, False, 2, False, None, True)
return (x.clone(), y.clone(), z.clone(), _.clone())
)JIT";
at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
at::Tensor input = torch::tensor({0, 1, 0, 2});
at::Tensor offset = torch::tensor({0, 2, 4});
std::vector<IValue> args{weight, input, offset};
testStaticRuntime(embedding_bag_default, args);
testStaticRuntime(embedding_bag_mean, args);
testStaticRuntime(embedding_bag_max, args);
testStaticRuntime(embedding_bag_sum_last_offset, args);
testStaticRuntime(embedding_bag_mean_last_offset, args);
testStaticRuntime(embedding_bag_max_last_offset, args);
at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
std::vector<IValue> args2{weight2, input2, offset2};
testStaticRuntime(embedding_bag_default, args, args2);
testStaticRuntime(embedding_bag_mean, args, args2);
testStaticRuntime(embedding_bag_max, args, args2);
testStaticRuntime(embedding_bag_sum_last_offset, args, args2);
testStaticRuntime(embedding_bag_mean_last_offset, args, args2);
testStaticRuntime(embedding_bag_max_last_offset, args, args2);
}
TEST(StaticRuntime, EmbeddingBagWithManagedOutput) {
const std::string embedding_bag_managed_output = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
# The outputs of embedding_bag become an intermediate tensors
# since they are not directly returned from the graph.
x, y, z, _ = torch.embedding_bag(a, b, c)
return x + x
)JIT";
at::Tensor weight = torch::randn({3, 8}, at::ScalarType::Float);
at::Tensor input = torch::tensor({0, 1, 0, 2});
at::Tensor offset = torch::tensor({0, 2});
std::vector<IValue> args{weight, input, offset};
at::Tensor weight2 = torch::randn({6, 8}, at::ScalarType::Float);
at::Tensor input2 = torch::tensor({0, 1, 0, 2, 3, 4});
at::Tensor offset2 = torch::tensor({0, 2, 4, 5});
std::vector<IValue> args2{weight2, input2, offset2};
testStaticRuntime(embedding_bag_managed_output, args);
testStaticRuntime(embedding_bag_managed_output, args, args2);
}
TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) {
const std::string embedding_bag_default_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=0]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=0]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res : Tensor = aten::clone(%y0, %none)
return (%res)
)IR";
auto graph = getGraphFromIR(embedding_bag_default_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check("static_runtime::embedding_bag")
->run(*graph);
const std::string embedding_bag_mean_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=1]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=0]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res : Tensor = aten::clone(%y0, %none)
return (%res)
)IR";
graph = getGraphFromIR(embedding_bag_mean_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check("static_runtime::embedding_bag")
->run(*graph);
const std::string embedding_bag_max_last_offset_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=2]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=1]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res : Tensor = aten::clone(%y0, %none)
return (%res)
)IR";
graph = getGraphFromIR(embedding_bag_max_last_offset_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check("static_runtime::embedding_bag")
->run(*graph);
const std::string embedding_bag_normal_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=0]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=0]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res0 : Tensor = aten::clone(%y0, %none)
%res1 : Tensor = aten::clone(%y1, %none)
%res2 : Tensor = aten::clone(%y2, %none)
%res3 : Tensor = aten::clone(%y3, %none)
return (%res0, %res1, %res2, %res3)
)IR";
graph = getGraphFromIR(embedding_bag_normal_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check_not("static_runtime::embedding_bag")
->run(*graph);
at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
at::Tensor input = torch::tensor({0, 1, 0, 2});
at::Tensor offset = torch::tensor({0, 2, 4});
std::vector<IValue> args{weight, input, offset};
testStaticRuntime(embedding_bag_default_ir, args);
testStaticRuntime(embedding_bag_mean_ir, args);
testStaticRuntime(embedding_bag_max_last_offset_ir, args);
at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
std::vector<IValue> args2{weight2, input2, offset2};
testStaticRuntime(embedding_bag_default_ir, args, args2);
testStaticRuntime(embedding_bag_mean_ir, args, args2);
testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2);
}
TEST(StaticRuntime, LayerNorm) {
const std::string layer_norm_with_weights = R"JIT(
def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
)JIT";
const std::string layer_norm_without_weights = R"JIT(
def forward(self, input: Tensor, normalized_shape: List[int]):
return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone()
)JIT";
const auto a = torch::rand({1, 2, 2, 2});
const auto b = torch::rand({3, 2, 2, 2});
for (int normalized_size : {2, 3}) {
std::vector<int64_t> normalized_shape(normalized_size, 2);
const auto weight = torch::rand(normalized_shape);
const auto bias = torch::rand(normalized_shape);
std::vector<IValue> args{a, normalized_shape, weight, bias};
std::vector<IValue> args1{b, normalized_shape, weight, bias};
testStaticRuntime(layer_norm_with_weights, args);
testStaticRuntime(layer_norm_with_weights, args, args1);
args = {a, normalized_shape};
testStaticRuntime(layer_norm_without_weights, args);
testStaticRuntime(layer_norm_without_weights, args, {b, normalized_shape});
}
}
TEST(StaticRuntime, Bmm) {
const auto bmm_script = R"JIT(
def forward(self, inp: Tensor, mat2: Tensor):
return torch.bmm(inp, mat2).clone()
)JIT";
auto a = at::randn({10, 4, 5});
auto b = at::randn({10, 5, 6});
auto c = at::randn({12, 5, 6});
auto d = at::randn({12, 6, 7});
std::vector<IValue> args{a, b};
std::vector<IValue> args1{c, d};
testStaticRuntime(bmm_script, args);
testStaticRuntime(bmm_script, args1);
testStaticRuntime(bmm_script, args, args1);
}
TEST(StaticRuntime, Addmm) {
const auto addmm_script = R"JIT(
def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float):
return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone()
)JIT";
auto inp1 = at::randn({5});
auto mat1 = at::randn({3, 4});
auto mat2 = at::randn({4, 5});
auto inp2 = at::randn({3, 7});
auto mat3 = at::randn({3, 6});
auto mat4 = at::randn({6, 7});
std::vector<IValue> args{inp1, mat1, mat2, 1.0, 2.0};
std::vector<IValue> args1{inp2, mat3, mat4, 2.0, 1.0};
testStaticRuntime(addmm_script, args);
testStaticRuntime(addmm_script, args1);
testStaticRuntime(addmm_script, args, args1);
}
TEST(StaticRuntime, Abs) {
const auto abs_script = R"JIT(
def forward(self, a):
return a.abs().clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 2, 3});
std::vector<IValue> args{a};
std::vector<IValue> args2{b};
testStaticRuntime(abs_script, args);
testStaticRuntime(abs_script, args, args2);
}
TEST(StaticRuntime, Binary) {
const auto add_script = R"JIT(
def forward(self, a, b):
c = a + b
return (c.clone())
)JIT";
const auto add_script_ints = R"JIT(
def forward(self, a: int, b: int):
c = a + b
d = c + 1
return d
)JIT";
const auto add_list_script = R"JIT(
def forward(self, a: List[int], b: List[int]):
c = a + b
return c[::]
)JIT";
const auto list_construct_script = R"JIT(
def forward(self, a, b):
return [a, b]
)JIT";
const auto list_construct_script_2 = R"JIT(
def forward(self, a, b):
c = a + a
return [c, c]
)JIT";
const auto list_construct_script_3 = R"JIT(
def forward(self, a, b):
c = a + a
return [c, c.flatten()]
)JIT";
const auto list_unpack_script = R"JIT(
def forward(self, a, b):
c = [a, b]
x, y = c
z = x + y
return z.clone()
)JIT";
const auto list_unpack_script_2 = R"JIT(
def forward(self, a, b):
c = [a, b]
x, y = c
z = (x, y)
return z
)JIT";
const auto tuple_construct_script = R"JIT(
def forward(self, a, b):
return (a, b)
)JIT";
const auto tuple_construct_script_2 = R"JIT(
def forward(self, a, b):
return (a.flatten(), b)
)JIT";
auto a = at::randn({2, 3});
auto b = at::ones({2, 3});
auto c = at::randn({4, 2, 3});
auto d = at::ones({4, 2, 3});
std::vector<IValue> args{a, b};
testStaticRuntime(add_script, args);
testStaticRuntime(add_script_ints, {1, 2});
testStaticRuntime(add_script, args, {c, d});
testStaticRuntime(list_construct_script, args);
testStaticRuntime(list_construct_script_2, args);
testStaticRuntime(list_construct_script_3, args);
testStaticRuntime(list_unpack_script, args);
testStaticRuntime(list_unpack_script_2, args);
testStaticRuntime(tuple_construct_script, args);
testStaticRuntime(tuple_construct_script_2, args);
std::vector<IValue> list_args{
c10::List<int64_t>{1, 2, 3}, c10::List<int64_t>{4, 5, 6}};
testStaticRuntime(add_list_script, list_args);
}
TEST(StaticRuntime, MatMul) {
const auto aten_matmul = R"JIT(
def forward(self, a: Tensor, b: Tensor):
return torch.matmul(a, b).clone()
)JIT";
// 1-D, 1-D
std::vector<IValue> args{at::randn({3}), at::randn({3})};
testStaticRuntime(aten_matmul, args);
// 2-D, 2-D
std::vector<IValue> args1 = {at::randn({3, 2}), at::randn({2, 3})};
testStaticRuntime(aten_matmul, args1);
// 1-D, 2-D
std::vector<IValue> args2 = {at::randn({3}), at::randn({3, 5})};
testStaticRuntime(aten_matmul, args2);
// 2-D, 1-D
std::vector<IValue> args3 = {at::randn({3, 5}), at::randn({5})};
testStaticRuntime(aten_matmul, args3);
// > 2-D , > 2-D
std::vector<IValue> args4 = {at::randn({3, 1, 4, 5}), at::randn({2, 5, 6})};
testStaticRuntime(aten_matmul, args4);
testStaticRuntime(aten_matmul, args3, args4);
}
TEST(StaticRuntime, Sign) {
const auto sign_tensor = R"JIT(
def forward(self, input: Tensor):
return torch.sign(input).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
std::vector<IValue> args{a};
testStaticRuntime(sign_tensor, args);
testStaticRuntime(sign_tensor, args, {b});
}
TEST(StaticRuntime, Div) {
const auto div_tensor = R"JIT(
def forward(self, a: Tensor, b: Tensor):
return torch.div(a, b).clone()
)JIT";
const auto div_scalar = R"JIT(
def forward(self, a: Tensor, b: int):
return torch.div(a, b).clone()
)JIT";
const auto div_tensor_mode = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: str):
return torch.div(a, b, rounding_mode=c).clone()
)JIT";
const auto div_scalar_mode = R"JIT(
def forward(self, a: Tensor, b: float, c: str):
return torch.div(a, b, rounding_mode=c).clone()
)JIT";
const auto div_strided = R"JIT(
def forward(self, a: Tensor, b: Tensor):
a_strided = torch.transpose(a, 0, 1)
b_strided = torch.transpose(b, 0, 1)
return torch.div(a_strided, b_strided).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({2, 3});
auto bs = at::randn({3, 2}).transpose(0, 1);
auto c = at::randn({4, 3, 2});
auto d = at::randn({4, 3, 2});
auto ds = at::randn({3, 4, 2}).transpose(0, 1);
std::vector<IValue> args0{a, b};
testStaticRuntime(div_tensor, args0);
testStaticRuntime(div_tensor, args0, {c, d});
testStaticRuntime(div_strided, args0);
testStaticRuntime(div_strided, args0, {c, d});
testStaticRuntime(div_tensor, {a, bs});
testStaticRuntime(div_tensor, {a, bs}, {c, ds});
std::vector<IValue> args1{a, 3};
testStaticRuntime(div_scalar, args1);
testStaticRuntime(div_scalar, args1, {c, 4});
std::vector<IValue> args2{a, b, "floor"};
testStaticRuntime(div_tensor_mode, args2);
testStaticRuntime(div_tensor_mode, args2, {c, d, "floor"});
std::vector<IValue> args3{a, 2.3, "trunc"};
testStaticRuntime(div_scalar_mode, args3);
testStaticRuntime(div_scalar_mode, args3, {c, 1.5, "trunc"});
}
TEST(StaticRuntime, Mul) {
const auto mul_tensor = R"JIT(
def forward(self, a: Tensor, b: Tensor):
return torch.mul(a, b).clone()
)JIT";
const auto mul_scalar = R"JIT(
def forward(self, a: Tensor, b: int):
return torch.mul(a, b).clone()
)JIT";
const auto mul_list = R"JIT(
def forward(self, a: List[int], n: int):
b = a * n
return b[::]
)JIT";
auto a = at::randn({3, 3});
auto b = at::randn({3, 3});
auto c = at::randn({3, 3, 3});
auto d = at::randn({3, 3, 3});
std::vector<IValue> tensor_args1{a, b};
std::vector<IValue> tensor_args2{c, d};
testStaticRuntime(mul_tensor, tensor_args1);
testStaticRuntime(mul_tensor, tensor_args1, tensor_args2);
std::vector<IValue> scalar_args1{a, 42};
std::vector<IValue> scalar_args2{c, 42};
testStaticRuntime(mul_scalar, scalar_args1);
testStaticRuntime(mul_scalar, scalar_args1, scalar_args2);
std::vector<IValue> list_args{c10::List<int64_t>{1, 2}, 3};
testStaticRuntime(mul_list, list_args);
}
TEST(StaticRuntime, Log) {
const auto log_tensor = R"JIT(
def forward(self, inp: Tensor):
a = torch.log(inp).clone()
return (a)
)JIT";
// Ensure that the input values are valid.
auto a = at::abs(at::randn({2, 3}));
auto b = at::abs(at::randn({4, 3, 2}));
std::vector<IValue> args{a};
testStaticRuntime(log_tensor, args);
testStaticRuntime(log_tensor, args, {b});
}
TEST(StaticRuntime, Sub) {
const auto sub_tensor = R"JIT(
def forward(self, a: Tensor, b: Tensor):
return torch.sub(a, b).clone()
)JIT";
const auto sub_scalar = R"JIT(
def forward(self, a: Tensor, b: int):
return torch.sub(a, b).clone()
)JIT";
const auto sub_tensor_alpha = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: float):
return torch.sub(a, b, alpha=c).clone()
)JIT";
const auto sub_scalar_alpha = R"JIT(
def forward(self, a: Tensor, b: float, c: int):
return torch.sub(a, b, alpha=c).clone()
)JIT";
const auto sub_two_scalars = R"JIT(
def forward(self, a: int, b: int):
return (a - b - b)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({2, 3});
auto c = at::randn({4, 3, 2});
auto d = at::randn({4, 3, 2});
std::vector<IValue> args0{a, b};
testStaticRuntime(sub_tensor, args0);
testStaticRuntime(sub_tensor, args0, {c, d});
std::vector<IValue> args1{a, 3};
testStaticRuntime(sub_scalar, args1);
testStaticRuntime(sub_scalar, args1, {c, 4});
std::vector<IValue> args2{a, b, 2.3};
testStaticRuntime(sub_tensor_alpha, args2);
testStaticRuntime(sub_tensor_alpha, {c, d, 3.1});
std::vector<IValue> args3{a, 2.3, 4};
testStaticRuntime(sub_scalar_alpha, args3);
testStaticRuntime(sub_scalar_alpha, {c, 1.3, 2});
std::vector<IValue> args4{1, 2};
testStaticRuntime(sub_two_scalars, args4);
}
TEST(StaticRuntime, NanToNum) {
const auto nan_to_num_script = R"JIT(
def forward(self, a: Tensor, nan: float, posinf: float, neginf: float):
return torch.nan_to_num(a, nan, posinf, neginf).clone()
)JIT";
const auto inf = std::numeric_limits<double>::infinity();
const auto nan = std::numeric_limits<double>::quiet_NaN();
auto a = torch::tensor({{1.0, nan}, {-inf, inf}});
auto b = at::randn({3, 6});
float* b_data = b.data_ptr<float>();
b_data[0] = nan;
b_data[4] = -inf;
b_data[11] = inf;
b_data[13] = nan;
std::vector<IValue> args1{a, 1.0, 2.0, -2.0};
std::vector<IValue> args2{b, 1.0, 2.0, -2.0};
testStaticRuntime(
nan_to_num_script,
args1,
/*args2*/ {},
/*use_allclose*/ true,
/*use_equalnan*/ true);
testStaticRuntime(
nan_to_num_script,
args1,
args2,
/*use_allclose*/ true,
/*use_equalnan*/ true);
}
TEST(StaticRuntime, Stack) {
const auto stack_dim = R"JIT(
def forward(self, a: Tensor, b: Tensor, dim: int):
inputs = [a]
inputs.append(b) # mutation to avoid using VarStack
return torch.stack(inputs, dim = dim).clone()
)JIT";
const auto stack_three = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: Tensor):
inputs = [a, b]
inputs.append(c) # mutation to avoid using VarStack
return torch.stack(inputs).clone()
)JIT";
auto a = at::randn({2, 2});
auto b = at::randn({2, 2});
auto c = at::randn({2, 2});
auto d = at::randn({3, 3, 3});
auto e = at::randn({3, 3, 3});
auto f = at::randn({3, 3, 3});
std::vector<IValue> args1_dim{a, b, 0};
std::vector<IValue> args2_dim{d, e, 1};
std::vector<IValue> args_dim_negative{d, e, -1};
std::vector<IValue> args1_three_tensors{a, b, c};
std::vector<IValue> args2_three_tensors{d, e, f};
testStaticRuntime(stack_dim, args1_dim);
testStaticRuntime(stack_dim, args1_dim, args2_dim);
testStaticRuntime(stack_dim, args_dim_negative);
testStaticRuntime(stack_three, args1_three_tensors);
testStaticRuntime(stack_three, args1_three_tensors, args2_three_tensors);
}
TEST(StaticRuntime, ReLU) {
const auto relu_script = R"JIT(
def forward(self, a: Tensor):
return torch.relu(a).clone()
)JIT";
auto a = at::randint(-10, 10, {2, 4});
auto b = at::randint(-10, 10, {3, 6});
std::vector<IValue> args1{a};
std::vector<IValue> args2{b};
testStaticRuntime(relu_script, args1);
testStaticRuntime(relu_script, args1, args2);
}
TEST(StaticRuntime, Tanh) {
const auto tanh_script = R"JIT(
def forward(self, a):
return torch.tanh(a).clone()
)JIT";
auto a = at::randn({2, 2});
auto b = at::randn({3, 3, 3});
std::vector<IValue> args1{a};
std::vector<IValue> args2{b};
testStaticRuntime(tanh_script, args1, /*args2*/ {}, /*use_allclose*/ true);
testStaticRuntime(tanh_script, args1, args2, /*use_allclose*/ true);
}
TEST(StaticRuntime, Norm) {
const auto norm_2arg = R"JIT(
def forward(self, a: Tensor, p: int):
return torch.norm(a, p).clone()
)JIT";
const auto norm_3arg = R"JIT(
def forward(self, a: Tensor, p: int, dtype: int):
return torch.norm(a, p, dtype=dtype).clone()
)JIT";
const auto norm_4arg = R"JIT(
def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool):
return torch.norm(a, p, dim, keepdim).clone()
)JIT";
const auto norm_5arg = R"JIT(
def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool, dtype: int):
return torch.norm(a, p, dim, keepdim, dtype=dtype).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 5});
auto dim = std::vector<int64_t>({1});
auto dtype = at::ScalarType::Float;
std::vector<IValue> args2{a, 2};
testStaticRuntime(norm_2arg, args2);
testStaticRuntime(norm_2arg, args2, {b, 2}, false, false, false);
std::vector<IValue> args3{a, 2, dtype};
testStaticRuntime(norm_3arg, args3);
testStaticRuntime(norm_3arg, args3, {b, 2, dtype}, false, false, false);
std::vector<IValue> args4{a, 3, dim, false};
testStaticRuntime(norm_4arg, args4);
testStaticRuntime(norm_4arg, args4, {b, 3, dim, false});
std::vector<IValue> args5{a, 4, dim, true, dtype};
testStaticRuntime(norm_5arg, args5);
testStaticRuntime(norm_5arg, args5, {b, 4, dim, true, dtype});
}
TEST(StaticRuntime, Reshape) {
const auto reshape_script_1 = R"JIT(
def forward(self, a: Tensor, shape: List[int]):
b = a.reshape(shape)
return b + b
)JIT";
const auto reshape_script_2 = R"JIT(
def forward(self, a: Tensor, shape: List[int]):
b = a.transpose(0, 1)
return b.reshape(shape)
)JIT";
const auto reshape_script_3 = R"JIT(
def forward(self, inp: Tensor, shape: List[int]):
a = inp + inp
b = a.reshape(shape)
c = a.reshape(shape)
d = c + c
e = d + d
f = e * e
g = f * f
return b.reshape(shape), g
)JIT";
// exercise reshape_copy and flatten_copy
const auto reshape_script_4 = R"JIT(
def forward(self, inp: Tensor, shape: List[int]):
k = inp + inp
a = k + k
b = a.reshape(shape)
c = a.flatten().reshape(shape)
return b + c
)JIT";
// exercise reshape_copy
const auto reshape_script_5 = R"JIT(
def forward(self, inp: Tensor, shape: List[int]):
a = inp + inp
b = a.reshape(shape)
c = a.reshape(shape).relu()
d = c + c
e = d + d
f = e * e
g = f * f
return g
)JIT";
const auto reshape_inplace_script = R"JIT(
def forward(self, inp: Tensor, shape: List[int]):
a = inp + inp
b = a.reshape(shape)
c = b.sigmoid_()
d = c + c
e = a + a
f = b + b
return (d, e, f)
)JIT";
// b is in_contiguous
const auto reshape_incontiguous_script = R"JIT(
def forward(self, a: Tensor, shape: List[int]):
b = a.transpose(0, 1)
c = b.reshape(shape)
c = c.relu()
return (c)
)JIT";
auto a = at::randn({2, 3});
auto b = std::vector<int64_t>({3, 2});
std::vector<IValue> args{a, b};
auto c = at::randn({4, 5});
auto d = std::vector<int64_t>({5, 1, 2, 2});
std::vector<IValue> args1{c, d};
testStaticRuntime(reshape_script_1, args);
testStaticRuntime(reshape_script_2, args);
testStaticRuntime(reshape_script_3, args);
testStaticRuntime(reshape_script_4, args);
testStaticRuntime(reshape_script_5, args);
testStaticRuntime(reshape_inplace_script, args);
testStaticRuntime(reshape_incontiguous_script, args);
testStaticRuntime(reshape_script_1, args, args1);
testStaticRuntime(reshape_script_2, args, args1);
testStaticRuntime(reshape_script_3, args, args1);
testStaticRuntime(reshape_script_4, args, args1);
testStaticRuntime(reshape_script_5, args, args1);
testStaticRuntime(reshape_inplace_script, args, args1);
testStaticRuntime(reshape_incontiguous_script, args, args1);
}
TEST(StaticRuntime, Repeat) {
const std::string repeat = R"JIT(
def forward(self, a: Tensor, repeats: List[int]):
return torch.repeat(a, repeats).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3});
auto c = std::vector<int64_t>({1, 2});
auto d = std::vector<int64_t>({2, 3});
std::vector<IValue> args1{a, c};
std::vector<IValue> args2{b, d};
testStaticRuntime(repeat, args1);
testStaticRuntime(repeat, args2);
testStaticRuntime(repeat, args1, args2);
}
TEST(StaticRuntime, Flatten) {
// exercise flatten_copy
const auto flatten_script_1 = R"JIT(
def forward(self, a: Tensor, start_dim: int, end_dim: int):
b = a * a
c = torch.flatten(b, start_dim, end_dim)
d = torch.relu(c)
return d
)JIT";
const auto flatten_script_2 = R"JIT(
def forward(self, a: Tensor, start_dim: int, end_dim: int):
b = a.transpose(0, 1)
return torch.flatten(b, start_dim, end_dim).clone()
)JIT";
auto test_flatten =
[&](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {
std::vector<int64_t> shape1(shape);
if (shape1.size() > 0) {
shape1[0] *= 6;
}
auto a = at::randn(shape);
auto b = at::randn(shape1);
std::vector<IValue> args{a, start_dim, end_dim};
bool check_resize = shape1.size() > 0;
testStaticRuntime(flatten_script_1, args);
testStaticRuntime(
flatten_script_1,
args,
{b, start_dim, end_dim},
false, /* use_allclose */
false, /* use_equalnan */
check_resize);
if (shape.size() > 2) {
testStaticRuntime(flatten_script_2, args);
testStaticRuntime(flatten_script_2, args, {b, start_dim, end_dim});
}
};
test_flatten({2, 3}, 0, 1);
test_flatten({2, 1, 3}, 1, 2);
test_flatten({0, 1, 3, 0}, 1, 2);
test_flatten({2, 3}, 1, 1);
test_flatten({}, 0, 0);
}
TEST(StaticRuntime, pow) {
const auto pow_script_ten_sca = R"JIT(
def forward(self, input : Tensor, exponent : int):
return torch.pow(input, exponent).clone()
)JIT";
const auto pow_script_ten_ten = R"JIT(
def forward(self, input : Tensor, exponent : Tensor):
return torch.pow(input, exponent).clone()
)JIT";
const auto pow_script_sca_ten = R"JIT(
def forward(self, input : int, exponent : Tensor):
return torch.pow(input, exponent).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({2, 3});
auto c = at::randn({4, 3, 2});
auto d = at::randn({4, 3, 2});
std::vector<IValue> args0{a, 4};
testStaticRuntime(pow_script_ten_sca, args0);
testStaticRuntime(pow_script_ten_sca, args0, {c, 4});
std::vector<IValue> args1{at::abs(a), b};
testStaticRuntime(pow_script_ten_ten, args1);
testStaticRuntime(pow_script_ten_ten, args1, {at::abs(c), d});
std::vector<IValue> args2{5, b};
testStaticRuntime(pow_script_sca_ten, args2);
testStaticRuntime(pow_script_sca_ten, args2, {3, d});
}
TEST(StaticRuntime, to) {
const auto to_script_dtype = R"JIT(
def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
a = input + input
return torch.to(a, dtype, non_blocking, copy, memory_format).clone()
)JIT";
const auto to_script_dtype_strided = R"JIT(
def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
b = input.permute(0, 2, 3, 1)
return torch.to(b, dtype, non_blocking, copy, memory_format).clone()
)JIT";
const auto to_script_prim_dtype = R"JIT(
def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool):
a = input + input
return torch.to(a, dtype, non_blocking, copy).clone()
)JIT";
const auto to_script_other = R"JIT(
def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
a = input + input
return torch.to(a, other, non_blocking, copy, memory_format).clone()
)JIT";
// if input is float tensor, b could be alias of a
const auto to_script_alias = R"JIT(
def forward(self, input:Tensor):
a = input + input
b = a.float()
c = b * b
return (c)
)JIT";
const auto to_script_fails_managed_output_check = R"JIT(
def forward(self, a, b):
d = a.half() * b.half()
e = d.float()
return e
)JIT";
const auto to_script_select_tensor_output_into_tuple = R"JIT(
def forward(self, a, b):
d = a.half() * b.half()
e = d.float()
return (d, e)
)JIT";
const auto to_script_memory_planning_fail = R"JIT(
def forward(self, a, b):
d = a.half() * b.half()
e = d.float().relu()
return e
)JIT";
auto test_to = [&](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {
auto a = at::randn({4, 3, 1, 2});
auto other = at::randn({4, 3, 1, 2}).to(b);
auto a2 = at::randn({3, 2, 2, 4});
auto a2_other = at::randn({3, 2, 2, 4}).to(b);
std::vector<IValue> args0{a, b, c, d, e};
std::vector<IValue> args1{a, b, c, d};
std::vector<IValue> args2{a, other, c, d, e};
std::vector<IValue> args2WithDifferentOtherType{
a, at::randn({4, 3, 1, 2}, ScalarType::Double), c, d, e};
std::vector<IValue> args3{a, c10::nullopt, c, d};
std::vector<IValue> args0WithInt{a, ScalarType::Int, c, d, e};
testStaticRuntime(
to_script_dtype,
args0,
args0WithInt,
/* default for use_allclose */ false,
/* default for use_equalnan */ false,
/* check_resize */ false);
testStaticRuntime(to_script_dtype_strided, args0);
testStaticRuntime(to_script_prim_dtype, args1);
if (!d) {
testStaticRuntime(to_script_prim_dtype, args3);
}
// Second set of args tests case where the `other` tensor's dtype
// changes between iterations.
testStaticRuntime(
to_script_other,
args2,
args2WithDifferentOtherType,
/* default for use_allclose */ false,
/* default for use_equalnan */ false,
/* check_resize */ false);
testStaticRuntime(to_script_alias, {a});
testStaticRuntime(to_script_memory_planning_fail, {a, a});
testStaticRuntime(to_script_fails_managed_output_check, {a, a});
testStaticRuntime(to_script_select_tensor_output_into_tuple, {a, a});
// dynamic shapes
testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e});
testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e});
testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d});
if (!d) {
testStaticRuntime(to_script_prim_dtype, args3, {a2, c10::nullopt, c, d});
}
testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e});
testStaticRuntime(to_script_alias, {a}, {a2});
};
for (const bool non_blocking : {false, true}) {
for (const bool copy : {false, true}) {
// float->float, NCHW->NHWC
test_to(
at::ScalarType::Float,
non_blocking,
copy,
c10::MemoryFormat::ChannelsLast);
// float->half
test_to(
at::ScalarType::Half,
non_blocking,
copy,
c10::MemoryFormat::Preserve);
// float->float
test_to(
at::ScalarType::Float,
non_blocking,
copy,
c10::MemoryFormat::Contiguous);
test_to(
at::ScalarType::Bool,
non_blocking,
copy,
c10::MemoryFormat::Contiguous);
// TODO: check if fbgemm is enabled properly in this case
// half->float, NCHW->NHWC
test_to(
at::ScalarType::Half,
non_blocking,
copy,
c10::MemoryFormat::ChannelsLast);
}
}
}
TEST(StaticRuntime, ExpandAs) {
const auto expand_as_script = R"JIT(
def forward(self, input: Tensor, other:Tensor):
a = input.expand_as(other)
return a.clone()
)JIT";
auto a = at::randn({3, 1});
auto b = at::randn({3, 2});
auto c = at::randn({4, 1});
auto d = at::randn({4, 2});
std::vector<IValue> args{a, b};
std::vector<IValue> args2{c, d};
testStaticRuntime(expand_as_script, args);
testStaticRuntime(expand_as_script, args, args2);
}
TEST(StaticRuntime, Full) {
const auto full_script = R"JIT(
def forward(self,
size: List[int],
fill_value: int,
dtype: Optional[int],
layout: Optional[int],
device: Optional[Device],
pin_memory: Optional[bool]):
a = torch.full(size,
fill_value,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory)
return (a.clone())
)JIT";
auto cpu = at::Device(DeviceType::CPU);
c10::List<int64_t> size0{2, 5};
std::vector<IValue> args{
size0, 4, at::ScalarType::Int, at::kStrided, cpu, false};
std::vector<IValue> args1{
size0, 4, at::ScalarType::Float, at::kStrided, cpu, false};
c10::List<int64_t> size1{5, 6};
std::vector<IValue> args2{
size1, 5, at::ScalarType::Float, at::kStrided, cpu, false};
testStaticRuntime(full_script, args);
testStaticRuntime(
full_script,
args,
args1,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/false);
testStaticRuntime(full_script, args, args2);
}
TEST(StaticRuntime, FullLike) {
const auto full_like_script = R"JIT(
def forward(self,
a: Tensor,
fill_value: int,
dtype: Optional[int],
layout: Optional[int],
device: Optional[Device],
pin_memory: Optional[bool],
memory_format: Optional[int]):
b = torch.full_like(a,
fill_value,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
memory_format=memory_format)
return (b.clone())
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({3, 4, 2});
auto cpu = at::Device(DeviceType::CPU);
std::vector<IValue> args{
a,
4,
at::ScalarType::Int,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
std::vector<IValue> args1{
a,
4,
at::ScalarType::Float,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
std::vector<IValue> args2{
b,
4,
at::ScalarType::Float,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
testStaticRuntime(full_like_script, args);
testStaticRuntime(
full_like_script,
args,
args1,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/false);
testStaticRuntime(full_like_script, args, args2);
}
TEST(StaticRuntime, Ones) {
const auto script = R"JIT(
def forward(self,
size: List[int],
dtype: Optional[int],
layout: Optional[int],
device: Optional[Device],
pin_memory: Optional[bool]):
a = torch.ones(size,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory)
return (a.clone())
)JIT";
auto dtype = at::ScalarType::Int;
auto cpu = at::Device(DeviceType::CPU);
c10::List<int64_t> size0{2, 5};
std::vector<IValue> args{size0, dtype, at::kStrided, cpu, false};
c10::List<int64_t> size1{5, 6};
std::vector<IValue> args2{size1, dtype, at::kStrided, cpu, false};
testStaticRuntime(script, args);
testStaticRuntime(script, args, args2);
}
TEST(StaticRuntime, OnesLike) {
const auto script = R"JIT(
def forward(self,
input: Tensor,
dtype: Optional[int],
layout: Optional[int],
device: Optional[Device],
pin_memory: Optional[bool],
memory_format: Optional[int]):
a = torch.ones_like(input,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
memory_format=memory_format)
return (a.clone())
)JIT";
auto cpu = at::Device(DeviceType::CPU);
auto input0 = at::randn({2, 5});
std::vector<IValue> args{
input0,
at::ScalarType::Int,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
std::vector<IValue> args1{
input0,
at::ScalarType::Float,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
auto input1 = at::randn({5, 6});
std::vector<IValue> args2{
input1,
at::ScalarType::Float,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
testStaticRuntime(script, args);
testStaticRuntime(
script,
args,
args1,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/false);
testStaticRuntime(script, args, args2);
}
TEST(StaticRuntime, Zeros) {
const auto script = R"JIT(
def forward(self,
size: List[int],
dtype: Optional[int],
layout: Optional[int],
device: Optional[Device],
pin_memory: Optional[bool]):
a = torch.zeros(size,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory)
return (a.clone())
)JIT";
auto cpu = at::Device(DeviceType::CPU);
c10::List<int64_t> size0{2, 5};
std::vector<IValue> args{
size0, at::ScalarType::Int, at::kStrided, cpu, false};
std::vector<IValue> args1{
size0, at::ScalarType::Float, at::kStrided, cpu, false};
c10::List<int64_t> size1{5, 6};
std::vector<IValue> args2{
size1, at::ScalarType::Float, at::kStrided, cpu, false};
testStaticRuntime(script, args);
testStaticRuntime(
script,
args,
args1,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/false);
testStaticRuntime(script, args, args2);
}
TEST(StaticRuntime, Linear) {
const auto linear_script = R"JIT(
def forward(self, inp: Tensor, weights: Tensor, bias: Optional[Tensor]) -> Tensor:
return torch.linear(inp, weights, bias).clone()
)JIT";
auto input = at::randn({1, 2});
auto weights = at::randn({1, 2});
auto bias = at::randn({1, 1});
std::vector<IValue> args{input, weights, bias};
std::vector<IValue> args_no_bias{input, weights, c10::nullopt};
auto input2 = at::randn({6, 3});
auto weights2 = at::randn({6, 3});
auto bias2 = at::randn({6, 6});
std::vector<IValue> args2{input2, weights2, bias2};
std::vector<IValue> args2_no_bias{input2, weights2, c10::nullopt};
testStaticRuntime(linear_script, args);
testStaticRuntime(linear_script, args_no_bias);
testStaticRuntime(linear_script, args, args2);
testStaticRuntime(linear_script, args, args2_no_bias);
}
TEST(StaticRuntime, VarCat) {
const auto var_cat_script = R"JIT(
def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
return torch.cat([inp1, inp2], dim).clone()
)JIT";
// 2D tensors - cat dim = 0
std::vector<IValue> args1 = {at::randn({4, 6}), at::randn({5, 6}), 0};
testStaticRuntime(var_cat_script, args1);
// 3D tensors - cat dim = 1
std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 8, 6}), 1};
testStaticRuntime(var_cat_script, args2);
// 3D tensors - cat dim = 2
std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2};
testStaticRuntime(var_cat_script, args3);
// Negative dim
std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), -1};
testStaticRuntime(var_cat_script, args4);
testStaticRuntime(var_cat_script, args1, args2);
}
TEST(StaticRuntime, LeakyReLU) {
torch::jit::Module mod = getLeakyReLUConstScriptModel();
auto inputs = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({inputs});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<c10::IValue> input_tensors({inputs});
torch::jit::StaticModule smod(mod);
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
static ProcessedNodeInputs createProcessedNodeInputs(
c10::ArrayRef<uint16_t> inputs) {
ProcessedNodeInputs result(inputs.size());
for (const auto idx : c10::irange(inputs.size())) {
result[idx] = inputs[idx];
}
return result;
}
static void checkProcessedNodeInputs(
const ProcessedNodeInputs& io,
c10::ArrayRef<uint16_t> inputs) {
ASSERT_EQ(inputs.size(), io.size());
for (const auto idx : c10::irange(inputs.size())) {
EXPECT_EQ(inputs[idx], io[idx]);
}
}
static void testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs) {
auto io = createProcessedNodeInputs(inputs);
checkProcessedNodeInputs(io, inputs);
ProcessedNodeInputs copied(io);
checkProcessedNodeInputs(copied, inputs);
ProcessedNodeInputs moved(std::move(io));
checkProcessedNodeInputs(moved, inputs);
}
TEST(ProcessedNodeInputs, Basic) {
std::vector<std::vector<uint16_t>> testCases = {
{}, // empty
{0xABCD, 0x5a5a}, // inline
{0x11, 0x22, 0x33, 0x44, 0x55}, // max inline size
{0x11, 0x22, 0x33, 0x44, 0x55, 0x66}, // minimum outline size
std::vector<uint16_t>(100, 0x5a), // large outline size
};
for (const auto& values : testCases) {
testProcessedNodeInputsRoundTrip(values);
for (const auto& values2 : testCases) {
auto from = createProcessedNodeInputs(values);
auto to = createProcessedNodeInputs(values2);
to = from;
checkProcessedNodeInputs(to, values);
auto toMoveInto = createProcessedNodeInputs(values2);
toMoveInto = std::move(from);
checkProcessedNodeInputs(toMoveInto, values);
}
}
}
TEST(StaticRuntime, isinstance) {
const auto isinstance_int_script = R"JIT(
def forward(self, a: Any):
return isinstance(a, int)
)JIT";
const auto isinstance_tensor_script = R"JIT(
def forward(self, a: Any):
return isinstance(a, torch.Tensor)
)JIT";
const auto isinstance_many_types_script = R"JIT(
def forward(self, a: Any):
return isinstance(a, (bool, int))
)JIT";
auto a = at::randn({2, 2});
auto b = at::randn({2, 2, 2});
std::vector<at::IValue> args{a};
std::vector<at::IValue> args2{b};
testStaticRuntime(isinstance_int_script, args);
testStaticRuntime(isinstance_int_script, args, args2);
testStaticRuntime(isinstance_tensor_script, args);
testStaticRuntime(isinstance_tensor_script, args, args2);
testStaticRuntime(isinstance_many_types_script, args);
testStaticRuntime(isinstance_many_types_script, args, args2);
}
TEST(StaticRuntime, TypeCheck) {
const auto typecheck_ir = R"IR(
graph(%a.1 : Tensor,
%b.1 : Tensor):
%t0 : Float(2, 2, strides=[2, 1], device=cpu), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu), Float(3, 3, strides=[3, 1])]](%a.1, %b.1)
return (%t0, %t1, %type_matched)
)IR";
auto a = at::zeros({2, 2}, at::kFloat);
a.to(at::kCPU);
auto b = at::ones({3, 3}, at::kFloat);
auto c = at::ones({2, 2, 2}, at::kFloat);
std::vector<IValue> args_correct = {a, b};
std::vector<IValue> args_incorrect = {a, c};
testStaticRuntime(typecheck_ir, args_correct);
testStaticRuntime(typecheck_ir, args_correct, args_incorrect);
}
TEST(StaticRuntime, Index) {
const auto index_without_none_script = R"JIT(
def forward(self, a: Tensor, idx: Tensor):
return a[idx].clone()
)JIT";
// Index with boolean mask
auto a = at::arange(4, at::kFloat).view({2, 2});
auto idx_a = torch::tensor({{0, 1}, {0, 0}}, at::kBool);
std::vector<IValue> args_a{a, idx_a};
// Index with tensor
auto b = at::arange(27, at::kFloat).view({3, 3, 3});
auto idx_b = torch::tensor({0, 1, 2}, at::kLong);
std::vector<IValue> args_b{b, idx_b};
testStaticRuntime(index_without_none_script, args_a);
testStaticRuntime(index_without_none_script, args_a, args_b);
const auto index_with_none_script = R"JIT(
def forward(self, a: Tensor, idx: Tensor, none: Optional[Tensor]):
return a[idx, none].clone()
)JIT";
// Index with None
// When indexing with none, the shape of `f` becomes [2, 1, 2],
// so the mask must be reshaped appropriately.
auto f = at::arange(4, at::kFloat).view({2, 1, 2});
auto idx_f_reshape = torch::tensor({{{0, 1}}, {{0, 0}}}, at::kBool);
std::vector<IValue> args_f_with_none{f, idx_f_reshape};
args_f_with_none.emplace_back();
testStaticRuntime(index_with_none_script, args_f_with_none);
testStaticRuntime(
index_with_none_script,
args_f_with_none,
{IValue(b), IValue(idx_b), IValue()});
const auto index_with_two_tensors_script = R"JIT(
def forward(self, a: Tensor, idx_a: Tensor, idx_b: Tensor):
return a[idx_a, idx_b].clone()
)JIT";
// Index with multiple tensors
const auto& c = a; // 2x2 tensor
auto idx_c1 = torch::tensor({0, 0}, at::kLong);
auto idx_c2 = torch::tensor({0}, at::kLong);
std::vector<IValue> args_c{c, idx_c1, idx_c2};
const auto& d = b; // 3x3x3 tensor
auto idx_d1 = torch::tensor({{0, 0, 2}, {0, 1, 1}}, at::kLong);
auto idx_d2 = torch::tensor({{1, 1, 0}, {1, 0, 2}}, at::kLong);
std::vector<IValue> args_d{d, idx_d1, idx_d2};
testStaticRuntime(index_with_two_tensors_script, args_c, args_d);
}
TEST(StaticRuntime, IndexSelect) {
const std::string script = R"IR(
graph(%self: Tensor, %dim: int, %index: Tensor):
%bias: None = prim::Constant()
%ret = aten::index_select(%self, %dim, %index)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6});
auto dim0 = 0;
auto index0 = at::randint(0, 5, {6}, torch::kInt32);
std::vector<IValue> args{self0, dim0, index0};
testStaticRuntime(script, args);
auto self1 = at::rand({128});
auto dim1 = 0;
auto index1 = at::randint(0, 127, {127}, torch::kInt32);
std::vector<IValue> args2{self1, dim1, index1};
testStaticRuntime(script, args, args2);
}
TEST(StaticRuntime, ClampMin) {
const auto clamp_min_int_script = R"JIT(
def forward(self, a: Tensor, b: int):
return torch.clamp_min(a, b).clone()
)JIT";
const auto clamp_min_float_script = R"JIT(
def forward(self, a: Tensor, b: float):
return torch.clamp_min(a, b).clone()
)JIT";
auto a = at::randn({2, 2});
auto b = at::randn({3, 3, 3});
int scalar_int = 1;
float scalar_float = 3.14;
std::vector<IValue> args_a_int{a, scalar_int};
std::vector<IValue> args_b_int{b, scalar_int};
testStaticRuntime(clamp_min_int_script, args_a_int);
testStaticRuntime(clamp_min_int_script, args_a_int, args_b_int);
std::vector<IValue> args_a_float{a, scalar_float};
std::vector<IValue> args_b_float{b, scalar_float};
testStaticRuntime(clamp_min_float_script, args_a_float);
testStaticRuntime(clamp_min_float_script, args_a_float, args_b_float);
}
TEST(StaticRuntime, Argmin) {
const auto argmin_script = R"JIT(
def forward(self, a: Tensor):
return torch.argmin(a).clone()
)JIT";
const auto argmin_with_dim_script = R"JIT(
def forward(self, a: Tensor, dim: int):
return torch.argmin(a, dim).clone()
)JIT";
const auto argmin_with_keep_dim_script = R"JIT(
def forward(self, a: Tensor, dim: int):
return torch.argmin(a, dim, True).clone()
)JIT";
auto a = at::randn({2, 2});
auto b = at::randn({17, 2, 1});
testStaticRuntime(argmin_script, {a});
testStaticRuntime(
argmin_script,
{a},
{b},
/* use_allclose */ false,
/* use_equalnan */ false,
/* check_resize */ false);
int dim_a = 0;
int dim_b = 1;
std::vector<IValue> args_a{a, dim_a};
std::vector<IValue> args_b{b, dim_b};
testStaticRuntime(argmin_with_dim_script, args_a);
testStaticRuntime(argmin_with_dim_script, args_a, args_b);
testStaticRuntime(argmin_with_keep_dim_script, args_a);
testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b);
}
TEST(StaticRuntime, Softmax) {
const auto softmax_script = R"JIT(
def forward(self, a: Tensor, dim: int):
return torch.softmax(a, dim).clone()
)JIT";
const auto softmax_script_with_dtype = R"JIT(
def forward(self, a: Tensor, dim: int, dtype: int):
return torch.softmax(a, dim, dtype=dtype).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({3, 3, 3});
testStaticRuntime(softmax_script, {a, 0});
testStaticRuntime(softmax_script, {a, 1});
testStaticRuntime(softmax_script, {b, 0});
testStaticRuntime(softmax_script, {b, 1});
testStaticRuntime(softmax_script, {b, 2});
testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float});
testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float});
}
TEST(StaticRuntime, GetItem_Dict) {
const auto getitem_dict_tensor_script = R"JIT(
def forward(self, key: Tensor):
d = {key: 1}
return d[key]
)JIT";
const auto getitem_dict_int_script = R"JIT(
def forward(self, key: int):
d = {key: 1}
return d[key]
)JIT";
const auto getitem_dict_str_script = R"JIT(
def forward(self, key: str):
d = {key: 1}
return d[key]
)JIT";
int int_key = 0;
std::string str_key = "str";
// No need to test these multiple times, args are not tensors
testStaticRuntime(getitem_dict_int_script, {int_key});
testStaticRuntime(getitem_dict_str_script, {str_key});
auto a = torch::tensor({1});
auto b = torch::tensor({1, 1});
testStaticRuntime(getitem_dict_tensor_script, {a});
testStaticRuntime(getitem_dict_tensor_script, {a}, {b});
}
TEST(StaticRuntime, GetItem_List) {
const auto getitem_list_int_script = R"JIT(
def forward(self, idx: int):
lst = [1, 2, 3]
return lst[idx]
)JIT";
const auto getitem_list_tensor_script = R"JIT(
def forward(self, tensor: Tensor, idx: int):
lst = [tensor, tensor]
return lst[idx]
)JIT";
testStaticRuntime(getitem_list_int_script, {1});
testStaticRuntime(getitem_list_int_script, {-1});
auto a = torch::tensor({1});
auto b = torch::tensor({1, 1});
testStaticRuntime(getitem_list_tensor_script, {a, 1});
testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1});
}
TEST(StaticRuntime, Transpose) {
const auto transpose_script = R"JIT(
def forward(self, a: Tensor, dim1: int, dim2: int):
return torch.transpose(a, dim1, dim2).clone()
)JIT";
auto a = at::randn({2, 2});
int dim1_a = 0;
int dim2_a = 1;
std::vector<IValue> args_a{a, dim1_a, dim2_a};
auto b = at::randn({3, 3, 3});
int dim1_b = 0;
int dim2_b = 2;
std::vector<IValue> args_b{b, dim1_b, dim2_b};
testStaticRuntime(transpose_script, args_a);
testStaticRuntime(transpose_script, args_a, args_b);
}
TEST(StaticRuntime, Permute) {
auto permute_script = R"JIT(
def forward(self, a: Tensor, dims: List[int]):
return torch.permute(a, dims).clone()
)JIT";
auto a = at::randn({2, 2});
c10::List<int64_t> dims_a{1, 0};
std::vector<IValue> args_a{a, dims_a};
auto b = at::randn({3, 3, 3});
c10::List<int64_t> dims_b{0, 2, 1};
std::vector<IValue> args_b{b, dims_b};
auto c = at::randn({3, 3, 3});
c10::List<int64_t> dims_c{0, -1, 1};
std::vector<IValue> args_c{c, dims_c};
testStaticRuntime(permute_script, args_a);
testStaticRuntime(permute_script, args_c);
testStaticRuntime(permute_script, args_a, args_b);
permute_script = R"JIT(
def forward(self, a: Tensor, dims: List[int], shape: List[int]):
return torch.permute(a, dims).reshape(shape).clone()
)JIT";
a = at::randn({8, 16, 4});
dims_a = {0, 2, 1};
dims_b = {-1, 16};
testStaticRuntime(permute_script, {a, dims_a, dims_b});
}
TEST(StaticRuntime, Slice) {
const auto slice_script = R"JIT(
def forward(self, a: Tensor, dim: int, start: int, end: int, step: int):
return a.slice(dim, start, end, step).clone()
)JIT";
auto a = at::randn({2, 2});
int dim_a = 1;
int start_a = 0;
int end_a = 1;
int step_a = 1;
std::vector<IValue> args_a{a, dim_a, start_a, end_a, step_a};
auto b = at::randn({3, 3, 3});
int dim_b = 2;
int start_b = 0;
int end_b = 1;
int step_b = 2;
std::vector<IValue> args_b{b, dim_b, start_b, end_b, step_b};
testStaticRuntime(slice_script, args_a);
testStaticRuntime(slice_script, args_a, args_b);
const auto slice_script2 = R"JIT(
def forward(self, a: Tensor, dim: int, step: int):
return a.slice(dim, None, None, step).clone()
)JIT";
std::vector<IValue> args_c{b, dim_b, step_b};
testStaticRuntime(slice_script2, args_c);
}
TEST(StaticRuntime, Narrow) {
const auto narrow_with_int_script = R"JIT(
def forward(self, a: Tensor, dim: int, start: int, length: int):
return a.narrow(dim, start, length).clone()
)JIT";
auto a = at::randn({5, 5});
int dim_a = 0;
int start_a_int = 3;
int len_a = 2;
std::vector<IValue> args_a{a, dim_a, start_a_int, len_a};
auto b = at::randn({5, 5, 5});
int dim_b = 1;
int start_b_int = 2;
int len_b = 3;
std::vector<IValue> args_b{b, dim_b, start_b_int, len_b};
testStaticRuntime(narrow_with_int_script, args_a);
testStaticRuntime(narrow_with_int_script, args_a, args_b);
}
TEST(StaticRuntime, TupleUnpack) {
const auto two_tuple_unpack_script = R"JIT(
def forward(self, tup: Tuple[Tensor, Tensor]):
a, b = tup
return (a, b)
)JIT";
const auto three_tuple_unpack_script = R"JIT(
def forward(self, tup: Tuple[Tensor, Tensor, Tensor]):
a, b, c = tup
return (a, b, c)
)JIT";
auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})});
auto two_tup_large =
c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})});
auto three_tup = c10::ivalue::Tuple::create(
{at::randn({1}), at::randn({1}), at::randn({1})});
auto three_tup_large = c10::ivalue::Tuple::create(
{at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})});
testStaticRuntime(two_tuple_unpack_script, {two_tup});
testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large});
testStaticRuntime(three_tuple_unpack_script, {three_tup});
testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large});
}
TEST(StaticRuntime, Append) {
const auto append_int_script = R"JIT(
def forward(self, a: int):
lst = [1, 2, 3]
lst.append(a)
return lst
)JIT";
const auto append_tensor_script = R"JIT(
def forward(self, a: Tensor):
lst = []
lst.append(a)
return lst
)JIT";
std::vector<IValue> args_int{1};
testStaticRuntime(append_int_script, args_int);
std::vector<IValue> args_tensor{at::randn({1})};
std::vector<IValue> args_tensor_large{at::randn({2, 2})};
testStaticRuntime(append_tensor_script, args_tensor);
testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large);
}
TEST(StaticRuntime, QuantizedLinear) {
const std::string quantize_script = R"IR(
graph(%input: Tensor, %weights: Tensor):
%scale: float = prim::Constant[value=1.]()
%zero_point: int = prim::Constant[value=1]()
%bias: None = prim::Constant()
%packed_params = quantized::linear_prepack(%weights, %bias)
%1254 = quantized::linear(%input, %packed_params, %scale, %zero_point)
%1249: Tensor = aten::dequantize(%1254)
return (%1249)
)IR";
at::Tensor weight =
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
at::Tensor input =
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
at::Tensor weight_2 =
at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);
at::Tensor input_2 =
at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);
testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2});
}
TEST(StaticRuntime, QuantizedLinearDynamicFp16) {
const std::string quantized_linear_dynamic_fp16_script = R"IR(
graph(%input: Tensor, %weights: Tensor):
%bias: None = prim::Constant()
%packed_params = quantized::linear_prepack_fp16(%weights, %bias)
%output = quantized::linear_dynamic_fp16(%input, %packed_params)
%ret = aten::clone(%output, %bias)
return (%ret)
)IR";
at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
at::Tensor input = torch::randn({3, 2}, torch::kFloat);
at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
testStaticRuntime(
quantized_linear_dynamic_fp16_script,
{input, weight},
{input_2, weight_2});
}
TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) {
const std::string quantized_linear_relu_dynamic_fp16_script = R"IR(
graph(%input: Tensor, %weights: Tensor):
%bias: None = prim::Constant()
%packed_params = quantized::linear_prepack_fp16(%weights, %bias)
%output = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
%ret = aten::clone(%output, %bias)
return (%ret)
)IR";
at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
at::Tensor input = torch::randn({3, 2}, torch::kFloat);
at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
testStaticRuntime(
quantized_linear_relu_dynamic_fp16_script,
{input, weight},
{input_2, weight_2});
}
TEST(StaticRuntime, VarStack) {
const auto var_stack_script = R"JIT(
def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
return torch.stack([inp1, inp2], dim).clone()
)JIT";
// 2D tensors - stack dim = 0
std::vector<IValue> args1 = {at::randn({6, 6}), at::randn({6, 6}), 0};
testStaticRuntime(var_stack_script, args1);
// 3D tensors - stack dim = 1
std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1};
testStaticRuntime(var_stack_script, args2);
// 3D tensors - stack dim = 2
std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2};
testStaticRuntime(var_stack_script, args3);
// Negative dim
std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), -1};
testStaticRuntime(var_stack_script, args4);
// Non-serial path
std::vector<IValue> args5 = {at::randn({1, 2, 3}), at::randn({1, 2, 3}), 3};
testStaticRuntime(var_stack_script, args5);
// Fast path
std::vector<IValue> args6 = {at::randn({1}), at::randn({1}), 0};
testStaticRuntime(var_stack_script, args6);
testStaticRuntime(var_stack_script, args1, args2);
}
TEST(StaticRuntime, FmodTensor) {
const auto fmod_tensor = R"JIT(
def forward(self, a: Tensor, b: Tensor):
return torch.fmod(a, b).clone()
)JIT";
// fmod tensor version
auto a = at::randn({2, 3});
auto b = at::randn({2, 3});
std::vector<IValue> args0{a, b};
testStaticRuntime(fmod_tensor, args0);
// check for dynamic shapes
auto c = at::randn({4, 3, 2});
auto d = at::randn({4, 3, 2});
std::vector<IValue> args1{c, d};
testStaticRuntime(fmod_tensor, args0, args1);
}
TEST(StaticRuntime, FmodScalar) {
const auto fmod_scalar = R"JIT(
def forward(self, a: Tensor, b: int):
return torch.fmod(a, b).clone()
)JIT";
auto a = at::randn({2, 3});
// fmod scalar version
std::vector<IValue> args2{a, 3};
testStaticRuntime(fmod_scalar, args2);
// check for dynamic shapes
auto c = at::randn({4, 3, 2});
std::vector<IValue> args3{c, 4};
testStaticRuntime(fmod_scalar, args2, args3);
// test int32 version
a = at::randint(-100, 100, {2, 3}, at::kInt);
c = at::randint(-100, 100, {4, 3, 2}, at::kInt);
testStaticRuntime(fmod_scalar, {a, 3});
testStaticRuntime(fmod_scalar, {a, 3}, {c, 4});
}
TEST(StaticRuntime, QEmbeddingBagBytePrepack) {
const std::string embedding_bag_byte_prepack_script = R"IR(
graph(%input: Tensor):
%none : None = prim::Constant()
%output: Tensor = quantized::embedding_bag_byte_prepack(%input)
%res: Tensor = aten::clone(%output, %none)
return (%res)
)IR";
auto a = torch::randn({8, 16}, at::ScalarType::Float);
auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
testStaticRuntime(embedding_bag_byte_prepack_script, {a});
testStaticRuntime(embedding_bag_byte_prepack_script, {a}, {b});
}
TEST(StaticRuntime, QEmbeddingBagByteUnpack) {
const auto src = R"IR(
graph(%input: Tensor):
%none : None = prim::Constant()
%weight: Tensor = quantized::embedding_bag_byte_prepack(%input)
%output: Tensor = quantized::embedding_bag_byte_unpack(%weight)
%res: Tensor = aten::clone(%output, %none)
return (%res)
)IR";
auto a = torch::randn({8, 16}, at::ScalarType::Float);
auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
testStaticRuntime(src, {a});
testStaticRuntime(src, {a}, {b});
}
TEST(StaticRuntime, LinalgNorm_ScalarOrd) {
const auto linalg_norm_ord_scalar = R"JIT(
def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int):
return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
)JIT";
auto a = at::randn({2, 3});
auto dim = std::vector<int64_t>({1});
auto dtype = at::ScalarType::Float;
std::vector<IValue> args0{a, 4, dim, true, dtype};
testStaticRuntime(linalg_norm_ord_scalar, args0);
auto b = at::randn({3, 2, 6});
std::vector<IValue> args1{b, 4, dim, true, dtype};
testStaticRuntime(linalg_norm_ord_scalar, args0, args1);
}
TEST(StaticRuntime, LinalgNorm_StringOrd) {
const auto linalg_norm_ord_str = R"JIT(
def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
)JIT";
auto a = at::randn({2, 3});
auto dim = std::vector<int64_t>({0, 1});
auto dtype = at::ScalarType::Float;
std::vector<IValue> args0{a, "fro", dim, true, dtype};
testStaticRuntime(linalg_norm_ord_str, args0);
auto b = at::randn({3, 2, 17});
std::vector<IValue> args1{b, "fro", dim, true, dtype};
testStaticRuntime(linalg_norm_ord_str, args0, args1);
}
TEST(StaticRuntime, Index_Put) {
const auto index_put_str = R"JIT(
def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
return torch.index_put(a, indices, values, accumulate).clone()
)JIT";
auto a = at::randn({2});
auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong));
auto values_a = at::randn({1});
std::vector<IValue> args0{a, indices_a, values_a, false};
testStaticRuntime(index_put_str, args0);
const auto index_put_non_optional_str = R"JIT(
def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool):
return torch.index_put(a, indices, values, accumulate).clone()
)JIT";
auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
std::vector<IValue> args1{a, indices_b, values_a, false};
testStaticRuntime(index_put_non_optional_str, args1);
const auto index_put_list_construct = R"JIT(
def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool):
indices: List[Optional[Tensor]] = [indices]
return torch.index_put(a, indices, values, accumulate).clone()
)JIT";
std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false};
testStaticRuntime(index_put_list_construct, args2);
}
TEST(StaticRuntime, Item) {
const auto item_str = R"JIT(
def forward(self, a: Tensor):
return torch.item(a)
)JIT";
auto a = at::randn({1});
std::vector<IValue> args0{a};
testStaticRuntime(item_str, args0);
}
TEST(StaticRuntime, Tensor_Split) {
const auto tensor_split_str1 = R"JIT(
def forward(self, a: Tensor, sections: int, dim: int):
return torch.tensor_split(a, sections, dim)
)JIT";
std::vector<IValue> args1{at::randn({8}), 3, 0};
const auto tensor_split_str2 = R"JIT(
def forward(self, a: Tensor, sections: Tensor, dim: int):
return torch.tensor_split(a, sections, dim)
)JIT";
std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
const auto tensor_split_str3 = R"JIT(
def forward(self, a: Tensor, indicies: List[int], dim: int):
return torch.tensor_split(a, indicies, dim)
)JIT";
std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
testStaticRuntime(tensor_split_str1, args1);
testStaticRuntime(tensor_split_str2, args2);
testStaticRuntime(tensor_split_str3, args3);
}
TEST(StaticRuntime, JIT_Aten_Cpu) {
const std::string script = R"IR(
graph(%a: Tensor):
%1 : int = prim::Constant[value=0]()
%aa: Tensor = aten::add(%a, %a, %1)
%ret: Tensor = aten::cpu(%aa)
return (%ret)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
vmap.reserve(0);
parseIR(script, graph.get(), vmap);
torch::jit::StaticModule smodule(graph);
auto a = at::randn({2, 4});
std::vector<IValue> args0{a};
testStaticRuntime(script, args0);
}
TEST(StaticRuntime, JIT_Aten_Numel) {
const std::string script = R"IR(
graph(%a: Tensor):
%1 : int = prim::Constant[value=0]()
%aa: Tensor = aten::add(%a, %a, %1)
%ret: int = aten::numel(%aa)
return (%ret)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
vmap.reserve(0);
parseIR(script, graph.get(), vmap);
torch::jit::StaticModule smodule(graph);
auto a = at::randn({2, 4});
std::vector<IValue> args0{a};
testStaticRuntime(script, args0);
}
TEST(StaticRuntime, JIT_Aten_List) {
const auto script_str = R"IR(
graph(%a: str):
%ret: str[] = aten::list(%a)
return (%ret)
)IR";
std::string a = "abcd";
std::vector<IValue> args0{a};
testStaticRuntime(script_str, args0);
// Update the result of aten::list to ensure that a deep copy
// took place
const auto script_list = R"IR(
graph(%a : int[]):
%idx : int = prim::Constant[value=0]()
%value : int = prim::Constant[value=42]()
%res : int[] = aten::list(%a)
%updated : int[] = aten::_set_item(%res, %idx, %value)
return (%res, %a)
)IR";
std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
testStaticRuntime(script_list, args1);
}
TEST(StaticRuntime, JIT_Aten_Range_Length) {
const std::string script = R"IR(
graph(%lo: int, %hi: int, %step: int):
%1 : int = prim::Constant[value=0]()
%ret: int = aten::__range_length(%lo, %hi, %step)
return (%ret)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
vmap.reserve(0);
parseIR(script, graph.get(), vmap);
torch::jit::StaticModule smodule(graph);
std::vector<IValue> args0{0, 10, 2};
testStaticRuntime(script, args0);
}
TEST(StaticRuntime, Cat) {
const std::string cat_script = R"IR(
graph(%a: Tensor, %b: Tensor, %dim: int):
%ten_list: Tensor[] = prim::ListConstruct(%a, %b)
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=1]()
%ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3)
%ret: Tensor = aten::cat(%ten_list2, %dim)
return (%ret)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(cat_script, graph.get(), vmap);
torch::jit::StaticModule smodule(graph);
ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat"));
auto a = at::randn({2, 4});
auto b = at::randn({3, 4});
std::vector<IValue> args0{a, b, 0};
testStaticRuntime(cat_script, args0);
auto c = at::randn({3, 4});
auto d = at::randn({3, 5});
std::vector<IValue> args1{c, d, 1};
testStaticRuntime(cat_script, args0, args1);
std::vector<IValue> args_dim_negative{c, d, -1};
testStaticRuntime(cat_script, args_dim_negative);
}
TEST(StaticRuntime, Cumsum) {
const auto cumsum_script = R"JIT(
def forward(self, a: Tensor, dim: int):
return torch.cumsum(a, dim).clone()
)JIT";
auto a = at::randn({2, 3});
std::vector<IValue> args0{a, 0};
testStaticRuntime(cumsum_script, args0);
auto b = at::randn({3, 6});
std::vector<IValue> args1{b, 1};
testStaticRuntime(cumsum_script, args0, args1);
}
TEST(StaticRuntime, CumsumDtype) {
const auto cumsum_script_dtype = R"JIT(
def forward(self, a: Tensor, dim: int, dtype: int):
return torch.cumsum(a, dim, dtype=dtype).clone()
)JIT";
auto a = at::randn({1, 2});
auto dtype = at::ScalarType::Float;
std::vector<IValue> args0{a, 0, dtype};
testStaticRuntime(cumsum_script_dtype, args0);
auto b = at::randn({3, 6});
std::vector<IValue> args1{b, 1, dtype};
testStaticRuntime(cumsum_script_dtype, args0, args1);
}
TEST(StaticRuntime, Nonzero) {
const auto nonzero_tensor = R"JIT(
def forward(self, input: Tensor):
a = torch.nonzero(input).clone()
return (a)
)JIT";
auto a = at::randint(0, 2, {2, 3});
testStaticRuntime(nonzero_tensor, {a});
auto b = at::randint(0, 2, {4, 3, 2});
testStaticRuntime(nonzero_tensor, {a}, {b});
}
TEST(StaticRuntime, SignedLog1p) {
const std::string signed_log1p_script = R"IR(
graph(%input):
%0 : Tensor = aten::sign(%input)
%1 : Tensor = aten::abs(%input)
%2 : Tensor = aten::log1p(%1)
%3 : Tensor = aten::mul(%0, %2)
%none : NoneType = prim::Constant()
%res : Tensor = aten::clone(%3, %none)
return (%res)
)IR";
std::vector<IValue> args1 = {at::randn({2, 2})};
testStaticRuntime(signed_log1p_script, args1, {}, true);
std::vector<IValue> args2 = {at::randn({3, 3, 3})};
testStaticRuntime(signed_log1p_script, args1, args2, true);
}
TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithImmutableInputDict) {
const auto getitem_immutable_input_dict_script = R"JIT(
def forward(self, input: Dict[int, Tensor]):
a = input[0]
b = input[1]
c = a + b
return c.clone()
)JIT";
script::Module module("module");
module.define(getitem_immutable_input_dict_script);
torch::jit::StaticModule smodule(module);
EXPECT_FALSE(hasNodeWithKind(smodule, "aten::__getitem__"));
EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
auto a = at::randn({2, 4});
auto b = at::randn({2, 4});
c10::Dict<c10::IValue, c10::IValue> dict(
c10::IntType::get(), c10::TensorType::get());
dict.insert(0, a);
dict.insert(1, b);
testStaticRuntime(getitem_immutable_input_dict_script, {dict});
c10::Dict<c10::IValue, c10::IValue> dict0(
c10::IntType::get(), c10::TensorType::get());
auto a0 = at::randn({3, 4});
auto b0 = at::randn({3, 4});
dict0.insert(0, a0);
dict0.insert(1, b0);
testStaticRuntime(getitem_immutable_input_dict_script, {dict0});
}
TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithMutableInputDict) {
const auto getitem_mutable_input_dict_script = R"JIT(
def forward(self, input: Dict[int, Tensor]):
a = input[0]
input[1] = a
b = input[1]
c = a + b
return c.clone()
)JIT";
script::Module module("module");
module.define(getitem_mutable_input_dict_script);
torch::jit::StaticModule smodule(module);
EXPECT_TRUE(hasNodeWithKind(smodule, "aten::__getitem__"));
EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
}
TEST(StaticRuntime, VarTupleUnpack) {
const auto var_tuple_unpack_script = R"JIT(
def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
a, b = input_0
c, d = input_1
res = a * c + b * d
return res.clone()
)JIT";
script::Module module("module");
module.define(var_tuple_unpack_script);
torch::jit::StaticModule smodule(module);
EXPECT_FALSE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
auto a = at::randn({2, 2});
auto b = at::randn({3, 3, 3});
std::vector<IValue> args1{
c10::ivalue::Tuple::create(a, a), c10::ivalue::Tuple::create(1, 2)};
std::vector<IValue> args2{
c10::ivalue::Tuple::create(b, b), c10::ivalue::Tuple::create(1, 2)};
testStaticRuntime(var_tuple_unpack_script, args1);
testStaticRuntime(var_tuple_unpack_script, args1, args2);
}
TEST(StaticRuntime, VarTupleUnpack_NotApplied) {
const auto var_tuple_unpack_not_applied_script = R"JIT(
def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
a, b = input_0
x = a + b
c, d = input_1
res = a * c + b * d + x
return res.clone()
)JIT";
script::Module module("module");
// In this script, the optimization is not applied since there is a
// computation between the TupleUnpack nodes.
module.define(var_tuple_unpack_not_applied_script);
torch::jit::StaticModule smodule(module);
EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
EXPECT_TRUE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
}
TEST(StaticRuntime, RemainderTensor) {
const auto remainder_tensor = R"JIT(
def forward(self, x, y):
return torch.remainder(x, y).clone()
)JIT";
std::vector<IValue> args1 = {
at::randint(0, 10, {2, 2}), at::randint(1, 10, {2, 2})};
std::vector<IValue> args2 = {
at::randint(0, 10, {3, 6}), at::randint(1, 10, {3, 6})};
// Use allclose and equalnan since outputs may be NaN.
testStaticRuntime(
remainder_tensor,
args1,
/*args2*/ {},
/*use_alloclose*/ true,
/*use_equalnan*/ true);
testStaticRuntime(
remainder_tensor,
args1,
args2,
/*use_allclose*/ true,
/*use_equalnan*/ true);
}
TEST(StaticRuntime, RemainderScalar) {
const auto remainder_scalar = R"JIT(
def forward(self, x, y: int):
return torch.remainder(x, y).clone()
)JIT";
std::vector<IValue> args1 = {at::randint(0, 10, {2, 2}), 4};
std::vector<IValue> args2 = {at::randint(0, 10, {3, 6}), 4};
// Use allclose and equalnan since outputs may be NaN.
testStaticRuntime(
remainder_scalar,
args1,
/*args2*/ {},
/*use_alloclose*/ true,
/*use_equalnan*/ true);
testStaticRuntime(
remainder_scalar,
args1,
args2,
/*use_allclose*/ true,
/*use_equalnan*/ true);
}
TEST(StaticRuntime, Where) {
const auto where_script = R"JIT(
def forward(self, x, y):
return torch.where(x > 0, x, y).clone()
)JIT";
std::vector<IValue> args1 = {at::randn({2, 2}), at::randn({2, 2})};
std::vector<IValue> args2 = {at::randn({8, 10}), at::randn({8, 10})};
testStaticRuntime(where_script, args1);
testStaticRuntime(where_script, args1, args2);
}
TEST(StaticRuntime, WhereBroadcast) {
const auto where_script = R"JIT(
def forward(self, cond_1d, x, y):
shape = [-1] + [1] * (x.dim() - 1)
cond = cond_1d.view(shape)
return torch.where(cond, x, y).clone()
)JIT";
std::vector<IValue> args1 = {
at::tensor({0, 1}).to(at::kBool), at::randn({2, 2}), at::randn({2, 2})};
std::vector<IValue> args2 = {
at::tensor({1, 0, 0}).to(at::kBool),
at::randn({3, 6}),
at::randn({3, 6})};
testStaticRuntime(where_script, args1);
testStaticRuntime(where_script, args1, args2);
}
TEST(StaticRuntime, View) {
// Note that clone is not technically necessary here since this is not
// an out variant, but it suppresses warnings about only have one op
// in testStaticRuntime
const auto src = R"IR(
graph(%input : Tensor, %shape : int[]):
%none : NoneType = prim::Constant()
%view : Tensor = aten::view(%input, %shape)
%res : Tensor = aten::clone(%view, %none)
return (%res)
)IR";
std::vector<IValue> args1{at::randn({2, 2}), c10::List<int64_t>(4)};
std::vector<IValue> args2{at::randn({2, 2, 2}), c10::List<int64_t>({4, 2})};
testStaticRuntime(src, args1);
testStaticRuntime(src, args1, args2);
}
TEST(StaticRuntime, Size) {
const auto src_with_dim = R"JIT(
def forward(self, x, dim: int):
return x.size(dim)
)JIT";
const auto src_no_dim = R"JIT(
def forward(self, x):
return x.size()
)JIT";
std::vector<IValue> args1{at::randn({1}), 0};
std::vector<IValue> args2{at::randn({1}), -1};
std::vector<IValue> args3{at::randn({2, 4}), 1};
std::vector<IValue> args_no_dim{at::randn({2, 4})};
testStaticRuntime(src_with_dim, args1);
testStaticRuntime(src_with_dim, args2);
testStaticRuntime(src_with_dim, args1, args3);
testStaticRuntime(src_no_dim, args_no_dim);
}
TEST(StaticRuntime, Squeeze) {
// Note: this is a native op, not an out variant, but clone anyways
// to silence warnings in testStaticRuntime
const auto src = R"JIT(
def forward(self, inp, dim: int):
return inp.squeeze(dim).clone()
)JIT";
const auto a = at::randn({2, 2});
const auto b = at::randn({3, 2, 3});
testStaticRuntime(src, {a, 0});
testStaticRuntime(src, {a, 1});
testStaticRuntime(src, {a, -1}, {b, 2});
}
TEST(StaticRuntime, NumToTensorScalar) {
const auto num_to_tensor_ir = R"IR(
graph(%1 : int):
%2 : NoneType = prim::Constant()
%3 : Tensor = prim::NumToTensor(%1)
%4 : Tensor = aten::clone(%3, %2)
return (%4)
)IR";
IValue arg{5};
std::vector<IValue> args = {arg};
testStaticRuntime(num_to_tensor_ir, args);
}
TEST(StaticRuntime, NumToTensorFalse) {
const auto num_to_tensor_ir = R"IR(
graph(%1 : bool):
%2 : NoneType = prim::Constant()
%3 : Tensor = prim::NumToTensor(%1)
%4 : Tensor = aten::clone(%3, %2)
return (%4)
)IR";
IValue arg{false};
std::vector<IValue> args = {arg};
testStaticRuntime(num_to_tensor_ir, args);
}
TEST(StaticRuntime, NumToTensorTrue) {
const auto num_to_tensor_ir = R"IR(
graph(%1 : bool):
%2 : NoneType = prim::Constant()
%3 : Tensor = prim::NumToTensor(%1)
%4 : Tensor = aten::clone(%3, %2)
return (%4)
)IR";
IValue arg{true};
std::vector<IValue> args = {arg};
testStaticRuntime(num_to_tensor_ir, args);
}
TEST(StaticRuntime, Split) {
const auto src = R"JIT(
def forward(self, inp, split_size: int, dim: int):
return inp.split(split_size, dim)
)JIT";
const auto a = at::randn({2, 2});
const auto b = at::randn({2, 2, 2});
testStaticRuntime(src, {a, 1, 0});
testStaticRuntime(src, {a, 1, 1});
testStaticRuntime(src, {a, 2, -1}, {b, 2, 2});
}
TEST(StaticRuntime, SplitWithSizes) {
const auto src = R"JIT(
def forward(self, inp, split_sizes: List[int], dim: int):
return inp.split(split_sizes, dim)
)JIT";
const auto a = at::randn({2, 2});
const auto b = at::randn({2, 2, 2});
const auto split_sizes = c10::List<int64_t>{1, 1};
testStaticRuntime(src, {a, split_sizes, 0});
testStaticRuntime(src, {a, split_sizes, 1});
testStaticRuntime(src, {a, split_sizes, -1}, {b, split_sizes, 2});
}
namespace {
void maybe_throw(bool should_throw) {
if (should_throw) {
throw std::runtime_error("test exception");
}
}
TORCH_LIBRARY(static_runtime_tests, m) {
// Conservative so this op doesn't get deleted by dead
// code elimination
m.def(torch::schema(
"static_runtime_tests::maybe_throw(bool throw) -> ()",
at::AliasAnalysisKind::CONSERVATIVE));
m.impl("maybe_throw", maybe_throw);
}
} // namespace
TEST(StaticRuntime, ModelCrashOnFirstRun) {
const auto src = R"JIT(
graph(%0: Tensor, %throw: bool):
%1: Tensor = aten::mul(%0, %0)
static_runtime_tests::maybe_throw(%throw)
%2: Tensor = aten::mul(%1, %1)
%3: Tensor = aten::mul(%2, %2)
return (%3)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args_crash{at::randn({1}), true};
std::vector<IValue> args_no_crash{at::randn({1}), false};
EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
// The run didn't finish, we didn't allocate the memory planner
EXPECT_EQ(runtime.get_memory_planner(), nullptr);
runtime.check_for_memory_leak();
// We guarantee that the runtime is still usable after the crash.
// Run again to verify this.
compareResultsWithJIT(runtime, graph, args_no_crash);
EXPECT_NE(runtime.get_memory_planner(), nullptr);
}
TEST(StaticRuntime, ModelCrashOnSecondRun) {
const auto src = R"JIT(
graph(%0: Tensor, %throw: bool):
%1: Tensor = aten::mul(%0, %0)
static_runtime_tests::maybe_throw(%throw)
%2: Tensor = aten::mul(%1, %1)
%3: Tensor = aten::mul(%2, %2)
return (%3)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args_crash{at::randn({1}), true};
std::vector<IValue> args_no_crash{at::randn({1}), false};
runtime(args_no_crash, {});
EXPECT_NE(runtime.get_memory_planner(), nullptr);
runtime.check_for_memory_leak();
EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
runtime.check_for_memory_leak();
// We guarantee that the runtime is still usable after the crash.
// Run again to verify this.
compareResultsWithJIT(runtime, graph, args_no_crash);
}
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {
const auto src = R"JIT(
graph(%0: Tensor):
%1: Tensor = aten::mul(%0, %0)
%2: Tensor = aten::mul(%1, %1)
%3: bool = prim::Constant[value=1]()
%4: Tensor = static_runtime::select_tensor(%1, %2, %3)
static_runtime_tests::maybe_throw(%3)
return (%4)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args{at::randn({1})};
EXPECT_THROW(runtime(args), std::runtime_error);
}
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {
const auto src = R"JIT(
graph(%0: Tensor, %1: Tensor):
%2: bool = prim::Constant[value=1]()
%3: Tensor = static_runtime::select_tensor(%0, %1, %2)
static_runtime_tests::maybe_throw(%2)
return (%3)
)JIT";
auto graph = getGraphFromIR(src);
auto static_module = StaticModule(graph);
auto& runtime = static_module.runtime();
std::vector<IValue> args{at::randn({1}), at::randn({1})};
EXPECT_THROW(runtime(std::move(args)), std::runtime_error);
}
TEST(StaticRuntime, ReplaceWithMaybeCopy) {
const std::string to = R"IR(
graph(%0 : Tensor):
%1: int = prim::Constant[value=4]()
%2: bool = prim::Constant[value=0]()
%3: None = prim::Constant()
%res : Tensor = aten::to(%0, %1, %2, %2, %3)
return (%res)
)IR";
at::Tensor a = at::tensor({1.1, 2.2, 3.3, 4.0}, at::ScalarType::Float);
std::vector<IValue> args{a};
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(to, g.get());
// Jit Interpreter.
Stack stack(args);
torch::jit::GraphExecutor graph_exec(g, "");
graph_exec.run(stack);
ASSERT_EQ(stack.size(), 1);
auto expected = stack[0].toTensor();
// Static Runtime.
torch::jit::StaticModule smodule(g);
auto actual = smodule(args, {}).toTensor();
smodule.runtime().check_for_memory_leak();
EXPECT_TRUE(expected.equal(actual));
EXPECT_FALSE(hasProcessedNodeWithName(smodule, "aten::to"));
EXPECT_TRUE(
hasProcessedNodeWithName(smodule, "static_runtime::to_maybe_copy_out"));
}
TEST(StaticRuntime, Int) {
const auto src = R"JIT(
def forward(self, x):
return int(x) + int(x)
)JIT";
std::vector<IValue> args{at::tensor({3.14})};
testStaticRuntime(src, args);
}
TEST(StaticRuntime, ReturnConstant) {
const auto src = R"JIT(
def forward(self):
return 1
)JIT";
testStaticRuntime(src, {});
}
TEST(StaticRuntime, SimpleIf) {
const auto src = R"JIT(
def forward(self, cond: bool, x):
if cond:
return torch.mul(x, 42).clone()
else:
return x.clone()
)JIT";
std::vector<IValue> args_false{false, at::randn({1})};
std::vector<IValue> args_true{true, at::randn({1})};
std::vector<IValue> args_big_tensor{true, at::randn({3, 3, 3})};
testStaticRuntime(src, args_false);
testStaticRuntime(src, args_true);
testStaticRuntime(src, args_true, args_big_tensor);
}
TEST(StaticRuntime, NestedIf) {
const auto src = R"JIT(
def forward(self, cond1: bool, cond2: bool, x):
y = x * 42
if cond1:
y = y * y
if cond2:
y += x
else:
if cond2:
return x.clone()
return y.clone()
)JIT";
for (auto cond1 : {true, false}) {
for (auto cond2 : {true, false}) {
std::vector<IValue> args1{cond1, cond2, at::randn({1})};
std::vector<IValue> args2{cond1, cond2, at::randn({3, 3, 3})};
testStaticRuntime(src, args1, args2);
}
}
}
TEST(StaticRuntime, DeeplyNestedIf) {
const auto src = R"JIT(
def forward(self, cond1: bool, cond2: bool, cond3: bool, x):
y = x * 42
if cond1:
y = y * y
if cond2:
y += x
if cond2 and cond3:
y += 1
if cond2:
if cond3:
y += 2
else:
y = y * y
y += 4
else:
if cond2:
return x.clone()
if cond3 or cond2:
y += 42
return y.clone()
)JIT";
for (auto cond1 : {true, false}) {
for (auto cond2 : {true, false}) {
for (auto cond3 : {true, false}) {
std::vector<IValue> args1{cond1, cond2, cond3, at::randn({1})};
std::vector<IValue> args2{cond1, cond2, cond3, at::randn({3, 3, 3})};
testStaticRuntime(src, args1, args2);
}
}
}
}
TEST(StaticRuntime, BasicForLoop) {
const auto src = R"JIT(
def forward(self, x, loop_max: int):
y = x.clone()
for i in range(loop_max):
y += 1
return y
)JIT";
std::vector<IValue> args1{at::randn({1}), 10};
std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
testStaticRuntime(src, args1, args2);
}
TEST(StaticRuntime, BasicWhileLoop) {
const auto src = R"JIT(
def forward(self, x, loop_max: int):
y = x.clone()
loop_count = 0
while loop_count < loop_max:
y += 1
loop_count += 1
return y
)JIT";
std::vector<IValue> args1{at::randn({1}), 10};
std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
testStaticRuntime(src, args1, args2);
}
TEST(StaticRuntime, NestedLoops) {
const auto src = R"JIT(
def forward(self, x, loop_max: int):
y = x.clone()
even: List[int] = []
odd: List[int] = []
for i in range(loop_max):
if i % 2:
odd.append(i)
else:
even.append(i)
for j in range(i):
y += 1
return y, even, odd
)JIT";
std::vector<IValue> args1{at::randn({1}), 10};
std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
testStaticRuntime(src, args1, args2);
}
TEST(StaticRuntime, TupleIndex) {
const auto src = R"JIT(
def forward(self, idx: int, tup: Tuple[int, int]):
a = tup[idx]
return a * a
)JIT";
const auto tuple = c10::ivalue::Tuple::create({1, 2});
testStaticRuntime(src, {1, tuple}, {-1, tuple});
torch::jit::Module mod("module");
mod.define(src);
StaticModule smod(mod);
EXPECT_THROW(smod({100, tuple}), std::out_of_range);
}
TEST(StaticRuntime, RaiseException) {
const auto src = R"IR(
graph(%str: str):
%none: NoneType = prim::Constant()
prim::RaiseException(%str, %none)
return (%none)
)IR";
auto graph = getGraphFromIR(src);
StaticModule smod(graph);
const auto msg = "exception message";
EXPECT_THROW(
{
try {
smod({msg});
} catch (const std::runtime_error& e) {
EXPECT_STREQ(msg, e.what());
throw;
}
},
std::runtime_error);
}
TEST(StaticRuntime, Uninitialized) {
const auto src = R"IR(
graph():
%0: int = prim::Uninitialized()
return (%0)
)IR";
auto graph = getGraphFromIR(src);
StaticModule smod(graph);
const auto ret = smod({});
// If a and b are both uninitialized, then a != b. So just check that the type
// is Any
EXPECT_EQ(ret.type()->kind(), c10::TypeKind::AnyType);
}
TEST(StaticRuntime, Format) {
const auto src = R"JIT(
def forward(self, arg1: int, arg2: Tensor, arg3: str):
a = "arg1: {}, arg2: {}, arg3: {}".format(arg1, arg2, arg3)
return a[::]
)JIT";
testStaticRuntime(src, {1, at::randn({3}), "str"});
}
TEST(StaticRuntime, Device) {
const auto src = R"JIT(
def forward(self, x):
return x.device, x.device
)JIT";
testStaticRuntime(src, {at::tensor({1})});
}
TEST(StaticRuntime, Dtype) {
const auto src = R"JIT(
def forward(self, x, y):
return x.dtype, y.dtype
)JIT";
testStaticRuntime(
src, {at::tensor({1}, at::kLong), at::tensor({1}, at::kFloat)});
}
TEST(StaticRuntime, Dim) {
const auto src = R"JIT(
def forward(self, x, y):
return x.dim(), y.dim()
)JIT";
testStaticRuntime(src, {at::randn({2, 2}), at::randn({1})});
}
TEST(StaticRuntime, Not) {
const auto src = R"JIT(
def forward(self, x: bool, y: bool):
return not x, not y
)JIT";
testStaticRuntime(src, {true, false});
}
TEST(StaticRuntime, Bool) {
const auto src = R"JIT(
def forward(self, x: Tensor, y: int, z: float):
return bool(x), bool(y), bool(z)
)JIT";
testStaticRuntime(src, {at::randn({1}), 0, 1.151}, {at::zeros({1}), 1, 0.0});
}
TEST(StaticRuntime, IsCuda) {
const auto src = R"JIT(
def forward(self, x: Tensor, y: Tensor):
return x.is_cuda, y.is_cuda
)JIT";
testStaticRuntime(src, {at::randn({1}), at::randn({1})});
}
TEST(StaticRuntime, ToList) {
const auto src = R"JIT(
graph(%x: Tensor):
%type: int = prim::Constant[value=1]()
%dim: int = aten::dim(%x)
%ret: float[] = prim::tolist(%x, %dim, %type)
return (%ret)
)JIT";
testStaticRuntime(src, {at::randn({2, 2})});
}
TEST(StaticRuntime, IfThenElse) {
const auto src = R"IR(
graph(%cond: bool, %a: Tensor, %b: Tensor):
%none: NoneType = prim::Constant()
%c: Tensor = prim::IfThenElse(%cond, %a, %b)
%d: Tensor = aten::clone(%c, %none)
return (%d)
)IR";
std::vector<IValue> args1{true, at::randn({1}), at::randn({1})};
std::vector<IValue> args2{false, at::randn({1}), at::randn({1})};
testStaticRuntime(src, args1);
testStaticRuntime(src, args2);
}
TEST(StaticRuntime, EmptyIfBlock) {
const auto src =
R"JIT(
def forward(self, cond: bool, a: Tensor, b: Tensor):
l = []
if cond:
l.append((a + b).clone())
return l
)JIT";
testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
}
TEST(StaticRuntime, EmptyNestedIfBlock) {
const auto src =
R"JIT(
def forward(self, cond: bool, a: Tensor, b: Tensor):
l = []
if cond:
if cond:
l.append((a + b).clone())
return l
)JIT";
testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
}
TEST(StaticRuntime, StackEmpty) {
const auto src = R"JIT(
def forward(self):
x = torch.stack([])
return x
)JIT";
torch::jit::Module mod("mod");
mod.define(src);
torch::jit::StaticModule smod(mod);
EXPECT_THROW(smod({}), c10::Error);
}
TEST(StaticRuntime, ConcatEmpty) {
const auto src = R"JIT(
def forward(self):
x = torch.concat([])
return x
)JIT";
torch::jit::Module mod("mod");
mod.define(src);
torch::jit::StaticModule smod(mod);
EXPECT_THROW(smod({}), c10::Error);
}
TEST(StaticRuntime, IntImplicit) {
const auto src = R"IR(
graph(%a: Tensor):
%y: int = aten::IntImplicit(%a)
return (%y)
)IR";
testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()});
}
TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) {
const auto src = R"IR(
graph(%a: Tensor):
%y: int = aten::IntImplicit(%a)
return (%y)
)IR";
auto graph = getGraphFromIR(src);
torch::jit::StaticModule smod(graph);
// Not 0D tensor
EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error);
// Wrong dtype
EXPECT_THROW(
smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error);
}
TEST(StaticRuntime, Select) {
const auto src = R"IR(
graph(%a: Tensor, %dim: int, %index: int):
%none: NoneType = prim::Constant()
%b: Tensor = aten::select(%a, %dim, %index)
%c: Tensor = aten::clone(%b, %none)
return (%c)
)IR";
testStaticRuntime(src, {at::randn({2, 2}), 0, 1});
}
TEST(StaticRuntime, ReshapeAs) {
const auto src = R"JIT(
def forward(self, a, b):
return a.reshape_as(b).clone()
)JIT";
testStaticRuntime(src, {at::randn({2, 2}), at::randn({4})});
}
TEST(StaticRuntime, MoveCtor) {
auto mod = getDeepAndWideSciptModel();
std::vector<IValue> args{
at::randn({1, 1, 32}), at::randn({1, 1, 32}), at::randn({1, 50})};
torch::jit::StaticModule smod(mod);
torch::jit::StaticRuntime runtime(smod);
auto expected = runtime(args);
torch::jit::StaticRuntime new_runtime(std::move(runtime));
auto actual = new_runtime(args);
compareResults(expected, actual);
}
TEST(StaticRuntime, SingleBlockIfReturnList) {
const auto src = R"JIT(
def forward(self, a, b, cond: bool):
lst = []
if cond:
lst.append(a + b)
return lst
)JIT";
std::vector<IValue> args1{at::randn({1}), at::randn({1}), true};
std::vector<IValue> args2{at::randn({42, 42}), at::randn({42, 42}), false};
testStaticRuntime(src, args1, args2);
}
TEST(StaticRuntime, NestedBlockIfReturnList) {
const auto src = R"JIT(
def forward(self, a, b, cond1: bool, cond2: bool):
if cond1:
lst = []
if cond2:
lst.append(a + b)
lst.append(a * b)
return lst
return []
)JIT";
std::vector<IValue> args1{at::randn({1}), at::randn({1}), true, true};
std::vector<IValue> args2{
at::randn({42, 42}), at::randn({42, 42}), true, false};
testStaticRuntime(src, args1, args2);
}
TEST(StaticRuntime, ClampNaNToNum) {
const auto src1 = R"JIT(
def forward(self, a):
return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone()
)JIT";
const auto src2 = R"JIT(
def forward(self, a, nan: float):
return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone()
)JIT";
const auto src3 = R"JIT(
def forward(self, a):
return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone()
)JIT";
auto a = at::tensor({
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::infinity(),
-std::numeric_limits<float>::infinity(),
0.0f,
3.0f
});
auto b = a.repeat({10, 5});
// Have to use_allclose even though all NaNs will be replaced - testStaticRuntime
// also checks inputs at the end to make sure they're not changed
testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true);
testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
// Non-NNC path
testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
}
TEST(StaticRuntime, PrepackWeights) {
const std::string src = R"IR(
graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor):
%none: NoneType = prim::Constant()
%result: Tensor = fb::quantized_linear_unpacked_weight_v2(%input, %weight, %bias, %scale, %zero_point)
%dequantized: Tensor = aten::dequantize(%result)
return (%dequantized)
)IR";
auto graph = getGraphFromIR(src);
PrepackWeights(graph);
ASSERT_TRUE(graphHasOp(graph, "quantized::linear"));
ASSERT_TRUE(graphHasOp(graph, "quantized::linear_prepack"));
ASSERT_FALSE(graphHasOp(graph, "fb::quantized_linear_unpacked_weight_v2"));
auto scale = at::tensor({2}, at::kFloat);
auto zero_point = at::tensor({3}, at::kLong);
auto weight =
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
auto input =
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
auto args1 = std::vector<IValue>{input, weight, c10::nullopt, scale, zero_point};
auto weight_2 =
at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);
auto input_2 =
at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);
auto bias_2 = torch::randn({3}, torch::kFloat);
auto args2 = std::vector<IValue>{input, weight, bias_2, scale, zero_point};
testStaticRuntime(src, args1);
testStaticRuntime(src, args2);
}
TEST(StaticRuntime, IfReturningTuple) {
const auto src = R"JIT(
def forward(self, x, y, cond: bool, idx: int):
if cond:
tup = (x, y)
else:
tup = (x, x)
return tup[idx]
)JIT";
std::vector<IValue> args{at::randn({3}), at::randn({3}), true, 0};
testStaticRuntime(src, args);
}