| #include <ATen/core/dispatch/OperatorOptions.h> |
| #include <c10/core/ScalarType.h> |
| #include <gtest/gtest.h> |
| #include <torch/csrc/jit/ir/alias_analysis.h> |
| #include <torch/csrc/jit/ir/irparser.h> |
| #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h> |
| #include <torch/csrc/jit/runtime/static/impl.h> |
| #include <torch/csrc/jit/runtime/static/passes.h> |
| #include <torch/csrc/jit/testing/file_check.h> |
| #include <stdexcept> |
| |
| #include "deep_wide_pt.h" |
| #include "test_utils.h" |
| |
| using namespace caffe2; |
| using namespace torch; |
| using namespace torch::jit; |
| using namespace torch::jit::test; |
| using c10::IValue; |
| |
| /* |
| When adding a test for an operator implemented in static runtime, there are |
| several things that you need to pay attention to: |
| |
| 1) if the op is an out variant, in the test script of the op, |
| instead of: |
| def forward(self, input): |
| return myop(input) |
| |
| do: |
| def forward(self, input): |
| return myop(input).clone() |
| |
| This makes sure that the output of myop is managed by the memory planner and |
| exercise the code path in the op impl that otherwise doesn't get exercised. The |
| output of the model is not managed by the memory planner, because it needs to |
| be returned to the client. |
| |
| 2) The memory planner rounds up the size of each Tensor's storage to multiples |
| of 64 bytes (alignment requirement on AVX512). Make sure the sizes of the input |
| tensors in args2 are big enough to trigger resizing. |
| |
| 3) for view ops such as aten::reshape or aten::to, if you want it to be |
| replaced by the copy version with the ReplaceWithCopy pass in passes.h, you |
| also want to make sure its output is not returned as the model output. The |
| reason is that ReplaceWithCopy only replaces the op whose output is not an |
| alias of the model output. |
| */ |
| |
| C10_DECLARE_bool(static_runtime_enable_fast_math); |
| |
| TEST(StaticRuntime, UnaryOps) { |
| const auto aten_sum = R"JIT( |
| def forward(self, input): |
| return torch.sum(input).clone() |
| )JIT"; |
| |
| const auto aten_sum_0 = R"JIT( |
| def forward(self, input): |
| return torch.sum(input, 0).clone() |
| )JIT"; |
| |
| const auto aten_sum_1 = R"JIT( |
| def forward(self, input): |
| return torch.sum(input, 1).clone() |
| )JIT"; |
| |
| const auto aten_sum_0_true = R"JIT( |
| def forward(self, input): |
| return torch.sum(input, 0, True).clone() |
| )JIT"; |
| |
| const auto aten_sum_1_true = R"JIT( |
| def forward(self, input): |
| return torch.sum(input, 1, True).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({3, 3, 6}); |
| |
| std::vector<IValue> args{a}, args2{b}; |
| |
| // sum |
| testStaticRuntime(aten_sum, args); |
| testStaticRuntime(aten_sum_0, args); |
| testStaticRuntime(aten_sum_1, args); |
| testStaticRuntime(aten_sum_0_true, args); |
| testStaticRuntime(aten_sum_1_true, args); |
| |
| testStaticRuntime(aten_sum, args, args2, false, false, false); |
| testStaticRuntime(aten_sum_0, args, args2); |
| testStaticRuntime(aten_sum_1, args, args2); |
| testStaticRuntime(aten_sum_0_true, args, args2); |
| testStaticRuntime(aten_sum_1_true, args, args2); |
| } |
| |
| TEST(StaticRuntime, Max) { |
| auto src_max_reduce = R"JIT( |
| def forward(self, input): |
| return torch.max(input).clone() |
| )JIT"; |
| |
| auto src_max_dim = R"JIT( |
| def forward(self, input, dim: int): |
| values, indices = torch.max(input, dim) |
| return values.clone(), indices.clone() |
| )JIT"; |
| |
| auto src_max_dim_keepdim = R"JIT( |
| def forward(self, input, dim: int): |
| values, indices = torch.max(input, dim, keepdim=True) |
| return values.clone(), indices.clone() |
| )JIT"; |
| |
| auto src_max_pointwise = R"JIT( |
| def forward(self, input, other): |
| return torch.max(input, other).clone() |
| )JIT"; |
| |
| auto input = at::randn({2, 3, 2}); |
| auto input_other = at::randn({2, 3, 2}); |
| auto large_input = at::randn({8, 9, 10}); |
| auto large_input_other = at::randn({8, 9, 10}); |
| |
| testStaticRuntime(src_max_reduce, {input}); |
| testStaticRuntime(src_max_dim, {input, 1}); |
| testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0}); |
| testStaticRuntime(src_max_dim_keepdim, {input, 0}); |
| testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2}); |
| testStaticRuntime(src_max_pointwise, {input, input_other}); |
| testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other}); |
| } |
| |
| TEST(StaticRuntime, Mean) { |
| const auto src_default = R"JIT( |
| def forward(self, input): |
| return torch.mean(input).clone() |
| )JIT"; |
| const auto src_dtype = R"JIT( |
| def forward(self, input, dtype: int): |
| return torch.mean(input, dtype=dtype).clone() |
| )JIT"; |
| const auto src_dim = R"JIT( |
| def forward(self, input, dim: List[int]): |
| return torch.mean(input, dim).clone() |
| )JIT"; |
| const auto src_dim_keepdim = R"JIT( |
| def forward(self, input, dim: List[int]): |
| return torch.mean(input, dim, keepdim=True).clone() |
| )JIT"; |
| const auto src_dim_dtype = R"JIT( |
| def forward(self, input, dim: List[int], dtype: int): |
| return torch.mean(input, dim, dtype=dtype).clone() |
| )JIT"; |
| |
| auto input = at::randn({2, 3, 2}); |
| auto large_input = at::randn({8, 7, 6, 8}); |
| |
| std::vector<IValue> args_default = {input}; |
| std::vector<IValue> args_dtype = {input, torch::kFloat}; |
| std::vector<IValue> args_dim = {input, c10::List<int64_t>{0, 2}}; |
| std::vector<IValue> args_dim_keepdim = {input, c10::List<int64_t>{1, 2}}; |
| std::vector<IValue> args_dim_dtype = {input, c10::List<int64_t>{0, 1}, torch::kBFloat16}; |
| |
| testStaticRuntime(src_default, args_default); |
| testStaticRuntime(src_dtype, args_dtype); |
| testStaticRuntime(src_dim, args_dim); |
| testStaticRuntime(src_dim_keepdim, args_dim_keepdim); |
| testStaticRuntime(src_dim_dtype, args_dim_dtype); |
| |
| std::vector<IValue> large_args_dim = {large_input, c10::List<int64_t>{0, 3}}; |
| std::vector<IValue> large_args_dim_keepdim = {large_input, c10::List<int64_t>{1, 2}}; |
| std::vector<IValue> large_args_dim_dtype = {large_input, c10::List<int64_t>{1, 3}, torch::kBFloat16}; |
| |
| testStaticRuntime(src_dim, args_dim, large_args_dim); |
| testStaticRuntime(src_dim_keepdim, args_dim_keepdim, large_args_dim_keepdim); |
| testStaticRuntime(src_dim_dtype, args_dim_dtype, large_args_dim_dtype); |
| } |
| |
| TEST(StaticRuntime, Sigmoid) { |
| const auto sigmoid_script = R"JIT( |
| def forward(self, inp: Tensor): |
| b = torch.sigmoid(inp).clone() |
| return (b) |
| )JIT"; |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({4, 3, 2}); |
| |
| std::vector<IValue> args{a}, args2{b}; |
| |
| testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true); |
| testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true); |
| |
| FLAGS_static_runtime_enable_fast_math = false; |
| testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true); |
| testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true); |
| FLAGS_static_runtime_enable_fast_math = true; |
| } |
| |
| TEST(StaticRuntime, Clone) { |
| /* |
| Clone called two times to trigger memory planner for output of first clone. |
| The output of last op(second clone) is not managed by memory planner since it |
| needs to be returned to the client and cannot be reused by planner. |
| */ |
| const auto clone_script_0 = R"JIT( |
| def forward(self, input): |
| a = torch.clone(input).clone() |
| return (a * a) |
| )JIT"; |
| |
| // Case: clone with different set of memory_formats |
| const auto clone_script_1 = R"JIT( |
| def forward(self, input: Tensor, memory_format: int): |
| a = torch.clone(input, memory_format=memory_format).clone() |
| return (a * a) |
| )JIT"; |
| |
| /* |
| Case: input stride set to 0 (due to expand op) |
| calls native clone instead of out variant |
| */ |
| const auto clone_script_2 = R"JIT( |
| def forward(self, input: Tensor, other:Tensor): |
| a = input.expand_as(other) |
| return a.clone().clone() |
| )JIT"; |
| |
| /* |
| Case: testing the case of sliced tensor for |
| testing non-contiguous tensor storage |
| */ |
| const auto clone_script_3 = R"JIT( |
| def forward(self, input: Tensor): |
| a = input[:, 0:10:2] |
| return a.clone().clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({3, 2}).as_strided({3, 2}, {1, 3}); |
| auto b_larger = at::randn({30, 20}).as_strided({30, 20}, {1, 3}); |
| auto c = at::randn({1, 20, 13, 8}); |
| auto d = at::randn({1, 0, 3, 4}); |
| auto e = at::randn({2, 1}); |
| auto f = at::randn({2, 10}); |
| auto g = at::randn({3, 20}); |
| std::vector<IValue> args_0{b, c10::MemoryFormat::Contiguous}; |
| std::vector<IValue> args_1{b_larger, c10::MemoryFormat::Preserve}; |
| std::vector<IValue> args_2{c, c10::MemoryFormat::ChannelsLast}; |
| std::vector<IValue> args_3{d, c10::MemoryFormat::ChannelsLast}; |
| std::vector<IValue> args_4{e,a}; |
| std::vector<IValue> args_5{e,f}; |
| |
| testStaticRuntime(clone_script_0, {a}); |
| testStaticRuntime(clone_script_0, {a}, {b_larger}); |
| |
| testStaticRuntime(clone_script_1, args_0); |
| testStaticRuntime(clone_script_1, args_1); |
| testStaticRuntime(clone_script_1, args_2); |
| testStaticRuntime(clone_script_1, args_3); |
| testStaticRuntime(clone_script_1, args_0, args_1); |
| testStaticRuntime(clone_script_1, args_3, args_2); |
| |
| testStaticRuntime(clone_script_2, args_4); |
| testStaticRuntime(clone_script_2, args_4, args_5); |
| |
| testStaticRuntime(clone_script_3, {f}); |
| testStaticRuntime(clone_script_3, {f}, {g}); |
| } |
| |
| TEST(StaticRuntime, Clamp) { |
| const auto clamp_script_1 = R"JIT( |
| def forward(self, inp: Tensor, min: int, max: int): |
| a = torch.clamp(inp, min, max).clone() |
| return (a) |
| )JIT"; |
| |
| const auto clamp_script_2 = R"JIT( |
| def forward(self, inp: Tensor, min: Tensor, max: Tensor): |
| a = torch.clamp(inp, min, max).clone() |
| return (a) |
| )JIT"; |
| auto a = at::randn({2, 3}); |
| auto max_t = at::full_like(a, 1); |
| auto min_t = at::full_like(a, -1); |
| |
| auto b = at::randn({4, 3, 2}); |
| auto max_t1 = at::full_like(b, 1); |
| auto min_t1 = at::full_like(b, -1); |
| |
| testStaticRuntime(clamp_script_1, {a, -1, 1}); |
| testStaticRuntime(clamp_script_2, {a, min_t, max_t}); |
| |
| testStaticRuntime(clamp_script_1, {a, -1, 1}, {b, -1, 1}); |
| testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1}); |
| } |
| |
| TEST(StaticRuntime, ClampMinOnly) { |
| const auto src = R"JIT( |
| def forward(self, inp: Tensor, min: float): |
| a = torch.clamp(inp, min, None).clone() |
| return (a) |
| )JIT"; |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({4, 3, 2}); |
| testStaticRuntime(src, {a, 0.5}); |
| testStaticRuntime(src, {a, 0.5}, {b, 0.25}); |
| } |
| |
| TEST(StaticRuntime, ClampMaxOnly) { |
| const auto src = R"JIT( |
| def forward(self, inp: Tensor, max: float): |
| a = torch.clamp(inp, None, max).clone() |
| return (a) |
| )JIT"; |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({4, 3, 2}); |
| testStaticRuntime(src, {a, 0.5}); |
| testStaticRuntime(src, {a, 0.5}, {b, 0.25}); |
| } |
| |
| TEST(StaticRuntime, ClampIntTensor) { |
| const auto src = R"JIT( |
| def forward(self, inp: Tensor, min: float, max: float): |
| a = torch.clamp(inp, min, max).clone() |
| return (a) |
| )JIT"; |
| auto a = at::randint(0, 20, {2, 3}, at::kFloat); |
| auto b = at::randint(0, 20, {4, 3, 2}, at::kFloat); |
| auto min = 5.0f; |
| auto max = 5.0f; |
| testStaticRuntime(src, {a, min, max}); |
| testStaticRuntime(src, {a, min, max}, {b, min, max}); |
| } |
| |
| TEST(StaticRuntime, LenWithTuple) { |
| const auto src = R"IR( |
| graph(%input : int[]): |
| %res : int = aten::len(%input) |
| return (%res) |
| )IR"; |
| |
| testStaticRuntime(src, {c10::List<int64_t>(4)}); |
| } |
| |
| TEST(StaticRuntime, LenWithTensor) { |
| const auto src = R"IR( |
| graph(%input : Tensor): |
| %res : int = aten::len(%input) |
| return (%res) |
| )IR"; |
| |
| testStaticRuntime(src, {at::randn({2, 2, 2})}); |
| } |
| |
| TEST(StaticRuntime, LenWithStr) { |
| const auto src = R"IR( |
| graph(%input : str): |
| %res : int = aten::len(%input) |
| return (%res) |
| )IR"; |
| |
| testStaticRuntime(src, {"static_runtime"}); |
| } |
| |
| TEST(StaticRuntime, LenWithDict_str) { |
| const auto script = R"JIT( |
| def forward(self, input: Dict[str, str]): |
| return len(input) |
| )JIT"; |
| |
| c10::Dict<std::string, std::string> dict; |
| dict.insert("abc", "123"); |
| dict.insert("def", "456"); |
| testStaticRuntime(script, {dict}); |
| } |
| |
| TEST(StaticRuntime, LenWithDict_int) { |
| const auto script = R"JIT( |
| def forward(self, input: Dict[int, int]): |
| return len(input) |
| )JIT"; |
| |
| c10::Dict<int64_t, int64_t> dict; |
| dict.insert(0, 1); |
| dict.insert(2, 3); |
| testStaticRuntime(script, {dict}); |
| } |
| |
| TEST(StaticRuntime, LenWithDict_bool) { |
| const auto script = R"JIT( |
| def forward(self, input: Dict[bool, bool]): |
| return len(input) |
| )JIT"; |
| |
| c10::Dict<bool, bool> dict; |
| dict.insert(true, false); |
| dict.insert(false, true); |
| testStaticRuntime(script, {dict}); |
| } |
| |
| TEST(StaticRuntime, LenWithDict_float) { |
| const auto script = R"JIT( |
| def forward(self, input: Dict[float, float]): |
| return len(input) |
| )JIT"; |
| |
| c10::Dict<double, double> dict; |
| dict.insert(0.1, 0.9); |
| dict.insert(0.8, 0.18); |
| testStaticRuntime(script, {dict}); |
| } |
| |
| TEST(StaticRuntime, LenWithDict_complex) { |
| const auto script = R"JIT( |
| def forward(self, input: Dict[complex, complex]): |
| return len(input) |
| )JIT"; |
| |
| c10::Dict<c10::complex<double>, c10::complex<double>> dict; |
| dict.insert(0.1, 0.4); |
| dict.insert(0.9, 0.45); |
| testStaticRuntime(script, {dict}); |
| } |
| |
| TEST(StaticRuntime, LenWithDict_Tensor) { |
| const auto script = R"JIT( |
| def forward(self, input: Dict[Tensor, Tensor]): |
| return len(input) |
| )JIT"; |
| |
| c10::Dict<at::Tensor, at::Tensor> dict; |
| dict.insert(at::randn({1, 2}), at::randn({1, 2})); |
| dict.insert(at::randn({1, 2}), at::randn({1, 2})); |
| testStaticRuntime(script, {dict}); |
| } |
| |
| TEST(StaticRuntime, Logit) { |
| // no nnc |
| const auto logit_script_1 = R"JIT( |
| def forward(self, inp: Tensor): |
| a = torch.logit(inp).clone() |
| return (a) |
| )JIT"; |
| |
| // with nnc |
| const auto logit_script_2 = R"JIT( |
| def forward(self, inp: Tensor): |
| a = torch.logit(inp, 1e-6).clone() |
| return (a) |
| )JIT"; |
| |
| // no nnc |
| const auto logit_script_3 = R"JIT( |
| def forward(self, inp: Tensor, eps: float): |
| a = torch.logit(inp, eps).clone() |
| return (a) |
| )JIT"; |
| auto a = at::ones({2, 3}); |
| double b = 1e-6; |
| std::vector<IValue> args_1{a}; |
| std::vector<IValue> args_2({a, b}); |
| |
| auto c = at::ones({4, 3, 2}); |
| |
| // logit |
| testStaticRuntime(logit_script_1, args_1); |
| testStaticRuntime(logit_script_2, args_1); |
| testStaticRuntime(logit_script_3, args_2); |
| |
| testStaticRuntime(logit_script_1, args_1, {c}); |
| testStaticRuntime(logit_script_2, args_1, {c}); |
| testStaticRuntime(logit_script_3, args_2, {c, b}); |
| } |
| |
| TEST(StaticRuntime, EmbeddingBag) { |
| const std::string embedding_bag_default = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| x, y, z, _ = torch.embedding_bag(a, b, c) |
| return (x.clone(), y.clone(), z.clone(), _.clone()) |
| )JIT"; |
| |
| const std::string embedding_bag_mean = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| x, y, z, _ = torch.embedding_bag(a, b, c, False, 1) |
| return (x.clone(), y.clone(), z.clone(), _.clone()) |
| )JIT"; |
| |
| const std::string embedding_bag_max = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| x, y, z, _ = torch.embedding_bag(a, b, c, False, 2) |
| return (x.clone(), y.clone(), z.clone(), _.clone()) |
| )JIT"; |
| |
| const std::string embedding_bag_sum_last_offset = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| x, y, z, _ = torch.embedding_bag(a, b, c, False, 0, False, None, True) |
| return (x.clone(), y.clone(), z.clone(), _.clone()) |
| )JIT"; |
| |
| const std::string embedding_bag_mean_last_offset = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| x, y, z, _ = torch.embedding_bag(a, b, c, False, 1, False, None, True) |
| return (x.clone(), y.clone(), z.clone(), _.clone()) |
| )JIT"; |
| |
| const std::string embedding_bag_max_last_offset = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| x, y, z, _ = torch.embedding_bag(a, b, c, False, 2, False, None, True) |
| return (x.clone(), y.clone(), z.clone(), _.clone()) |
| )JIT"; |
| |
| at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float); |
| at::Tensor input = torch::tensor({0, 1, 0, 2}); |
| at::Tensor offset = torch::tensor({0, 2, 4}); |
| std::vector<IValue> args{weight, input, offset}; |
| testStaticRuntime(embedding_bag_default, args); |
| testStaticRuntime(embedding_bag_mean, args); |
| testStaticRuntime(embedding_bag_max, args); |
| testStaticRuntime(embedding_bag_sum_last_offset, args); |
| testStaticRuntime(embedding_bag_mean_last_offset, args); |
| testStaticRuntime(embedding_bag_max_last_offset, args); |
| |
| at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float); |
| at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1}); |
| at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5}); |
| std::vector<IValue> args2{weight2, input2, offset2}; |
| testStaticRuntime(embedding_bag_default, args, args2); |
| testStaticRuntime(embedding_bag_mean, args, args2); |
| testStaticRuntime(embedding_bag_max, args, args2); |
| testStaticRuntime(embedding_bag_sum_last_offset, args, args2); |
| testStaticRuntime(embedding_bag_mean_last_offset, args, args2); |
| testStaticRuntime(embedding_bag_max_last_offset, args, args2); |
| } |
| |
| TEST(StaticRuntime, EmbeddingBagWithManagedOutput) { |
| const std::string embedding_bag_managed_output = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| # The outputs of embedding_bag become an intermediate tensors |
| # since they are not directly returned from the graph. |
| x, y, z, _ = torch.embedding_bag(a, b, c) |
| return x + x |
| )JIT"; |
| |
| at::Tensor weight = torch::randn({3, 8}, at::ScalarType::Float); |
| at::Tensor input = torch::tensor({0, 1, 0, 2}); |
| at::Tensor offset = torch::tensor({0, 2}); |
| std::vector<IValue> args{weight, input, offset}; |
| |
| at::Tensor weight2 = torch::randn({6, 8}, at::ScalarType::Float); |
| at::Tensor input2 = torch::tensor({0, 1, 0, 2, 3, 4}); |
| at::Tensor offset2 = torch::tensor({0, 2, 4, 5}); |
| std::vector<IValue> args2{weight2, input2, offset2}; |
| |
| testStaticRuntime(embedding_bag_managed_output, args); |
| testStaticRuntime(embedding_bag_managed_output, args, args2); |
| } |
| |
| TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) { |
| const std::string embedding_bag_default_ir = R"IR( |
| graph(%weight, %indices, %offsets): |
| %scale_grad_by_freq : bool = prim::Constant[value=0]() |
| %mode : int = prim::Constant[value=0]() |
| %sparse : bool = prim::Constant[value=0]() |
| %per_sample_weights : NoneType = prim::Constant() |
| %include_last_offset : bool = prim::Constant[value=0]() |
| %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) |
| %none : NoneType = prim::Constant() |
| %res : Tensor = aten::clone(%y0, %none) |
| return (%res) |
| )IR"; |
| auto graph = getGraphFromIR(embedding_bag_default_ir); |
| RemoveUnnecessaryOutputs(graph); |
| torch::jit::testing::FileCheck() |
| .check("static_runtime::embedding_bag") |
| ->run(*graph); |
| |
| const std::string embedding_bag_mean_ir = R"IR( |
| graph(%weight, %indices, %offsets): |
| %scale_grad_by_freq : bool = prim::Constant[value=0]() |
| %mode : int = prim::Constant[value=1]() |
| %sparse : bool = prim::Constant[value=0]() |
| %per_sample_weights : NoneType = prim::Constant() |
| %include_last_offset : bool = prim::Constant[value=0]() |
| %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) |
| %none : NoneType = prim::Constant() |
| %res : Tensor = aten::clone(%y0, %none) |
| return (%res) |
| )IR"; |
| graph = getGraphFromIR(embedding_bag_mean_ir); |
| RemoveUnnecessaryOutputs(graph); |
| torch::jit::testing::FileCheck() |
| .check("static_runtime::embedding_bag") |
| ->run(*graph); |
| |
| const std::string embedding_bag_max_last_offset_ir = R"IR( |
| graph(%weight, %indices, %offsets): |
| %scale_grad_by_freq : bool = prim::Constant[value=0]() |
| %mode : int = prim::Constant[value=2]() |
| %sparse : bool = prim::Constant[value=0]() |
| %per_sample_weights : NoneType = prim::Constant() |
| %include_last_offset : bool = prim::Constant[value=1]() |
| %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) |
| %none : NoneType = prim::Constant() |
| %res : Tensor = aten::clone(%y0, %none) |
| return (%res) |
| )IR"; |
| graph = getGraphFromIR(embedding_bag_max_last_offset_ir); |
| RemoveUnnecessaryOutputs(graph); |
| torch::jit::testing::FileCheck() |
| .check("static_runtime::embedding_bag") |
| ->run(*graph); |
| |
| const std::string embedding_bag_normal_ir = R"IR( |
| graph(%weight, %indices, %offsets): |
| %scale_grad_by_freq : bool = prim::Constant[value=0]() |
| %mode : int = prim::Constant[value=0]() |
| %sparse : bool = prim::Constant[value=0]() |
| %per_sample_weights : NoneType = prim::Constant() |
| %include_last_offset : bool = prim::Constant[value=0]() |
| %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset) |
| %none : NoneType = prim::Constant() |
| %res0 : Tensor = aten::clone(%y0, %none) |
| %res1 : Tensor = aten::clone(%y1, %none) |
| %res2 : Tensor = aten::clone(%y2, %none) |
| %res3 : Tensor = aten::clone(%y3, %none) |
| return (%res0, %res1, %res2, %res3) |
| )IR"; |
| graph = getGraphFromIR(embedding_bag_normal_ir); |
| RemoveUnnecessaryOutputs(graph); |
| torch::jit::testing::FileCheck() |
| .check_not("static_runtime::embedding_bag") |
| ->run(*graph); |
| |
| at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float); |
| at::Tensor input = torch::tensor({0, 1, 0, 2}); |
| at::Tensor offset = torch::tensor({0, 2, 4}); |
| std::vector<IValue> args{weight, input, offset}; |
| testStaticRuntime(embedding_bag_default_ir, args); |
| testStaticRuntime(embedding_bag_mean_ir, args); |
| testStaticRuntime(embedding_bag_max_last_offset_ir, args); |
| |
| at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float); |
| at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1}); |
| at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5}); |
| std::vector<IValue> args2{weight2, input2, offset2}; |
| testStaticRuntime(embedding_bag_default_ir, args, args2); |
| testStaticRuntime(embedding_bag_mean_ir, args, args2); |
| testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2); |
| } |
| |
| TEST(StaticRuntime, LayerNorm) { |
| const std::string layer_norm_with_weights = R"JIT( |
| def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor): |
| return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone() |
| )JIT"; |
| |
| const std::string layer_norm_without_weights = R"JIT( |
| def forward(self, input: Tensor, normalized_shape: List[int]): |
| return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone() |
| )JIT"; |
| |
| const auto a = torch::rand({1, 2, 2, 2}); |
| const auto b = torch::rand({3, 2, 2, 2}); |
| for (int normalized_size : {2, 3}) { |
| std::vector<int64_t> normalized_shape(normalized_size, 2); |
| const auto weight = torch::rand(normalized_shape); |
| const auto bias = torch::rand(normalized_shape); |
| |
| std::vector<IValue> args{a, normalized_shape, weight, bias}; |
| std::vector<IValue> args1{b, normalized_shape, weight, bias}; |
| testStaticRuntime(layer_norm_with_weights, args); |
| testStaticRuntime(layer_norm_with_weights, args, args1); |
| |
| args = {a, normalized_shape}; |
| testStaticRuntime(layer_norm_without_weights, args); |
| testStaticRuntime(layer_norm_without_weights, args, {b, normalized_shape}); |
| } |
| } |
| |
| TEST(StaticRuntime, Bmm) { |
| const auto bmm_script = R"JIT( |
| def forward(self, inp: Tensor, mat2: Tensor): |
| return torch.bmm(inp, mat2).clone() |
| )JIT"; |
| |
| auto a = at::randn({10, 4, 5}); |
| auto b = at::randn({10, 5, 6}); |
| |
| auto c = at::randn({12, 5, 6}); |
| auto d = at::randn({12, 6, 7}); |
| |
| std::vector<IValue> args{a, b}; |
| std::vector<IValue> args1{c, d}; |
| testStaticRuntime(bmm_script, args); |
| testStaticRuntime(bmm_script, args1); |
| testStaticRuntime(bmm_script, args, args1); |
| } |
| |
| TEST(StaticRuntime, Addmm) { |
| const auto addmm_script = R"JIT( |
| def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float): |
| return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone() |
| )JIT"; |
| auto inp1 = at::randn({5}); |
| auto mat1 = at::randn({3, 4}); |
| auto mat2 = at::randn({4, 5}); |
| |
| auto inp2 = at::randn({3, 7}); |
| auto mat3 = at::randn({3, 6}); |
| auto mat4 = at::randn({6, 7}); |
| |
| std::vector<IValue> args{inp1, mat1, mat2, 1.0, 2.0}; |
| std::vector<IValue> args1{inp2, mat3, mat4, 2.0, 1.0}; |
| testStaticRuntime(addmm_script, args); |
| testStaticRuntime(addmm_script, args1); |
| testStaticRuntime(addmm_script, args, args1); |
| } |
| |
| TEST(StaticRuntime, Abs) { |
| const auto abs_script = R"JIT( |
| def forward(self, a): |
| return a.abs().clone() |
| )JIT"; |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({4, 2, 3}); |
| std::vector<IValue> args{a}; |
| std::vector<IValue> args2{b}; |
| testStaticRuntime(abs_script, args); |
| testStaticRuntime(abs_script, args, args2); |
| } |
| |
| TEST(StaticRuntime, Binary) { |
| const auto add_script = R"JIT( |
| def forward(self, a, b): |
| c = a + b |
| return (c.clone()) |
| )JIT"; |
| |
| const auto add_script_ints = R"JIT( |
| def forward(self, a: int, b: int): |
| c = a + b |
| d = c + 1 |
| return d |
| )JIT"; |
| |
| const auto add_list_script = R"JIT( |
| def forward(self, a: List[int], b: List[int]): |
| c = a + b |
| return c[::] |
| )JIT"; |
| |
| const auto list_construct_script = R"JIT( |
| def forward(self, a, b): |
| return [a, b] |
| )JIT"; |
| |
| const auto list_construct_script_2 = R"JIT( |
| def forward(self, a, b): |
| c = a + a |
| return [c, c] |
| )JIT"; |
| |
| const auto list_construct_script_3 = R"JIT( |
| def forward(self, a, b): |
| c = a + a |
| return [c, c.flatten()] |
| )JIT"; |
| |
| const auto list_unpack_script = R"JIT( |
| def forward(self, a, b): |
| c = [a, b] |
| x, y = c |
| z = x + y |
| return z.clone() |
| )JIT"; |
| |
| const auto list_unpack_script_2 = R"JIT( |
| def forward(self, a, b): |
| c = [a, b] |
| x, y = c |
| z = (x, y) |
| return z |
| )JIT"; |
| |
| const auto tuple_construct_script = R"JIT( |
| def forward(self, a, b): |
| return (a, b) |
| )JIT"; |
| |
| const auto tuple_construct_script_2 = R"JIT( |
| def forward(self, a, b): |
| return (a.flatten(), b) |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::ones({2, 3}); |
| |
| auto c = at::randn({4, 2, 3}); |
| auto d = at::ones({4, 2, 3}); |
| |
| std::vector<IValue> args{a, b}; |
| |
| testStaticRuntime(add_script, args); |
| testStaticRuntime(add_script_ints, {1, 2}); |
| testStaticRuntime(add_script, args, {c, d}); |
| testStaticRuntime(list_construct_script, args); |
| testStaticRuntime(list_construct_script_2, args); |
| testStaticRuntime(list_construct_script_3, args); |
| testStaticRuntime(list_unpack_script, args); |
| testStaticRuntime(list_unpack_script_2, args); |
| testStaticRuntime(tuple_construct_script, args); |
| testStaticRuntime(tuple_construct_script_2, args); |
| |
| std::vector<IValue> list_args{ |
| c10::List<int64_t>{1, 2, 3}, c10::List<int64_t>{4, 5, 6}}; |
| testStaticRuntime(add_list_script, list_args); |
| } |
| |
| TEST(StaticRuntime, MatMul) { |
| const auto aten_matmul = R"JIT( |
| def forward(self, a: Tensor, b: Tensor): |
| return torch.matmul(a, b).clone() |
| )JIT"; |
| |
| // 1-D, 1-D |
| std::vector<IValue> args{at::randn({3}), at::randn({3})}; |
| testStaticRuntime(aten_matmul, args); |
| // 2-D, 2-D |
| std::vector<IValue> args1 = {at::randn({3, 2}), at::randn({2, 3})}; |
| testStaticRuntime(aten_matmul, args1); |
| // 1-D, 2-D |
| std::vector<IValue> args2 = {at::randn({3}), at::randn({3, 5})}; |
| testStaticRuntime(aten_matmul, args2); |
| // 2-D, 1-D |
| std::vector<IValue> args3 = {at::randn({3, 5}), at::randn({5})}; |
| testStaticRuntime(aten_matmul, args3); |
| // > 2-D , > 2-D |
| std::vector<IValue> args4 = {at::randn({3, 1, 4, 5}), at::randn({2, 5, 6})}; |
| testStaticRuntime(aten_matmul, args4); |
| |
| testStaticRuntime(aten_matmul, args3, args4); |
| } |
| |
| TEST(StaticRuntime, Sign) { |
| const auto sign_tensor = R"JIT( |
| def forward(self, input: Tensor): |
| return torch.sign(input).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({4, 3, 2}); |
| |
| std::vector<IValue> args{a}; |
| testStaticRuntime(sign_tensor, args); |
| testStaticRuntime(sign_tensor, args, {b}); |
| } |
| |
| TEST(StaticRuntime, Div) { |
| const auto div_tensor = R"JIT( |
| def forward(self, a: Tensor, b: Tensor): |
| return torch.div(a, b).clone() |
| )JIT"; |
| |
| const auto div_scalar = R"JIT( |
| def forward(self, a: Tensor, b: int): |
| return torch.div(a, b).clone() |
| )JIT"; |
| |
| const auto div_tensor_mode = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: str): |
| return torch.div(a, b, rounding_mode=c).clone() |
| )JIT"; |
| |
| const auto div_scalar_mode = R"JIT( |
| def forward(self, a: Tensor, b: float, c: str): |
| return torch.div(a, b, rounding_mode=c).clone() |
| )JIT"; |
| |
| const auto div_strided = R"JIT( |
| def forward(self, a: Tensor, b: Tensor): |
| a_strided = torch.transpose(a, 0, 1) |
| b_strided = torch.transpose(b, 0, 1) |
| return torch.div(a_strided, b_strided).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({2, 3}); |
| auto bs = at::randn({3, 2}).transpose(0, 1); |
| auto c = at::randn({4, 3, 2}); |
| auto d = at::randn({4, 3, 2}); |
| auto ds = at::randn({3, 4, 2}).transpose(0, 1); |
| |
| std::vector<IValue> args0{a, b}; |
| testStaticRuntime(div_tensor, args0); |
| testStaticRuntime(div_tensor, args0, {c, d}); |
| |
| testStaticRuntime(div_strided, args0); |
| testStaticRuntime(div_strided, args0, {c, d}); |
| |
| testStaticRuntime(div_tensor, {a, bs}); |
| testStaticRuntime(div_tensor, {a, bs}, {c, ds}); |
| |
| std::vector<IValue> args1{a, 3}; |
| testStaticRuntime(div_scalar, args1); |
| testStaticRuntime(div_scalar, args1, {c, 4}); |
| |
| std::vector<IValue> args2{a, b, "floor"}; |
| testStaticRuntime(div_tensor_mode, args2); |
| testStaticRuntime(div_tensor_mode, args2, {c, d, "floor"}); |
| |
| std::vector<IValue> args3{a, 2.3, "trunc"}; |
| testStaticRuntime(div_scalar_mode, args3); |
| testStaticRuntime(div_scalar_mode, args3, {c, 1.5, "trunc"}); |
| } |
| |
| TEST(StaticRuntime, Mul) { |
| const auto mul_tensor = R"JIT( |
| def forward(self, a: Tensor, b: Tensor): |
| return torch.mul(a, b).clone() |
| )JIT"; |
| |
| const auto mul_scalar = R"JIT( |
| def forward(self, a: Tensor, b: int): |
| return torch.mul(a, b).clone() |
| )JIT"; |
| |
| const auto mul_list = R"JIT( |
| def forward(self, a: List[int], n: int): |
| b = a * n |
| return b[::] |
| )JIT"; |
| |
| auto a = at::randn({3, 3}); |
| auto b = at::randn({3, 3}); |
| auto c = at::randn({3, 3, 3}); |
| auto d = at::randn({3, 3, 3}); |
| |
| std::vector<IValue> tensor_args1{a, b}; |
| std::vector<IValue> tensor_args2{c, d}; |
| |
| testStaticRuntime(mul_tensor, tensor_args1); |
| testStaticRuntime(mul_tensor, tensor_args1, tensor_args2); |
| |
| std::vector<IValue> scalar_args1{a, 42}; |
| std::vector<IValue> scalar_args2{c, 42}; |
| |
| testStaticRuntime(mul_scalar, scalar_args1); |
| testStaticRuntime(mul_scalar, scalar_args1, scalar_args2); |
| |
| std::vector<IValue> list_args{c10::List<int64_t>{1, 2}, 3}; |
| testStaticRuntime(mul_list, list_args); |
| } |
| |
| TEST(StaticRuntime, Log) { |
| const auto log_tensor = R"JIT( |
| def forward(self, inp: Tensor): |
| a = torch.log(inp).clone() |
| return (a) |
| )JIT"; |
| |
| // Ensure that the input values are valid. |
| auto a = at::abs(at::randn({2, 3})); |
| auto b = at::abs(at::randn({4, 3, 2})); |
| |
| std::vector<IValue> args{a}; |
| testStaticRuntime(log_tensor, args); |
| testStaticRuntime(log_tensor, args, {b}); |
| } |
| |
| TEST(StaticRuntime, Sub) { |
| const auto sub_tensor = R"JIT( |
| def forward(self, a: Tensor, b: Tensor): |
| return torch.sub(a, b).clone() |
| )JIT"; |
| |
| const auto sub_scalar = R"JIT( |
| def forward(self, a: Tensor, b: int): |
| return torch.sub(a, b).clone() |
| )JIT"; |
| |
| const auto sub_tensor_alpha = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: float): |
| return torch.sub(a, b, alpha=c).clone() |
| )JIT"; |
| |
| const auto sub_scalar_alpha = R"JIT( |
| def forward(self, a: Tensor, b: float, c: int): |
| return torch.sub(a, b, alpha=c).clone() |
| )JIT"; |
| |
| const auto sub_two_scalars = R"JIT( |
| def forward(self, a: int, b: int): |
| return (a - b - b) |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({2, 3}); |
| auto c = at::randn({4, 3, 2}); |
| auto d = at::randn({4, 3, 2}); |
| |
| std::vector<IValue> args0{a, b}; |
| testStaticRuntime(sub_tensor, args0); |
| testStaticRuntime(sub_tensor, args0, {c, d}); |
| |
| std::vector<IValue> args1{a, 3}; |
| testStaticRuntime(sub_scalar, args1); |
| testStaticRuntime(sub_scalar, args1, {c, 4}); |
| |
| std::vector<IValue> args2{a, b, 2.3}; |
| testStaticRuntime(sub_tensor_alpha, args2); |
| testStaticRuntime(sub_tensor_alpha, {c, d, 3.1}); |
| |
| std::vector<IValue> args3{a, 2.3, 4}; |
| testStaticRuntime(sub_scalar_alpha, args3); |
| testStaticRuntime(sub_scalar_alpha, {c, 1.3, 2}); |
| |
| std::vector<IValue> args4{1, 2}; |
| testStaticRuntime(sub_two_scalars, args4); |
| } |
| |
| TEST(StaticRuntime, NanToNum) { |
| const auto nan_to_num_script = R"JIT( |
| def forward(self, a: Tensor, nan: float, posinf: float, neginf: float): |
| return torch.nan_to_num(a, nan, posinf, neginf).clone() |
| )JIT"; |
| |
| const auto inf = std::numeric_limits<double>::infinity(); |
| const auto nan = std::numeric_limits<double>::quiet_NaN(); |
| |
| auto a = torch::tensor({{1.0, nan}, {-inf, inf}}); |
| auto b = at::randn({3, 6}); |
| float* b_data = b.data_ptr<float>(); |
| b_data[0] = nan; |
| b_data[4] = -inf; |
| b_data[11] = inf; |
| b_data[13] = nan; |
| |
| std::vector<IValue> args1{a, 1.0, 2.0, -2.0}; |
| std::vector<IValue> args2{b, 1.0, 2.0, -2.0}; |
| |
| testStaticRuntime( |
| nan_to_num_script, |
| args1, |
| /*args2*/ {}, |
| /*use_allclose*/ true, |
| /*use_equalnan*/ true); |
| testStaticRuntime( |
| nan_to_num_script, |
| args1, |
| args2, |
| /*use_allclose*/ true, |
| /*use_equalnan*/ true); |
| } |
| |
| TEST(StaticRuntime, Stack) { |
| const auto stack_dim = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, dim: int): |
| inputs = [a] |
| inputs.append(b) # mutation to avoid using VarStack |
| return torch.stack(inputs, dim = dim).clone() |
| )JIT"; |
| |
| const auto stack_three = R"JIT( |
| def forward(self, a: Tensor, b: Tensor, c: Tensor): |
| inputs = [a, b] |
| inputs.append(c) # mutation to avoid using VarStack |
| return torch.stack(inputs).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 2}); |
| auto b = at::randn({2, 2}); |
| auto c = at::randn({2, 2}); |
| |
| auto d = at::randn({3, 3, 3}); |
| auto e = at::randn({3, 3, 3}); |
| auto f = at::randn({3, 3, 3}); |
| |
| std::vector<IValue> args1_dim{a, b, 0}; |
| std::vector<IValue> args2_dim{d, e, 1}; |
| std::vector<IValue> args_dim_negative{d, e, -1}; |
| |
| std::vector<IValue> args1_three_tensors{a, b, c}; |
| std::vector<IValue> args2_three_tensors{d, e, f}; |
| |
| testStaticRuntime(stack_dim, args1_dim); |
| testStaticRuntime(stack_dim, args1_dim, args2_dim); |
| |
| testStaticRuntime(stack_dim, args_dim_negative); |
| |
| testStaticRuntime(stack_three, args1_three_tensors); |
| testStaticRuntime(stack_three, args1_three_tensors, args2_three_tensors); |
| } |
| |
| TEST(StaticRuntime, ReLU) { |
| const auto relu_script = R"JIT( |
| def forward(self, a: Tensor): |
| return torch.relu(a).clone() |
| )JIT"; |
| auto a = at::randint(-10, 10, {2, 4}); |
| auto b = at::randint(-10, 10, {3, 6}); |
| |
| std::vector<IValue> args1{a}; |
| std::vector<IValue> args2{b}; |
| |
| testStaticRuntime(relu_script, args1); |
| testStaticRuntime(relu_script, args1, args2); |
| } |
| |
| TEST(StaticRuntime, Tanh) { |
| const auto tanh_script = R"JIT( |
| def forward(self, a): |
| return torch.tanh(a).clone() |
| )JIT"; |
| auto a = at::randn({2, 2}); |
| auto b = at::randn({3, 3, 3}); |
| |
| std::vector<IValue> args1{a}; |
| std::vector<IValue> args2{b}; |
| |
| testStaticRuntime(tanh_script, args1, /*args2*/ {}, /*use_allclose*/ true); |
| testStaticRuntime(tanh_script, args1, args2, /*use_allclose*/ true); |
| } |
| |
| TEST(StaticRuntime, Norm) { |
| const auto norm_2arg = R"JIT( |
| def forward(self, a: Tensor, p: int): |
| return torch.norm(a, p).clone() |
| )JIT"; |
| |
| const auto norm_3arg = R"JIT( |
| def forward(self, a: Tensor, p: int, dtype: int): |
| return torch.norm(a, p, dtype=dtype).clone() |
| )JIT"; |
| |
| const auto norm_4arg = R"JIT( |
| def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool): |
| return torch.norm(a, p, dim, keepdim).clone() |
| )JIT"; |
| |
| const auto norm_5arg = R"JIT( |
| def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool, dtype: int): |
| return torch.norm(a, p, dim, keepdim, dtype=dtype).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({4, 3, 5}); |
| auto dim = std::vector<int64_t>({1}); |
| auto dtype = at::ScalarType::Float; |
| |
| std::vector<IValue> args2{a, 2}; |
| testStaticRuntime(norm_2arg, args2); |
| testStaticRuntime(norm_2arg, args2, {b, 2}, false, false, false); |
| |
| std::vector<IValue> args3{a, 2, dtype}; |
| testStaticRuntime(norm_3arg, args3); |
| testStaticRuntime(norm_3arg, args3, {b, 2, dtype}, false, false, false); |
| |
| std::vector<IValue> args4{a, 3, dim, false}; |
| testStaticRuntime(norm_4arg, args4); |
| testStaticRuntime(norm_4arg, args4, {b, 3, dim, false}); |
| |
| std::vector<IValue> args5{a, 4, dim, true, dtype}; |
| testStaticRuntime(norm_5arg, args5); |
| testStaticRuntime(norm_5arg, args5, {b, 4, dim, true, dtype}); |
| } |
| |
| TEST(StaticRuntime, Reshape) { |
| const auto reshape_script_1 = R"JIT( |
| def forward(self, a: Tensor, shape: List[int]): |
| b = a.reshape(shape) |
| return b + b |
| )JIT"; |
| |
| const auto reshape_script_2 = R"JIT( |
| def forward(self, a: Tensor, shape: List[int]): |
| b = a.transpose(0, 1) |
| return b.reshape(shape) |
| )JIT"; |
| |
| const auto reshape_script_3 = R"JIT( |
| def forward(self, inp: Tensor, shape: List[int]): |
| a = inp + inp |
| b = a.reshape(shape) |
| c = a.reshape(shape) |
| d = c + c |
| e = d + d |
| f = e * e |
| g = f * f |
| return b.reshape(shape), g |
| )JIT"; |
| |
| // exercise reshape_copy and flatten_copy |
| const auto reshape_script_4 = R"JIT( |
| def forward(self, inp: Tensor, shape: List[int]): |
| k = inp + inp |
| a = k + k |
| b = a.reshape(shape) |
| c = a.flatten().reshape(shape) |
| return b + c |
| )JIT"; |
| |
| // exercise reshape_copy |
| const auto reshape_script_5 = R"JIT( |
| def forward(self, inp: Tensor, shape: List[int]): |
| a = inp + inp |
| b = a.reshape(shape) |
| c = a.reshape(shape).relu() |
| d = c + c |
| e = d + d |
| f = e * e |
| g = f * f |
| return g |
| )JIT"; |
| |
| const auto reshape_inplace_script = R"JIT( |
| def forward(self, inp: Tensor, shape: List[int]): |
| a = inp + inp |
| b = a.reshape(shape) |
| c = b.sigmoid_() |
| d = c + c |
| e = a + a |
| f = b + b |
| return (d, e, f) |
| )JIT"; |
| |
| // b is in_contiguous |
| const auto reshape_incontiguous_script = R"JIT( |
| def forward(self, a: Tensor, shape: List[int]): |
| b = a.transpose(0, 1) |
| c = b.reshape(shape) |
| c = c.relu() |
| return (c) |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = std::vector<int64_t>({3, 2}); |
| std::vector<IValue> args{a, b}; |
| |
| auto c = at::randn({4, 5}); |
| auto d = std::vector<int64_t>({5, 1, 2, 2}); |
| std::vector<IValue> args1{c, d}; |
| |
| testStaticRuntime(reshape_script_1, args); |
| testStaticRuntime(reshape_script_2, args); |
| testStaticRuntime(reshape_script_3, args); |
| testStaticRuntime(reshape_script_4, args); |
| testStaticRuntime(reshape_script_5, args); |
| testStaticRuntime(reshape_inplace_script, args); |
| testStaticRuntime(reshape_incontiguous_script, args); |
| |
| testStaticRuntime(reshape_script_1, args, args1); |
| testStaticRuntime(reshape_script_2, args, args1); |
| testStaticRuntime(reshape_script_3, args, args1); |
| testStaticRuntime(reshape_script_4, args, args1); |
| testStaticRuntime(reshape_script_5, args, args1); |
| testStaticRuntime(reshape_inplace_script, args, args1); |
| testStaticRuntime(reshape_incontiguous_script, args, args1); |
| } |
| |
| TEST(StaticRuntime, Repeat) { |
| const std::string repeat = R"JIT( |
| def forward(self, a: Tensor, repeats: List[int]): |
| return torch.repeat(a, repeats).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({4, 3}); |
| auto c = std::vector<int64_t>({1, 2}); |
| auto d = std::vector<int64_t>({2, 3}); |
| std::vector<IValue> args1{a, c}; |
| std::vector<IValue> args2{b, d}; |
| |
| testStaticRuntime(repeat, args1); |
| testStaticRuntime(repeat, args2); |
| testStaticRuntime(repeat, args1, args2); |
| } |
| |
| TEST(StaticRuntime, Flatten) { |
| // exercise flatten_copy |
| const auto flatten_script_1 = R"JIT( |
| def forward(self, a: Tensor, start_dim: int, end_dim: int): |
| b = a * a |
| c = torch.flatten(b, start_dim, end_dim) |
| d = torch.relu(c) |
| return d |
| )JIT"; |
| |
| const auto flatten_script_2 = R"JIT( |
| def forward(self, a: Tensor, start_dim: int, end_dim: int): |
| b = a.transpose(0, 1) |
| return torch.flatten(b, start_dim, end_dim).clone() |
| )JIT"; |
| |
| auto test_flatten = |
| [&](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) { |
| std::vector<int64_t> shape1(shape); |
| if (shape1.size() > 0) { |
| shape1[0] *= 6; |
| } |
| auto a = at::randn(shape); |
| auto b = at::randn(shape1); |
| std::vector<IValue> args{a, start_dim, end_dim}; |
| bool check_resize = shape1.size() > 0; |
| testStaticRuntime(flatten_script_1, args); |
| testStaticRuntime( |
| flatten_script_1, |
| args, |
| {b, start_dim, end_dim}, |
| false, /* use_allclose */ |
| false, /* use_equalnan */ |
| check_resize); |
| if (shape.size() > 2) { |
| testStaticRuntime(flatten_script_2, args); |
| testStaticRuntime(flatten_script_2, args, {b, start_dim, end_dim}); |
| } |
| }; |
| |
| test_flatten({2, 3}, 0, 1); |
| test_flatten({2, 1, 3}, 1, 2); |
| test_flatten({0, 1, 3, 0}, 1, 2); |
| test_flatten({2, 3}, 1, 1); |
| test_flatten({}, 0, 0); |
| } |
| |
| TEST(StaticRuntime, pow) { |
| const auto pow_script_ten_sca = R"JIT( |
| def forward(self, input : Tensor, exponent : int): |
| return torch.pow(input, exponent).clone() |
| )JIT"; |
| |
| const auto pow_script_ten_ten = R"JIT( |
| def forward(self, input : Tensor, exponent : Tensor): |
| return torch.pow(input, exponent).clone() |
| )JIT"; |
| |
| const auto pow_script_sca_ten = R"JIT( |
| def forward(self, input : int, exponent : Tensor): |
| return torch.pow(input, exponent).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({2, 3}); |
| auto c = at::randn({4, 3, 2}); |
| auto d = at::randn({4, 3, 2}); |
| |
| std::vector<IValue> args0{a, 4}; |
| testStaticRuntime(pow_script_ten_sca, args0); |
| testStaticRuntime(pow_script_ten_sca, args0, {c, 4}); |
| |
| std::vector<IValue> args1{at::abs(a), b}; |
| testStaticRuntime(pow_script_ten_ten, args1); |
| testStaticRuntime(pow_script_ten_ten, args1, {at::abs(c), d}); |
| |
| std::vector<IValue> args2{5, b}; |
| testStaticRuntime(pow_script_sca_ten, args2); |
| testStaticRuntime(pow_script_sca_ten, args2, {3, d}); |
| } |
| |
| TEST(StaticRuntime, to) { |
| const auto to_script_dtype = R"JIT( |
| def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int): |
| a = input + input |
| return torch.to(a, dtype, non_blocking, copy, memory_format).clone() |
| )JIT"; |
| |
| const auto to_script_dtype_strided = R"JIT( |
| def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int): |
| b = input.permute(0, 2, 3, 1) |
| return torch.to(b, dtype, non_blocking, copy, memory_format).clone() |
| )JIT"; |
| |
| const auto to_script_prim_dtype = R"JIT( |
| def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool): |
| a = input + input |
| return torch.to(a, dtype, non_blocking, copy).clone() |
| )JIT"; |
| |
| const auto to_script_other = R"JIT( |
| def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int): |
| a = input + input |
| return torch.to(a, other, non_blocking, copy, memory_format).clone() |
| )JIT"; |
| |
| // if input is float tensor, b could be alias of a |
| const auto to_script_alias = R"JIT( |
| def forward(self, input:Tensor): |
| a = input + input |
| b = a.float() |
| c = b * b |
| return (c) |
| )JIT"; |
| |
| const auto to_script_fails_managed_output_check = R"JIT( |
| def forward(self, a, b): |
| d = a.half() * b.half() |
| e = d.float() |
| return e |
| )JIT"; |
| |
| const auto to_script_select_tensor_output_into_tuple = R"JIT( |
| def forward(self, a, b): |
| d = a.half() * b.half() |
| e = d.float() |
| return (d, e) |
| )JIT"; |
| |
| const auto to_script_memory_planning_fail = R"JIT( |
| def forward(self, a, b): |
| d = a.half() * b.half() |
| e = d.float().relu() |
| return e |
| )JIT"; |
| |
| auto test_to = [&](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) { |
| auto a = at::randn({4, 3, 1, 2}); |
| auto other = at::randn({4, 3, 1, 2}).to(b); |
| auto a2 = at::randn({3, 2, 2, 4}); |
| auto a2_other = at::randn({3, 2, 2, 4}).to(b); |
| |
| std::vector<IValue> args0{a, b, c, d, e}; |
| std::vector<IValue> args1{a, b, c, d}; |
| std::vector<IValue> args2{a, other, c, d, e}; |
| std::vector<IValue> args2WithDifferentOtherType{ |
| a, at::randn({4, 3, 1, 2}, ScalarType::Double), c, d, e}; |
| std::vector<IValue> args3{a, c10::nullopt, c, d}; |
| |
| std::vector<IValue> args0WithInt{a, ScalarType::Int, c, d, e}; |
| testStaticRuntime( |
| to_script_dtype, |
| args0, |
| args0WithInt, |
| /* default for use_allclose */ false, |
| /* default for use_equalnan */ false, |
| /* check_resize */ false); |
| testStaticRuntime(to_script_dtype_strided, args0); |
| testStaticRuntime(to_script_prim_dtype, args1); |
| if (!d) { |
| testStaticRuntime(to_script_prim_dtype, args3); |
| } |
| // Second set of args tests case where the `other` tensor's dtype |
| // changes between iterations. |
| testStaticRuntime( |
| to_script_other, |
| args2, |
| args2WithDifferentOtherType, |
| /* default for use_allclose */ false, |
| /* default for use_equalnan */ false, |
| /* check_resize */ false); |
| testStaticRuntime(to_script_alias, {a}); |
| |
| testStaticRuntime(to_script_memory_planning_fail, {a, a}); |
| testStaticRuntime(to_script_fails_managed_output_check, {a, a}); |
| testStaticRuntime(to_script_select_tensor_output_into_tuple, {a, a}); |
| |
| // dynamic shapes |
| testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e}); |
| testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e}); |
| testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d}); |
| if (!d) { |
| testStaticRuntime(to_script_prim_dtype, args3, {a2, c10::nullopt, c, d}); |
| } |
| testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e}); |
| testStaticRuntime(to_script_alias, {a}, {a2}); |
| }; |
| for (const bool non_blocking : {false, true}) { |
| for (const bool copy : {false, true}) { |
| // float->float, NCHW->NHWC |
| test_to( |
| at::ScalarType::Float, |
| non_blocking, |
| copy, |
| c10::MemoryFormat::ChannelsLast); |
| // float->half |
| test_to( |
| at::ScalarType::Half, |
| non_blocking, |
| copy, |
| c10::MemoryFormat::Preserve); |
| // float->float |
| test_to( |
| at::ScalarType::Float, |
| non_blocking, |
| copy, |
| c10::MemoryFormat::Contiguous); |
| test_to( |
| at::ScalarType::Bool, |
| non_blocking, |
| copy, |
| c10::MemoryFormat::Contiguous); |
| // TODO: check if fbgemm is enabled properly in this case |
| // half->float, NCHW->NHWC |
| test_to( |
| at::ScalarType::Half, |
| non_blocking, |
| copy, |
| c10::MemoryFormat::ChannelsLast); |
| } |
| } |
| } |
| |
| TEST(StaticRuntime, ExpandAs) { |
| const auto expand_as_script = R"JIT( |
| def forward(self, input: Tensor, other:Tensor): |
| a = input.expand_as(other) |
| return a.clone() |
| )JIT"; |
| |
| auto a = at::randn({3, 1}); |
| auto b = at::randn({3, 2}); |
| auto c = at::randn({4, 1}); |
| auto d = at::randn({4, 2}); |
| std::vector<IValue> args{a, b}; |
| std::vector<IValue> args2{c, d}; |
| testStaticRuntime(expand_as_script, args); |
| testStaticRuntime(expand_as_script, args, args2); |
| } |
| |
| TEST(StaticRuntime, Full) { |
| const auto full_script = R"JIT( |
| def forward(self, |
| size: List[int], |
| fill_value: int, |
| dtype: Optional[int], |
| layout: Optional[int], |
| device: Optional[Device], |
| pin_memory: Optional[bool]): |
| a = torch.full(size, |
| fill_value, |
| dtype=dtype, |
| layout=layout, |
| device=device, |
| pin_memory=pin_memory) |
| return (a.clone()) |
| )JIT"; |
| |
| auto cpu = at::Device(DeviceType::CPU); |
| c10::List<int64_t> size0{2, 5}; |
| std::vector<IValue> args{ |
| size0, 4, at::ScalarType::Int, at::kStrided, cpu, false}; |
| std::vector<IValue> args1{ |
| size0, 4, at::ScalarType::Float, at::kStrided, cpu, false}; |
| c10::List<int64_t> size1{5, 6}; |
| std::vector<IValue> args2{ |
| size1, 5, at::ScalarType::Float, at::kStrided, cpu, false}; |
| testStaticRuntime(full_script, args); |
| testStaticRuntime( |
| full_script, |
| args, |
| args1, |
| /*use_allclose=*/false, |
| /*use_equalnan=*/false, |
| /*check_resize=*/false); |
| testStaticRuntime(full_script, args, args2); |
| } |
| |
| TEST(StaticRuntime, FullLike) { |
| const auto full_like_script = R"JIT( |
| def forward(self, |
| a: Tensor, |
| fill_value: int, |
| dtype: Optional[int], |
| layout: Optional[int], |
| device: Optional[Device], |
| pin_memory: Optional[bool], |
| memory_format: Optional[int]): |
| b = torch.full_like(a, |
| fill_value, |
| dtype=dtype, |
| layout=layout, |
| device=device, |
| pin_memory=pin_memory, |
| memory_format=memory_format) |
| return (b.clone()) |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({3, 4, 2}); |
| auto cpu = at::Device(DeviceType::CPU); |
| std::vector<IValue> args{ |
| a, |
| 4, |
| at::ScalarType::Int, |
| at::kStrided, |
| cpu, |
| false, |
| c10::MemoryFormat::Contiguous}; |
| std::vector<IValue> args1{ |
| a, |
| 4, |
| at::ScalarType::Float, |
| at::kStrided, |
| cpu, |
| false, |
| c10::MemoryFormat::Contiguous}; |
| std::vector<IValue> args2{ |
| b, |
| 4, |
| at::ScalarType::Float, |
| at::kStrided, |
| cpu, |
| false, |
| c10::MemoryFormat::Contiguous}; |
| testStaticRuntime(full_like_script, args); |
| testStaticRuntime( |
| full_like_script, |
| args, |
| args1, |
| /*use_allclose=*/false, |
| /*use_equalnan=*/false, |
| /*check_resize=*/false); |
| testStaticRuntime(full_like_script, args, args2); |
| } |
| |
| TEST(StaticRuntime, Ones) { |
| const auto script = R"JIT( |
| def forward(self, |
| size: List[int], |
| dtype: Optional[int], |
| layout: Optional[int], |
| device: Optional[Device], |
| pin_memory: Optional[bool]): |
| a = torch.ones(size, |
| dtype=dtype, |
| layout=layout, |
| device=device, |
| pin_memory=pin_memory) |
| return (a.clone()) |
| )JIT"; |
| |
| auto dtype = at::ScalarType::Int; |
| auto cpu = at::Device(DeviceType::CPU); |
| c10::List<int64_t> size0{2, 5}; |
| std::vector<IValue> args{size0, dtype, at::kStrided, cpu, false}; |
| c10::List<int64_t> size1{5, 6}; |
| std::vector<IValue> args2{size1, dtype, at::kStrided, cpu, false}; |
| testStaticRuntime(script, args); |
| testStaticRuntime(script, args, args2); |
| } |
| |
| TEST(StaticRuntime, OnesLike) { |
| const auto script = R"JIT( |
| def forward(self, |
| input: Tensor, |
| dtype: Optional[int], |
| layout: Optional[int], |
| device: Optional[Device], |
| pin_memory: Optional[bool], |
| memory_format: Optional[int]): |
| a = torch.ones_like(input, |
| dtype=dtype, |
| layout=layout, |
| device=device, |
| pin_memory=pin_memory, |
| memory_format=memory_format) |
| return (a.clone()) |
| )JIT"; |
| |
| auto cpu = at::Device(DeviceType::CPU); |
| auto input0 = at::randn({2, 5}); |
| std::vector<IValue> args{ |
| input0, |
| at::ScalarType::Int, |
| at::kStrided, |
| cpu, |
| false, |
| c10::MemoryFormat::Contiguous}; |
| std::vector<IValue> args1{ |
| input0, |
| at::ScalarType::Float, |
| at::kStrided, |
| cpu, |
| false, |
| c10::MemoryFormat::Contiguous}; |
| auto input1 = at::randn({5, 6}); |
| std::vector<IValue> args2{ |
| input1, |
| at::ScalarType::Float, |
| at::kStrided, |
| cpu, |
| false, |
| c10::MemoryFormat::Contiguous}; |
| testStaticRuntime(script, args); |
| testStaticRuntime( |
| script, |
| args, |
| args1, |
| /*use_allclose=*/false, |
| /*use_equalnan=*/false, |
| /*check_resize=*/false); |
| testStaticRuntime(script, args, args2); |
| } |
| |
| TEST(StaticRuntime, Zeros) { |
| const auto script = R"JIT( |
| def forward(self, |
| size: List[int], |
| dtype: Optional[int], |
| layout: Optional[int], |
| device: Optional[Device], |
| pin_memory: Optional[bool]): |
| a = torch.zeros(size, |
| dtype=dtype, |
| layout=layout, |
| device=device, |
| pin_memory=pin_memory) |
| return (a.clone()) |
| )JIT"; |
| |
| auto cpu = at::Device(DeviceType::CPU); |
| c10::List<int64_t> size0{2, 5}; |
| std::vector<IValue> args{ |
| size0, at::ScalarType::Int, at::kStrided, cpu, false}; |
| std::vector<IValue> args1{ |
| size0, at::ScalarType::Float, at::kStrided, cpu, false}; |
| c10::List<int64_t> size1{5, 6}; |
| std::vector<IValue> args2{ |
| size1, at::ScalarType::Float, at::kStrided, cpu, false}; |
| testStaticRuntime(script, args); |
| testStaticRuntime( |
| script, |
| args, |
| args1, |
| /*use_allclose=*/false, |
| /*use_equalnan=*/false, |
| /*check_resize=*/false); |
| testStaticRuntime(script, args, args2); |
| } |
| |
| TEST(StaticRuntime, Linear) { |
| const auto linear_script = R"JIT( |
| def forward(self, inp: Tensor, weights: Tensor, bias: Optional[Tensor]) -> Tensor: |
| return torch.linear(inp, weights, bias).clone() |
| )JIT"; |
| |
| auto input = at::randn({1, 2}); |
| auto weights = at::randn({1, 2}); |
| auto bias = at::randn({1, 1}); |
| |
| std::vector<IValue> args{input, weights, bias}; |
| std::vector<IValue> args_no_bias{input, weights, c10::nullopt}; |
| |
| auto input2 = at::randn({6, 3}); |
| auto weights2 = at::randn({6, 3}); |
| auto bias2 = at::randn({6, 6}); |
| |
| std::vector<IValue> args2{input2, weights2, bias2}; |
| std::vector<IValue> args2_no_bias{input2, weights2, c10::nullopt}; |
| |
| testStaticRuntime(linear_script, args); |
| testStaticRuntime(linear_script, args_no_bias); |
| |
| testStaticRuntime(linear_script, args, args2); |
| testStaticRuntime(linear_script, args, args2_no_bias); |
| } |
| |
| TEST(StaticRuntime, VarCat) { |
| const auto var_cat_script = R"JIT( |
| def forward(self, inp1: Tensor, inp2: Tensor, dim: int): |
| return torch.cat([inp1, inp2], dim).clone() |
| )JIT"; |
| |
| // 2D tensors - cat dim = 0 |
| std::vector<IValue> args1 = {at::randn({4, 6}), at::randn({5, 6}), 0}; |
| testStaticRuntime(var_cat_script, args1); |
| |
| // 3D tensors - cat dim = 1 |
| std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 8, 6}), 1}; |
| testStaticRuntime(var_cat_script, args2); |
| |
| // 3D tensors - cat dim = 2 |
| std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2}; |
| testStaticRuntime(var_cat_script, args3); |
| |
| // Negative dim |
| std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), -1}; |
| testStaticRuntime(var_cat_script, args4); |
| |
| testStaticRuntime(var_cat_script, args1, args2); |
| } |
| |
| TEST(StaticRuntime, LeakyReLU) { |
| torch::jit::Module mod = getLeakyReLUConstScriptModel(); |
| auto inputs = torch::randn({2, 2}); |
| |
| // run jit graph executor |
| std::vector<at::IValue> input_ivalues({inputs}); |
| at::Tensor output_1 = mod.forward(input_ivalues).toTensor(); |
| |
| // run static runtime |
| std::vector<c10::IValue> input_tensors({inputs}); |
| torch::jit::StaticModule smod(mod); |
| at::Tensor output_2 = smod(input_tensors, {}).toTensor(); |
| smod.runtime().check_for_memory_leak(); |
| EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6)); |
| } |
| |
| static ProcessedNodeInputs createProcessedNodeInputs( |
| c10::ArrayRef<uint16_t> inputs) { |
| ProcessedNodeInputs result(inputs.size()); |
| for (const auto idx : c10::irange(inputs.size())) { |
| result[idx] = inputs[idx]; |
| } |
| return result; |
| } |
| |
| static void checkProcessedNodeInputs( |
| const ProcessedNodeInputs& io, |
| c10::ArrayRef<uint16_t> inputs) { |
| ASSERT_EQ(inputs.size(), io.size()); |
| for (const auto idx : c10::irange(inputs.size())) { |
| EXPECT_EQ(inputs[idx], io[idx]); |
| } |
| } |
| |
| static void testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs) { |
| auto io = createProcessedNodeInputs(inputs); |
| checkProcessedNodeInputs(io, inputs); |
| |
| ProcessedNodeInputs copied(io); |
| checkProcessedNodeInputs(copied, inputs); |
| ProcessedNodeInputs moved(std::move(io)); |
| checkProcessedNodeInputs(moved, inputs); |
| } |
| |
| TEST(ProcessedNodeInputs, Basic) { |
| std::vector<std::vector<uint16_t>> testCases = { |
| {}, // empty |
| {0xABCD, 0x5a5a}, // inline |
| {0x11, 0x22, 0x33, 0x44, 0x55}, // max inline size |
| {0x11, 0x22, 0x33, 0x44, 0x55, 0x66}, // minimum outline size |
| std::vector<uint16_t>(100, 0x5a), // large outline size |
| }; |
| |
| for (const auto& values : testCases) { |
| testProcessedNodeInputsRoundTrip(values); |
| for (const auto& values2 : testCases) { |
| auto from = createProcessedNodeInputs(values); |
| auto to = createProcessedNodeInputs(values2); |
| |
| to = from; |
| checkProcessedNodeInputs(to, values); |
| |
| auto toMoveInto = createProcessedNodeInputs(values2); |
| toMoveInto = std::move(from); |
| checkProcessedNodeInputs(toMoveInto, values); |
| } |
| } |
| } |
| |
| TEST(StaticRuntime, isinstance) { |
| const auto isinstance_int_script = R"JIT( |
| def forward(self, a: Any): |
| return isinstance(a, int) |
| )JIT"; |
| |
| const auto isinstance_tensor_script = R"JIT( |
| def forward(self, a: Any): |
| return isinstance(a, torch.Tensor) |
| )JIT"; |
| |
| const auto isinstance_many_types_script = R"JIT( |
| def forward(self, a: Any): |
| return isinstance(a, (bool, int)) |
| )JIT"; |
| |
| auto a = at::randn({2, 2}); |
| auto b = at::randn({2, 2, 2}); |
| |
| std::vector<at::IValue> args{a}; |
| std::vector<at::IValue> args2{b}; |
| |
| testStaticRuntime(isinstance_int_script, args); |
| testStaticRuntime(isinstance_int_script, args, args2); |
| |
| testStaticRuntime(isinstance_tensor_script, args); |
| testStaticRuntime(isinstance_tensor_script, args, args2); |
| |
| testStaticRuntime(isinstance_many_types_script, args); |
| testStaticRuntime(isinstance_many_types_script, args, args2); |
| } |
| |
| TEST(StaticRuntime, TypeCheck) { |
| const auto typecheck_ir = R"IR( |
| graph(%a.1 : Tensor, |
| %b.1 : Tensor): |
| %t0 : Float(2, 2, strides=[2, 1], device=cpu), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu), Float(3, 3, strides=[3, 1])]](%a.1, %b.1) |
| return (%t0, %t1, %type_matched) |
| )IR"; |
| |
| auto a = at::zeros({2, 2}, at::kFloat); |
| a.to(at::kCPU); |
| auto b = at::ones({3, 3}, at::kFloat); |
| auto c = at::ones({2, 2, 2}, at::kFloat); |
| |
| std::vector<IValue> args_correct = {a, b}; |
| std::vector<IValue> args_incorrect = {a, c}; |
| |
| testStaticRuntime(typecheck_ir, args_correct); |
| testStaticRuntime(typecheck_ir, args_correct, args_incorrect); |
| } |
| |
| TEST(StaticRuntime, Index) { |
| const auto index_without_none_script = R"JIT( |
| def forward(self, a: Tensor, idx: Tensor): |
| return a[idx].clone() |
| )JIT"; |
| |
| // Index with boolean mask |
| auto a = at::arange(4, at::kFloat).view({2, 2}); |
| auto idx_a = torch::tensor({{0, 1}, {0, 0}}, at::kBool); |
| std::vector<IValue> args_a{a, idx_a}; |
| |
| // Index with tensor |
| auto b = at::arange(27, at::kFloat).view({3, 3, 3}); |
| auto idx_b = torch::tensor({0, 1, 2}, at::kLong); |
| std::vector<IValue> args_b{b, idx_b}; |
| |
| testStaticRuntime(index_without_none_script, args_a); |
| testStaticRuntime(index_without_none_script, args_a, args_b); |
| |
| const auto index_with_none_script = R"JIT( |
| def forward(self, a: Tensor, idx: Tensor, none: Optional[Tensor]): |
| return a[idx, none].clone() |
| )JIT"; |
| |
| // Index with None |
| // When indexing with none, the shape of `f` becomes [2, 1, 2], |
| // so the mask must be reshaped appropriately. |
| auto f = at::arange(4, at::kFloat).view({2, 1, 2}); |
| auto idx_f_reshape = torch::tensor({{{0, 1}}, {{0, 0}}}, at::kBool); |
| std::vector<IValue> args_f_with_none{f, idx_f_reshape}; |
| args_f_with_none.emplace_back(); |
| |
| testStaticRuntime(index_with_none_script, args_f_with_none); |
| testStaticRuntime( |
| index_with_none_script, |
| args_f_with_none, |
| {IValue(b), IValue(idx_b), IValue()}); |
| |
| const auto index_with_two_tensors_script = R"JIT( |
| def forward(self, a: Tensor, idx_a: Tensor, idx_b: Tensor): |
| return a[idx_a, idx_b].clone() |
| )JIT"; |
| |
| // Index with multiple tensors |
| const auto& c = a; // 2x2 tensor |
| auto idx_c1 = torch::tensor({0, 0}, at::kLong); |
| auto idx_c2 = torch::tensor({0}, at::kLong); |
| std::vector<IValue> args_c{c, idx_c1, idx_c2}; |
| |
| const auto& d = b; // 3x3x3 tensor |
| auto idx_d1 = torch::tensor({{0, 0, 2}, {0, 1, 1}}, at::kLong); |
| auto idx_d2 = torch::tensor({{1, 1, 0}, {1, 0, 2}}, at::kLong); |
| std::vector<IValue> args_d{d, idx_d1, idx_d2}; |
| |
| testStaticRuntime(index_with_two_tensors_script, args_c, args_d); |
| } |
| |
| TEST(StaticRuntime, IndexSelect) { |
| const std::string script = R"IR( |
| graph(%self: Tensor, %dim: int, %index: Tensor): |
| %bias: None = prim::Constant() |
| %ret = aten::index_select(%self, %dim, %index) |
| %cloned = aten::clone(%ret, %bias) |
| return (%cloned) |
| )IR"; |
| |
| auto self0 = at::rand({6}); |
| auto dim0 = 0; |
| auto index0 = at::randint(0, 5, {6}, torch::kInt32); |
| std::vector<IValue> args{self0, dim0, index0}; |
| testStaticRuntime(script, args); |
| |
| auto self1 = at::rand({128}); |
| auto dim1 = 0; |
| auto index1 = at::randint(0, 127, {127}, torch::kInt32); |
| std::vector<IValue> args2{self1, dim1, index1}; |
| testStaticRuntime(script, args, args2); |
| } |
| |
| TEST(StaticRuntime, ClampMin) { |
| const auto clamp_min_int_script = R"JIT( |
| def forward(self, a: Tensor, b: int): |
| return torch.clamp_min(a, b).clone() |
| )JIT"; |
| |
| const auto clamp_min_float_script = R"JIT( |
| def forward(self, a: Tensor, b: float): |
| return torch.clamp_min(a, b).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 2}); |
| auto b = at::randn({3, 3, 3}); |
| int scalar_int = 1; |
| float scalar_float = 3.14; |
| |
| std::vector<IValue> args_a_int{a, scalar_int}; |
| std::vector<IValue> args_b_int{b, scalar_int}; |
| |
| testStaticRuntime(clamp_min_int_script, args_a_int); |
| testStaticRuntime(clamp_min_int_script, args_a_int, args_b_int); |
| |
| std::vector<IValue> args_a_float{a, scalar_float}; |
| std::vector<IValue> args_b_float{b, scalar_float}; |
| |
| testStaticRuntime(clamp_min_float_script, args_a_float); |
| testStaticRuntime(clamp_min_float_script, args_a_float, args_b_float); |
| } |
| |
| TEST(StaticRuntime, Argmin) { |
| const auto argmin_script = R"JIT( |
| def forward(self, a: Tensor): |
| return torch.argmin(a).clone() |
| )JIT"; |
| |
| const auto argmin_with_dim_script = R"JIT( |
| def forward(self, a: Tensor, dim: int): |
| return torch.argmin(a, dim).clone() |
| )JIT"; |
| |
| const auto argmin_with_keep_dim_script = R"JIT( |
| def forward(self, a: Tensor, dim: int): |
| return torch.argmin(a, dim, True).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 2}); |
| auto b = at::randn({17, 2, 1}); |
| |
| testStaticRuntime(argmin_script, {a}); |
| testStaticRuntime( |
| argmin_script, |
| {a}, |
| {b}, |
| /* use_allclose */ false, |
| /* use_equalnan */ false, |
| /* check_resize */ false); |
| |
| int dim_a = 0; |
| int dim_b = 1; |
| |
| std::vector<IValue> args_a{a, dim_a}; |
| std::vector<IValue> args_b{b, dim_b}; |
| |
| testStaticRuntime(argmin_with_dim_script, args_a); |
| testStaticRuntime(argmin_with_dim_script, args_a, args_b); |
| |
| testStaticRuntime(argmin_with_keep_dim_script, args_a); |
| testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b); |
| } |
| |
| TEST(StaticRuntime, Softmax) { |
| const auto softmax_script = R"JIT( |
| def forward(self, a: Tensor, dim: int): |
| return torch.softmax(a, dim).clone() |
| )JIT"; |
| |
| const auto softmax_script_with_dtype = R"JIT( |
| def forward(self, a: Tensor, dim: int, dtype: int): |
| return torch.softmax(a, dim, dtype=dtype).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({3, 3, 3}); |
| |
| testStaticRuntime(softmax_script, {a, 0}); |
| testStaticRuntime(softmax_script, {a, 1}); |
| |
| testStaticRuntime(softmax_script, {b, 0}); |
| testStaticRuntime(softmax_script, {b, 1}); |
| testStaticRuntime(softmax_script, {b, 2}); |
| |
| testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float}); |
| testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float}); |
| } |
| |
| TEST(StaticRuntime, GetItem_Dict) { |
| const auto getitem_dict_tensor_script = R"JIT( |
| def forward(self, key: Tensor): |
| d = {key: 1} |
| return d[key] |
| )JIT"; |
| |
| const auto getitem_dict_int_script = R"JIT( |
| def forward(self, key: int): |
| d = {key: 1} |
| return d[key] |
| )JIT"; |
| |
| const auto getitem_dict_str_script = R"JIT( |
| def forward(self, key: str): |
| d = {key: 1} |
| return d[key] |
| )JIT"; |
| |
| int int_key = 0; |
| std::string str_key = "str"; |
| |
| // No need to test these multiple times, args are not tensors |
| testStaticRuntime(getitem_dict_int_script, {int_key}); |
| testStaticRuntime(getitem_dict_str_script, {str_key}); |
| |
| auto a = torch::tensor({1}); |
| auto b = torch::tensor({1, 1}); |
| |
| testStaticRuntime(getitem_dict_tensor_script, {a}); |
| testStaticRuntime(getitem_dict_tensor_script, {a}, {b}); |
| } |
| |
| TEST(StaticRuntime, GetItem_List) { |
| const auto getitem_list_int_script = R"JIT( |
| def forward(self, idx: int): |
| lst = [1, 2, 3] |
| return lst[idx] |
| )JIT"; |
| |
| const auto getitem_list_tensor_script = R"JIT( |
| def forward(self, tensor: Tensor, idx: int): |
| lst = [tensor, tensor] |
| return lst[idx] |
| )JIT"; |
| |
| testStaticRuntime(getitem_list_int_script, {1}); |
| testStaticRuntime(getitem_list_int_script, {-1}); |
| |
| auto a = torch::tensor({1}); |
| auto b = torch::tensor({1, 1}); |
| |
| testStaticRuntime(getitem_list_tensor_script, {a, 1}); |
| testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1}); |
| } |
| |
| TEST(StaticRuntime, Transpose) { |
| const auto transpose_script = R"JIT( |
| def forward(self, a: Tensor, dim1: int, dim2: int): |
| return torch.transpose(a, dim1, dim2).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 2}); |
| int dim1_a = 0; |
| int dim2_a = 1; |
| std::vector<IValue> args_a{a, dim1_a, dim2_a}; |
| |
| auto b = at::randn({3, 3, 3}); |
| int dim1_b = 0; |
| int dim2_b = 2; |
| std::vector<IValue> args_b{b, dim1_b, dim2_b}; |
| |
| testStaticRuntime(transpose_script, args_a); |
| testStaticRuntime(transpose_script, args_a, args_b); |
| } |
| |
| TEST(StaticRuntime, Permute) { |
| auto permute_script = R"JIT( |
| def forward(self, a: Tensor, dims: List[int]): |
| return torch.permute(a, dims).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 2}); |
| c10::List<int64_t> dims_a{1, 0}; |
| std::vector<IValue> args_a{a, dims_a}; |
| |
| auto b = at::randn({3, 3, 3}); |
| c10::List<int64_t> dims_b{0, 2, 1}; |
| std::vector<IValue> args_b{b, dims_b}; |
| |
| auto c = at::randn({3, 3, 3}); |
| c10::List<int64_t> dims_c{0, -1, 1}; |
| std::vector<IValue> args_c{c, dims_c}; |
| |
| testStaticRuntime(permute_script, args_a); |
| testStaticRuntime(permute_script, args_c); |
| testStaticRuntime(permute_script, args_a, args_b); |
| |
| permute_script = R"JIT( |
| def forward(self, a: Tensor, dims: List[int], shape: List[int]): |
| return torch.permute(a, dims).reshape(shape).clone() |
| )JIT"; |
| |
| a = at::randn({8, 16, 4}); |
| dims_a = {0, 2, 1}; |
| dims_b = {-1, 16}; |
| testStaticRuntime(permute_script, {a, dims_a, dims_b}); |
| } |
| |
| TEST(StaticRuntime, Slice) { |
| const auto slice_script = R"JIT( |
| def forward(self, a: Tensor, dim: int, start: int, end: int, step: int): |
| return a.slice(dim, start, end, step).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 2}); |
| int dim_a = 1; |
| int start_a = 0; |
| int end_a = 1; |
| int step_a = 1; |
| std::vector<IValue> args_a{a, dim_a, start_a, end_a, step_a}; |
| |
| auto b = at::randn({3, 3, 3}); |
| int dim_b = 2; |
| int start_b = 0; |
| int end_b = 1; |
| int step_b = 2; |
| std::vector<IValue> args_b{b, dim_b, start_b, end_b, step_b}; |
| |
| testStaticRuntime(slice_script, args_a); |
| testStaticRuntime(slice_script, args_a, args_b); |
| |
| const auto slice_script2 = R"JIT( |
| def forward(self, a: Tensor, dim: int, step: int): |
| return a.slice(dim, None, None, step).clone() |
| )JIT"; |
| std::vector<IValue> args_c{b, dim_b, step_b}; |
| testStaticRuntime(slice_script2, args_c); |
| } |
| |
| TEST(StaticRuntime, Narrow) { |
| const auto narrow_with_int_script = R"JIT( |
| def forward(self, a: Tensor, dim: int, start: int, length: int): |
| return a.narrow(dim, start, length).clone() |
| )JIT"; |
| |
| auto a = at::randn({5, 5}); |
| int dim_a = 0; |
| int start_a_int = 3; |
| int len_a = 2; |
| std::vector<IValue> args_a{a, dim_a, start_a_int, len_a}; |
| |
| auto b = at::randn({5, 5, 5}); |
| int dim_b = 1; |
| int start_b_int = 2; |
| int len_b = 3; |
| std::vector<IValue> args_b{b, dim_b, start_b_int, len_b}; |
| |
| testStaticRuntime(narrow_with_int_script, args_a); |
| testStaticRuntime(narrow_with_int_script, args_a, args_b); |
| } |
| |
| TEST(StaticRuntime, TupleUnpack) { |
| const auto two_tuple_unpack_script = R"JIT( |
| def forward(self, tup: Tuple[Tensor, Tensor]): |
| a, b = tup |
| return (a, b) |
| )JIT"; |
| |
| const auto three_tuple_unpack_script = R"JIT( |
| def forward(self, tup: Tuple[Tensor, Tensor, Tensor]): |
| a, b, c = tup |
| return (a, b, c) |
| )JIT"; |
| |
| auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})}); |
| auto two_tup_large = |
| c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})}); |
| |
| auto three_tup = c10::ivalue::Tuple::create( |
| {at::randn({1}), at::randn({1}), at::randn({1})}); |
| auto three_tup_large = c10::ivalue::Tuple::create( |
| {at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})}); |
| |
| testStaticRuntime(two_tuple_unpack_script, {two_tup}); |
| testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large}); |
| |
| testStaticRuntime(three_tuple_unpack_script, {three_tup}); |
| testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large}); |
| } |
| |
| TEST(StaticRuntime, Append) { |
| const auto append_int_script = R"JIT( |
| def forward(self, a: int): |
| lst = [1, 2, 3] |
| lst.append(a) |
| return lst |
| )JIT"; |
| |
| const auto append_tensor_script = R"JIT( |
| def forward(self, a: Tensor): |
| lst = [] |
| lst.append(a) |
| return lst |
| )JIT"; |
| |
| std::vector<IValue> args_int{1}; |
| |
| testStaticRuntime(append_int_script, args_int); |
| |
| std::vector<IValue> args_tensor{at::randn({1})}; |
| std::vector<IValue> args_tensor_large{at::randn({2, 2})}; |
| |
| testStaticRuntime(append_tensor_script, args_tensor); |
| testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large); |
| } |
| |
| TEST(StaticRuntime, QuantizedLinear) { |
| const std::string quantize_script = R"IR( |
| graph(%input: Tensor, %weights: Tensor): |
| %scale: float = prim::Constant[value=1.]() |
| %zero_point: int = prim::Constant[value=1]() |
| %bias: None = prim::Constant() |
| %packed_params = quantized::linear_prepack(%weights, %bias) |
| %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point) |
| %1249: Tensor = aten::dequantize(%1254) |
| return (%1249) |
| )IR"; |
| at::Tensor weight = |
| at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8); |
| at::Tensor input = |
| at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8); |
| |
| at::Tensor weight_2 = |
| at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8); |
| at::Tensor input_2 = |
| at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8); |
| |
| testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2}); |
| } |
| |
| TEST(StaticRuntime, QuantizedLinearDynamicFp16) { |
| const std::string quantized_linear_dynamic_fp16_script = R"IR( |
| graph(%input: Tensor, %weights: Tensor): |
| %bias: None = prim::Constant() |
| %packed_params = quantized::linear_prepack_fp16(%weights, %bias) |
| %output = quantized::linear_dynamic_fp16(%input, %packed_params) |
| %ret = aten::clone(%output, %bias) |
| return (%ret) |
| )IR"; |
| at::Tensor weight = torch::randn({3, 2}, torch::kFloat); |
| at::Tensor input = torch::randn({3, 2}, torch::kFloat); |
| |
| at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat); |
| at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat); |
| |
| testStaticRuntime( |
| quantized_linear_dynamic_fp16_script, |
| {input, weight}, |
| {input_2, weight_2}); |
| } |
| |
| TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) { |
| const std::string quantized_linear_relu_dynamic_fp16_script = R"IR( |
| graph(%input: Tensor, %weights: Tensor): |
| %bias: None = prim::Constant() |
| %packed_params = quantized::linear_prepack_fp16(%weights, %bias) |
| %output = quantized::linear_relu_dynamic_fp16(%input, %packed_params) |
| %ret = aten::clone(%output, %bias) |
| return (%ret) |
| )IR"; |
| at::Tensor weight = torch::randn({3, 2}, torch::kFloat); |
| at::Tensor input = torch::randn({3, 2}, torch::kFloat); |
| |
| at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat); |
| at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat); |
| |
| testStaticRuntime( |
| quantized_linear_relu_dynamic_fp16_script, |
| {input, weight}, |
| {input_2, weight_2}); |
| } |
| |
| TEST(StaticRuntime, VarStack) { |
| const auto var_stack_script = R"JIT( |
| def forward(self, inp1: Tensor, inp2: Tensor, dim: int): |
| return torch.stack([inp1, inp2], dim).clone() |
| )JIT"; |
| |
| // 2D tensors - stack dim = 0 |
| std::vector<IValue> args1 = {at::randn({6, 6}), at::randn({6, 6}), 0}; |
| testStaticRuntime(var_stack_script, args1); |
| |
| // 3D tensors - stack dim = 1 |
| std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1}; |
| testStaticRuntime(var_stack_script, args2); |
| |
| // 3D tensors - stack dim = 2 |
| std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2}; |
| testStaticRuntime(var_stack_script, args3); |
| |
| // Negative dim |
| std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), -1}; |
| testStaticRuntime(var_stack_script, args4); |
| |
| // Non-serial path |
| std::vector<IValue> args5 = {at::randn({1, 2, 3}), at::randn({1, 2, 3}), 3}; |
| testStaticRuntime(var_stack_script, args5); |
| |
| // Fast path |
| std::vector<IValue> args6 = {at::randn({1}), at::randn({1}), 0}; |
| testStaticRuntime(var_stack_script, args6); |
| |
| testStaticRuntime(var_stack_script, args1, args2); |
| } |
| |
| TEST(StaticRuntime, FmodTensor) { |
| const auto fmod_tensor = R"JIT( |
| def forward(self, a: Tensor, b: Tensor): |
| return torch.fmod(a, b).clone() |
| )JIT"; |
| |
| // fmod tensor version |
| auto a = at::randn({2, 3}); |
| auto b = at::randn({2, 3}); |
| std::vector<IValue> args0{a, b}; |
| testStaticRuntime(fmod_tensor, args0); |
| |
| // check for dynamic shapes |
| auto c = at::randn({4, 3, 2}); |
| auto d = at::randn({4, 3, 2}); |
| std::vector<IValue> args1{c, d}; |
| testStaticRuntime(fmod_tensor, args0, args1); |
| } |
| |
| TEST(StaticRuntime, FmodScalar) { |
| const auto fmod_scalar = R"JIT( |
| def forward(self, a: Tensor, b: int): |
| return torch.fmod(a, b).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| |
| // fmod scalar version |
| std::vector<IValue> args2{a, 3}; |
| testStaticRuntime(fmod_scalar, args2); |
| |
| // check for dynamic shapes |
| auto c = at::randn({4, 3, 2}); |
| std::vector<IValue> args3{c, 4}; |
| testStaticRuntime(fmod_scalar, args2, args3); |
| |
| // test int32 version |
| a = at::randint(-100, 100, {2, 3}, at::kInt); |
| c = at::randint(-100, 100, {4, 3, 2}, at::kInt); |
| testStaticRuntime(fmod_scalar, {a, 3}); |
| testStaticRuntime(fmod_scalar, {a, 3}, {c, 4}); |
| } |
| |
| TEST(StaticRuntime, QEmbeddingBagBytePrepack) { |
| const std::string embedding_bag_byte_prepack_script = R"IR( |
| graph(%input: Tensor): |
| %none : None = prim::Constant() |
| %output: Tensor = quantized::embedding_bag_byte_prepack(%input) |
| %res: Tensor = aten::clone(%output, %none) |
| return (%res) |
| )IR"; |
| |
| auto a = torch::randn({8, 16}, at::ScalarType::Float); |
| auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float); |
| |
| testStaticRuntime(embedding_bag_byte_prepack_script, {a}); |
| testStaticRuntime(embedding_bag_byte_prepack_script, {a}, {b}); |
| } |
| |
| TEST(StaticRuntime, QEmbeddingBagByteUnpack) { |
| const auto src = R"IR( |
| graph(%input: Tensor): |
| %none : None = prim::Constant() |
| %weight: Tensor = quantized::embedding_bag_byte_prepack(%input) |
| %output: Tensor = quantized::embedding_bag_byte_unpack(%weight) |
| %res: Tensor = aten::clone(%output, %none) |
| return (%res) |
| )IR"; |
| |
| auto a = torch::randn({8, 16}, at::ScalarType::Float); |
| auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float); |
| |
| testStaticRuntime(src, {a}); |
| testStaticRuntime(src, {a}, {b}); |
| } |
| |
| TEST(StaticRuntime, LinalgNorm_ScalarOrd) { |
| const auto linalg_norm_ord_scalar = R"JIT( |
| def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int): |
| return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto dim = std::vector<int64_t>({1}); |
| auto dtype = at::ScalarType::Float; |
| |
| std::vector<IValue> args0{a, 4, dim, true, dtype}; |
| testStaticRuntime(linalg_norm_ord_scalar, args0); |
| |
| auto b = at::randn({3, 2, 6}); |
| std::vector<IValue> args1{b, 4, dim, true, dtype}; |
| testStaticRuntime(linalg_norm_ord_scalar, args0, args1); |
| } |
| |
| TEST(StaticRuntime, LinalgNorm_StringOrd) { |
| const auto linalg_norm_ord_str = R"JIT( |
| def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int): |
| return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| auto dim = std::vector<int64_t>({0, 1}); |
| auto dtype = at::ScalarType::Float; |
| |
| std::vector<IValue> args0{a, "fro", dim, true, dtype}; |
| testStaticRuntime(linalg_norm_ord_str, args0); |
| |
| auto b = at::randn({3, 2, 17}); |
| std::vector<IValue> args1{b, "fro", dim, true, dtype}; |
| testStaticRuntime(linalg_norm_ord_str, args0, args1); |
| } |
| |
| TEST(StaticRuntime, Index_Put) { |
| const auto index_put_str = R"JIT( |
| def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool): |
| return torch.index_put(a, indices, values, accumulate).clone() |
| )JIT"; |
| |
| auto a = at::randn({2}); |
| auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong)); |
| auto values_a = at::randn({1}); |
| |
| std::vector<IValue> args0{a, indices_a, values_a, false}; |
| testStaticRuntime(index_put_str, args0); |
| |
| const auto index_put_non_optional_str = R"JIT( |
| def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool): |
| return torch.index_put(a, indices, values, accumulate).clone() |
| )JIT"; |
| |
| auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)}; |
| std::vector<IValue> args1{a, indices_b, values_a, false}; |
| testStaticRuntime(index_put_non_optional_str, args1); |
| |
| const auto index_put_list_construct = R"JIT( |
| def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool): |
| indices: List[Optional[Tensor]] = [indices] |
| return torch.index_put(a, indices, values, accumulate).clone() |
| )JIT"; |
| |
| std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false}; |
| testStaticRuntime(index_put_list_construct, args2); |
| } |
| |
| TEST(StaticRuntime, Item) { |
| const auto item_str = R"JIT( |
| def forward(self, a: Tensor): |
| return torch.item(a) |
| )JIT"; |
| |
| auto a = at::randn({1}); |
| |
| std::vector<IValue> args0{a}; |
| testStaticRuntime(item_str, args0); |
| } |
| |
| TEST(StaticRuntime, Tensor_Split) { |
| const auto tensor_split_str1 = R"JIT( |
| def forward(self, a: Tensor, sections: int, dim: int): |
| return torch.tensor_split(a, sections, dim) |
| )JIT"; |
| std::vector<IValue> args1{at::randn({8}), 3, 0}; |
| |
| const auto tensor_split_str2 = R"JIT( |
| def forward(self, a: Tensor, sections: Tensor, dim: int): |
| return torch.tensor_split(a, sections, dim) |
| )JIT"; |
| std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0}; |
| |
| const auto tensor_split_str3 = R"JIT( |
| def forward(self, a: Tensor, indicies: List[int], dim: int): |
| return torch.tensor_split(a, indicies, dim) |
| )JIT"; |
| std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0}; |
| |
| testStaticRuntime(tensor_split_str1, args1); |
| testStaticRuntime(tensor_split_str2, args2); |
| testStaticRuntime(tensor_split_str3, args3); |
| } |
| |
| TEST(StaticRuntime, JIT_Aten_Cpu) { |
| const std::string script = R"IR( |
| graph(%a: Tensor): |
| %1 : int = prim::Constant[value=0]() |
| %aa: Tensor = aten::add(%a, %a, %1) |
| %ret: Tensor = aten::cpu(%aa) |
| return (%ret) |
| )IR"; |
| |
| auto graph = std::make_shared<Graph>(); |
| std::unordered_map<std::string, Value*> vmap; |
| vmap.reserve(0); |
| parseIR(script, graph.get(), vmap); |
| torch::jit::StaticModule smodule(graph); |
| |
| auto a = at::randn({2, 4}); |
| std::vector<IValue> args0{a}; |
| |
| testStaticRuntime(script, args0); |
| } |
| |
| TEST(StaticRuntime, JIT_Aten_Numel) { |
| const std::string script = R"IR( |
| graph(%a: Tensor): |
| %1 : int = prim::Constant[value=0]() |
| %aa: Tensor = aten::add(%a, %a, %1) |
| %ret: int = aten::numel(%aa) |
| return (%ret) |
| )IR"; |
| |
| auto graph = std::make_shared<Graph>(); |
| std::unordered_map<std::string, Value*> vmap; |
| vmap.reserve(0); |
| parseIR(script, graph.get(), vmap); |
| torch::jit::StaticModule smodule(graph); |
| |
| auto a = at::randn({2, 4}); |
| std::vector<IValue> args0{a}; |
| |
| testStaticRuntime(script, args0); |
| } |
| |
| TEST(StaticRuntime, JIT_Aten_List) { |
| const auto script_str = R"IR( |
| graph(%a: str): |
| %ret: str[] = aten::list(%a) |
| return (%ret) |
| )IR"; |
| std::string a = "abcd"; |
| std::vector<IValue> args0{a}; |
| testStaticRuntime(script_str, args0); |
| |
| // Update the result of aten::list to ensure that a deep copy |
| // took place |
| const auto script_list = R"IR( |
| graph(%a : int[]): |
| %idx : int = prim::Constant[value=0]() |
| %value : int = prim::Constant[value=42]() |
| %res : int[] = aten::list(%a) |
| %updated : int[] = aten::_set_item(%res, %idx, %value) |
| return (%res, %a) |
| )IR"; |
| |
| std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}}; |
| testStaticRuntime(script_list, args1); |
| } |
| |
| TEST(StaticRuntime, JIT_Aten_Range_Length) { |
| const std::string script = R"IR( |
| graph(%lo: int, %hi: int, %step: int): |
| %1 : int = prim::Constant[value=0]() |
| %ret: int = aten::__range_length(%lo, %hi, %step) |
| return (%ret) |
| )IR"; |
| |
| auto graph = std::make_shared<Graph>(); |
| std::unordered_map<std::string, Value*> vmap; |
| vmap.reserve(0); |
| parseIR(script, graph.get(), vmap); |
| torch::jit::StaticModule smodule(graph); |
| |
| std::vector<IValue> args0{0, 10, 2}; |
| |
| testStaticRuntime(script, args0); |
| } |
| |
| TEST(StaticRuntime, Cat) { |
| const std::string cat_script = R"IR( |
| graph(%a: Tensor, %b: Tensor, %dim: int): |
| %ten_list: Tensor[] = prim::ListConstruct(%a, %b) |
| %1 : int = prim::Constant[value=0]() |
| %2 : int = prim::Constant[value=1]() |
| %3 : int = prim::Constant[value=1]() |
| %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3) |
| %ret: Tensor = aten::cat(%ten_list2, %dim) |
| return (%ret) |
| )IR"; |
| |
| auto graph = std::make_shared<Graph>(); |
| std::unordered_map<std::string, Value*> vmap; |
| parseIR(cat_script, graph.get(), vmap); |
| torch::jit::StaticModule smodule(graph); |
| ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat")); |
| |
| auto a = at::randn({2, 4}); |
| auto b = at::randn({3, 4}); |
| std::vector<IValue> args0{a, b, 0}; |
| |
| testStaticRuntime(cat_script, args0); |
| |
| auto c = at::randn({3, 4}); |
| auto d = at::randn({3, 5}); |
| std::vector<IValue> args1{c, d, 1}; |
| testStaticRuntime(cat_script, args0, args1); |
| |
| std::vector<IValue> args_dim_negative{c, d, -1}; |
| testStaticRuntime(cat_script, args_dim_negative); |
| } |
| |
| TEST(StaticRuntime, Cumsum) { |
| const auto cumsum_script = R"JIT( |
| def forward(self, a: Tensor, dim: int): |
| return torch.cumsum(a, dim).clone() |
| )JIT"; |
| |
| auto a = at::randn({2, 3}); |
| std::vector<IValue> args0{a, 0}; |
| testStaticRuntime(cumsum_script, args0); |
| |
| auto b = at::randn({3, 6}); |
| std::vector<IValue> args1{b, 1}; |
| testStaticRuntime(cumsum_script, args0, args1); |
| } |
| |
| TEST(StaticRuntime, CumsumDtype) { |
| const auto cumsum_script_dtype = R"JIT( |
| def forward(self, a: Tensor, dim: int, dtype: int): |
| return torch.cumsum(a, dim, dtype=dtype).clone() |
| )JIT"; |
| |
| auto a = at::randn({1, 2}); |
| auto dtype = at::ScalarType::Float; |
| std::vector<IValue> args0{a, 0, dtype}; |
| testStaticRuntime(cumsum_script_dtype, args0); |
| |
| auto b = at::randn({3, 6}); |
| std::vector<IValue> args1{b, 1, dtype}; |
| testStaticRuntime(cumsum_script_dtype, args0, args1); |
| } |
| |
| TEST(StaticRuntime, Nonzero) { |
| const auto nonzero_tensor = R"JIT( |
| def forward(self, input: Tensor): |
| a = torch.nonzero(input).clone() |
| return (a) |
| )JIT"; |
| |
| auto a = at::randint(0, 2, {2, 3}); |
| testStaticRuntime(nonzero_tensor, {a}); |
| |
| auto b = at::randint(0, 2, {4, 3, 2}); |
| testStaticRuntime(nonzero_tensor, {a}, {b}); |
| } |
| |
| TEST(StaticRuntime, SignedLog1p) { |
| const std::string signed_log1p_script = R"IR( |
| graph(%input): |
| %0 : Tensor = aten::sign(%input) |
| %1 : Tensor = aten::abs(%input) |
| %2 : Tensor = aten::log1p(%1) |
| %3 : Tensor = aten::mul(%0, %2) |
| %none : NoneType = prim::Constant() |
| %res : Tensor = aten::clone(%3, %none) |
| return (%res) |
| )IR"; |
| |
| std::vector<IValue> args1 = {at::randn({2, 2})}; |
| testStaticRuntime(signed_log1p_script, args1, {}, true); |
| |
| std::vector<IValue> args2 = {at::randn({3, 3, 3})}; |
| testStaticRuntime(signed_log1p_script, args1, args2, true); |
| } |
| |
| TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithImmutableInputDict) { |
| const auto getitem_immutable_input_dict_script = R"JIT( |
| def forward(self, input: Dict[int, Tensor]): |
| a = input[0] |
| b = input[1] |
| c = a + b |
| return c.clone() |
| )JIT"; |
| |
| script::Module module("module"); |
| module.define(getitem_immutable_input_dict_script); |
| torch::jit::StaticModule smodule(module); |
| EXPECT_FALSE(hasNodeWithKind(smodule, "aten::__getitem__")); |
| EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::dict_unpack")); |
| |
| auto a = at::randn({2, 4}); |
| auto b = at::randn({2, 4}); |
| c10::Dict<c10::IValue, c10::IValue> dict( |
| c10::IntType::get(), c10::TensorType::get()); |
| dict.insert(0, a); |
| dict.insert(1, b); |
| testStaticRuntime(getitem_immutable_input_dict_script, {dict}); |
| |
| c10::Dict<c10::IValue, c10::IValue> dict0( |
| c10::IntType::get(), c10::TensorType::get()); |
| auto a0 = at::randn({3, 4}); |
| auto b0 = at::randn({3, 4}); |
| dict0.insert(0, a0); |
| dict0.insert(1, b0); |
| testStaticRuntime(getitem_immutable_input_dict_script, {dict0}); |
| } |
| |
| TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithMutableInputDict) { |
| const auto getitem_mutable_input_dict_script = R"JIT( |
| def forward(self, input: Dict[int, Tensor]): |
| a = input[0] |
| input[1] = a |
| b = input[1] |
| c = a + b |
| return c.clone() |
| )JIT"; |
| |
| script::Module module("module"); |
| module.define(getitem_mutable_input_dict_script); |
| torch::jit::StaticModule smodule(module); |
| EXPECT_TRUE(hasNodeWithKind(smodule, "aten::__getitem__")); |
| EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::dict_unpack")); |
| } |
| |
| TEST(StaticRuntime, VarTupleUnpack) { |
| const auto var_tuple_unpack_script = R"JIT( |
| def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]): |
| a, b = input_0 |
| c, d = input_1 |
| res = a * c + b * d |
| return res.clone() |
| )JIT"; |
| |
| script::Module module("module"); |
| module.define(var_tuple_unpack_script); |
| torch::jit::StaticModule smodule(module); |
| EXPECT_FALSE(hasNodeWithKind(smodule, "prim::TupleUnpack")); |
| EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack")); |
| |
| auto a = at::randn({2, 2}); |
| auto b = at::randn({3, 3, 3}); |
| std::vector<IValue> args1{ |
| c10::ivalue::Tuple::create(a, a), c10::ivalue::Tuple::create(1, 2)}; |
| std::vector<IValue> args2{ |
| c10::ivalue::Tuple::create(b, b), c10::ivalue::Tuple::create(1, 2)}; |
| |
| testStaticRuntime(var_tuple_unpack_script, args1); |
| testStaticRuntime(var_tuple_unpack_script, args1, args2); |
| } |
| |
| TEST(StaticRuntime, VarTupleUnpack_NotApplied) { |
| const auto var_tuple_unpack_not_applied_script = R"JIT( |
| def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]): |
| a, b = input_0 |
| x = a + b |
| c, d = input_1 |
| res = a * c + b * d + x |
| return res.clone() |
| )JIT"; |
| |
| script::Module module("module"); |
| // In this script, the optimization is not applied since there is a |
| // computation between the TupleUnpack nodes. |
| module.define(var_tuple_unpack_not_applied_script); |
| torch::jit::StaticModule smodule(module); |
| EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack")); |
| EXPECT_TRUE(hasNodeWithKind(smodule, "prim::TupleUnpack")); |
| } |
| |
| TEST(StaticRuntime, RemainderTensor) { |
| const auto remainder_tensor = R"JIT( |
| def forward(self, x, y): |
| return torch.remainder(x, y).clone() |
| )JIT"; |
| |
| std::vector<IValue> args1 = { |
| at::randint(0, 10, {2, 2}), at::randint(1, 10, {2, 2})}; |
| std::vector<IValue> args2 = { |
| at::randint(0, 10, {3, 6}), at::randint(1, 10, {3, 6})}; |
| |
| // Use allclose and equalnan since outputs may be NaN. |
| testStaticRuntime( |
| remainder_tensor, |
| args1, |
| /*args2*/ {}, |
| /*use_alloclose*/ true, |
| /*use_equalnan*/ true); |
| testStaticRuntime( |
| remainder_tensor, |
| args1, |
| args2, |
| /*use_allclose*/ true, |
| /*use_equalnan*/ true); |
| } |
| |
| TEST(StaticRuntime, RemainderScalar) { |
| const auto remainder_scalar = R"JIT( |
| def forward(self, x, y: int): |
| return torch.remainder(x, y).clone() |
| )JIT"; |
| |
| std::vector<IValue> args1 = {at::randint(0, 10, {2, 2}), 4}; |
| std::vector<IValue> args2 = {at::randint(0, 10, {3, 6}), 4}; |
| |
| // Use allclose and equalnan since outputs may be NaN. |
| testStaticRuntime( |
| remainder_scalar, |
| args1, |
| /*args2*/ {}, |
| /*use_alloclose*/ true, |
| /*use_equalnan*/ true); |
| testStaticRuntime( |
| remainder_scalar, |
| args1, |
| args2, |
| /*use_allclose*/ true, |
| /*use_equalnan*/ true); |
| } |
| |
| TEST(StaticRuntime, Where) { |
| const auto where_script = R"JIT( |
| def forward(self, x, y): |
| return torch.where(x > 0, x, y).clone() |
| )JIT"; |
| |
| std::vector<IValue> args1 = {at::randn({2, 2}), at::randn({2, 2})}; |
| std::vector<IValue> args2 = {at::randn({8, 10}), at::randn({8, 10})}; |
| |
| testStaticRuntime(where_script, args1); |
| testStaticRuntime(where_script, args1, args2); |
| } |
| |
| TEST(StaticRuntime, WhereBroadcast) { |
| const auto where_script = R"JIT( |
| def forward(self, cond_1d, x, y): |
| shape = [-1] + [1] * (x.dim() - 1) |
| cond = cond_1d.view(shape) |
| return torch.where(cond, x, y).clone() |
| )JIT"; |
| |
| std::vector<IValue> args1 = { |
| at::tensor({0, 1}).to(at::kBool), at::randn({2, 2}), at::randn({2, 2})}; |
| std::vector<IValue> args2 = { |
| at::tensor({1, 0, 0}).to(at::kBool), |
| at::randn({3, 6}), |
| at::randn({3, 6})}; |
| |
| testStaticRuntime(where_script, args1); |
| testStaticRuntime(where_script, args1, args2); |
| } |
| |
| TEST(StaticRuntime, View) { |
| // Note that clone is not technically necessary here since this is not |
| // an out variant, but it suppresses warnings about only have one op |
| // in testStaticRuntime |
| const auto src = R"IR( |
| graph(%input : Tensor, %shape : int[]): |
| %none : NoneType = prim::Constant() |
| %view : Tensor = aten::view(%input, %shape) |
| %res : Tensor = aten::clone(%view, %none) |
| return (%res) |
| )IR"; |
| |
| std::vector<IValue> args1{at::randn({2, 2}), c10::List<int64_t>(4)}; |
| std::vector<IValue> args2{at::randn({2, 2, 2}), c10::List<int64_t>({4, 2})}; |
| |
| testStaticRuntime(src, args1); |
| testStaticRuntime(src, args1, args2); |
| } |
| |
| TEST(StaticRuntime, Size) { |
| const auto src_with_dim = R"JIT( |
| def forward(self, x, dim: int): |
| return x.size(dim) |
| )JIT"; |
| |
| const auto src_no_dim = R"JIT( |
| def forward(self, x): |
| return x.size() |
| )JIT"; |
| |
| std::vector<IValue> args1{at::randn({1}), 0}; |
| std::vector<IValue> args2{at::randn({1}), -1}; |
| std::vector<IValue> args3{at::randn({2, 4}), 1}; |
| std::vector<IValue> args_no_dim{at::randn({2, 4})}; |
| |
| testStaticRuntime(src_with_dim, args1); |
| testStaticRuntime(src_with_dim, args2); |
| testStaticRuntime(src_with_dim, args1, args3); |
| testStaticRuntime(src_no_dim, args_no_dim); |
| } |
| |
| TEST(StaticRuntime, Squeeze) { |
| // Note: this is a native op, not an out variant, but clone anyways |
| // to silence warnings in testStaticRuntime |
| const auto src = R"JIT( |
| def forward(self, inp, dim: int): |
| return inp.squeeze(dim).clone() |
| )JIT"; |
| |
| const auto a = at::randn({2, 2}); |
| const auto b = at::randn({3, 2, 3}); |
| |
| testStaticRuntime(src, {a, 0}); |
| testStaticRuntime(src, {a, 1}); |
| testStaticRuntime(src, {a, -1}, {b, 2}); |
| } |
| |
| TEST(StaticRuntime, NumToTensorScalar) { |
| const auto num_to_tensor_ir = R"IR( |
| graph(%1 : int): |
| %2 : NoneType = prim::Constant() |
| %3 : Tensor = prim::NumToTensor(%1) |
| %4 : Tensor = aten::clone(%3, %2) |
| return (%4) |
| )IR"; |
| |
| IValue arg{5}; |
| std::vector<IValue> args = {arg}; |
| testStaticRuntime(num_to_tensor_ir, args); |
| } |
| |
| TEST(StaticRuntime, NumToTensorFalse) { |
| const auto num_to_tensor_ir = R"IR( |
| graph(%1 : bool): |
| %2 : NoneType = prim::Constant() |
| %3 : Tensor = prim::NumToTensor(%1) |
| %4 : Tensor = aten::clone(%3, %2) |
| return (%4) |
| )IR"; |
| |
| IValue arg{false}; |
| std::vector<IValue> args = {arg}; |
| testStaticRuntime(num_to_tensor_ir, args); |
| } |
| |
| TEST(StaticRuntime, NumToTensorTrue) { |
| const auto num_to_tensor_ir = R"IR( |
| graph(%1 : bool): |
| %2 : NoneType = prim::Constant() |
| %3 : Tensor = prim::NumToTensor(%1) |
| %4 : Tensor = aten::clone(%3, %2) |
| return (%4) |
| )IR"; |
| |
| IValue arg{true}; |
| std::vector<IValue> args = {arg}; |
| testStaticRuntime(num_to_tensor_ir, args); |
| } |
| |
| TEST(StaticRuntime, Split) { |
| const auto src = R"JIT( |
| def forward(self, inp, split_size: int, dim: int): |
| return inp.split(split_size, dim) |
| )JIT"; |
| |
| const auto a = at::randn({2, 2}); |
| const auto b = at::randn({2, 2, 2}); |
| |
| testStaticRuntime(src, {a, 1, 0}); |
| testStaticRuntime(src, {a, 1, 1}); |
| testStaticRuntime(src, {a, 2, -1}, {b, 2, 2}); |
| } |
| |
| TEST(StaticRuntime, SplitWithSizes) { |
| const auto src = R"JIT( |
| def forward(self, inp, split_sizes: List[int], dim: int): |
| return inp.split(split_sizes, dim) |
| )JIT"; |
| |
| const auto a = at::randn({2, 2}); |
| const auto b = at::randn({2, 2, 2}); |
| const auto split_sizes = c10::List<int64_t>{1, 1}; |
| |
| testStaticRuntime(src, {a, split_sizes, 0}); |
| testStaticRuntime(src, {a, split_sizes, 1}); |
| testStaticRuntime(src, {a, split_sizes, -1}, {b, split_sizes, 2}); |
| } |
| |
| namespace { |
| |
| void maybe_throw(bool should_throw) { |
| if (should_throw) { |
| throw std::runtime_error("test exception"); |
| } |
| } |
| |
| TORCH_LIBRARY(static_runtime_tests, m) { |
| // Conservative so this op doesn't get deleted by dead |
| // code elimination |
| m.def(torch::schema( |
| "static_runtime_tests::maybe_throw(bool throw) -> ()", |
| at::AliasAnalysisKind::CONSERVATIVE)); |
| m.impl("maybe_throw", maybe_throw); |
| } |
| |
| } // namespace |
| |
| TEST(StaticRuntime, ModelCrashOnFirstRun) { |
| const auto src = R"JIT( |
| graph(%0: Tensor, %throw: bool): |
| %1: Tensor = aten::mul(%0, %0) |
| static_runtime_tests::maybe_throw(%throw) |
| %2: Tensor = aten::mul(%1, %1) |
| %3: Tensor = aten::mul(%2, %2) |
| return (%3) |
| )JIT"; |
| |
| auto graph = getGraphFromIR(src); |
| auto static_module = StaticModule(graph); |
| auto& runtime = static_module.runtime(); |
| |
| std::vector<IValue> args_crash{at::randn({1}), true}; |
| std::vector<IValue> args_no_crash{at::randn({1}), false}; |
| EXPECT_THROW(runtime(args_crash, {}), std::runtime_error); |
| |
| // The run didn't finish, we didn't allocate the memory planner |
| EXPECT_EQ(runtime.get_memory_planner(), nullptr); |
| runtime.check_for_memory_leak(); |
| |
| // We guarantee that the runtime is still usable after the crash. |
| // Run again to verify this. |
| compareResultsWithJIT(runtime, graph, args_no_crash); |
| EXPECT_NE(runtime.get_memory_planner(), nullptr); |
| } |
| |
| TEST(StaticRuntime, ModelCrashOnSecondRun) { |
| const auto src = R"JIT( |
| graph(%0: Tensor, %throw: bool): |
| %1: Tensor = aten::mul(%0, %0) |
| static_runtime_tests::maybe_throw(%throw) |
| %2: Tensor = aten::mul(%1, %1) |
| %3: Tensor = aten::mul(%2, %2) |
| return (%3) |
| )JIT"; |
| |
| auto graph = getGraphFromIR(src); |
| auto static_module = StaticModule(graph); |
| auto& runtime = static_module.runtime(); |
| |
| std::vector<IValue> args_crash{at::randn({1}), true}; |
| std::vector<IValue> args_no_crash{at::randn({1}), false}; |
| runtime(args_no_crash, {}); |
| EXPECT_NE(runtime.get_memory_planner(), nullptr); |
| runtime.check_for_memory_leak(); |
| |
| EXPECT_THROW(runtime(args_crash, {}), std::runtime_error); |
| runtime.check_for_memory_leak(); |
| |
| // We guarantee that the runtime is still usable after the crash. |
| // Run again to verify this. |
| compareResultsWithJIT(runtime, graph, args_no_crash); |
| } |
| |
| TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) { |
| const auto src = R"JIT( |
| graph(%0: Tensor): |
| %1: Tensor = aten::mul(%0, %0) |
| %2: Tensor = aten::mul(%1, %1) |
| %3: bool = prim::Constant[value=1]() |
| %4: Tensor = static_runtime::select_tensor(%1, %2, %3) |
| static_runtime_tests::maybe_throw(%3) |
| return (%4) |
| )JIT"; |
| auto graph = getGraphFromIR(src); |
| auto static_module = StaticModule(graph); |
| auto& runtime = static_module.runtime(); |
| |
| std::vector<IValue> args{at::randn({1})}; |
| EXPECT_THROW(runtime(args), std::runtime_error); |
| } |
| |
| TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) { |
| const auto src = R"JIT( |
| graph(%0: Tensor, %1: Tensor): |
| %2: bool = prim::Constant[value=1]() |
| %3: Tensor = static_runtime::select_tensor(%0, %1, %2) |
| static_runtime_tests::maybe_throw(%2) |
| return (%3) |
| )JIT"; |
| auto graph = getGraphFromIR(src); |
| auto static_module = StaticModule(graph); |
| auto& runtime = static_module.runtime(); |
| |
| std::vector<IValue> args{at::randn({1}), at::randn({1})}; |
| EXPECT_THROW(runtime(std::move(args)), std::runtime_error); |
| } |
| |
| TEST(StaticRuntime, ReplaceWithMaybeCopy) { |
| const std::string to = R"IR( |
| graph(%0 : Tensor): |
| %1: int = prim::Constant[value=4]() |
| %2: bool = prim::Constant[value=0]() |
| %3: None = prim::Constant() |
| %res : Tensor = aten::to(%0, %1, %2, %2, %3) |
| return (%res) |
| )IR"; |
| |
| at::Tensor a = at::tensor({1.1, 2.2, 3.3, 4.0}, at::ScalarType::Float); |
| std::vector<IValue> args{a}; |
| auto g = std::make_shared<torch::jit::Graph>(); |
| torch::jit::parseIR(to, g.get()); |
| |
| // Jit Interpreter. |
| Stack stack(args); |
| torch::jit::GraphExecutor graph_exec(g, ""); |
| graph_exec.run(stack); |
| ASSERT_EQ(stack.size(), 1); |
| auto expected = stack[0].toTensor(); |
| |
| // Static Runtime. |
| torch::jit::StaticModule smodule(g); |
| auto actual = smodule(args, {}).toTensor(); |
| smodule.runtime().check_for_memory_leak(); |
| |
| EXPECT_TRUE(expected.equal(actual)); |
| EXPECT_FALSE(hasProcessedNodeWithName(smodule, "aten::to")); |
| EXPECT_TRUE( |
| hasProcessedNodeWithName(smodule, "static_runtime::to_maybe_copy_out")); |
| } |
| |
| TEST(StaticRuntime, Int) { |
| const auto src = R"JIT( |
| def forward(self, x): |
| return int(x) + int(x) |
| )JIT"; |
| std::vector<IValue> args{at::tensor({3.14})}; |
| testStaticRuntime(src, args); |
| } |
| |
| TEST(StaticRuntime, ReturnConstant) { |
| const auto src = R"JIT( |
| def forward(self): |
| return 1 |
| )JIT"; |
| |
| testStaticRuntime(src, {}); |
| } |
| |
| TEST(StaticRuntime, SimpleIf) { |
| const auto src = R"JIT( |
| def forward(self, cond: bool, x): |
| if cond: |
| return torch.mul(x, 42).clone() |
| else: |
| return x.clone() |
| )JIT"; |
| |
| std::vector<IValue> args_false{false, at::randn({1})}; |
| std::vector<IValue> args_true{true, at::randn({1})}; |
| std::vector<IValue> args_big_tensor{true, at::randn({3, 3, 3})}; |
| |
| testStaticRuntime(src, args_false); |
| testStaticRuntime(src, args_true); |
| testStaticRuntime(src, args_true, args_big_tensor); |
| } |
| |
| TEST(StaticRuntime, NestedIf) { |
| const auto src = R"JIT( |
| def forward(self, cond1: bool, cond2: bool, x): |
| y = x * 42 |
| if cond1: |
| y = y * y |
| if cond2: |
| y += x |
| else: |
| if cond2: |
| return x.clone() |
| |
| return y.clone() |
| )JIT"; |
| |
| for (auto cond1 : {true, false}) { |
| for (auto cond2 : {true, false}) { |
| std::vector<IValue> args1{cond1, cond2, at::randn({1})}; |
| std::vector<IValue> args2{cond1, cond2, at::randn({3, 3, 3})}; |
| testStaticRuntime(src, args1, args2); |
| } |
| } |
| } |
| |
| TEST(StaticRuntime, DeeplyNestedIf) { |
| const auto src = R"JIT( |
| def forward(self, cond1: bool, cond2: bool, cond3: bool, x): |
| y = x * 42 |
| if cond1: |
| y = y * y |
| if cond2: |
| y += x |
| |
| if cond2 and cond3: |
| y += 1 |
| |
| if cond2: |
| if cond3: |
| y += 2 |
| else: |
| y = y * y |
| y += 4 |
| else: |
| if cond2: |
| return x.clone() |
| if cond3 or cond2: |
| y += 42 |
| |
| return y.clone() |
| )JIT"; |
| |
| for (auto cond1 : {true, false}) { |
| for (auto cond2 : {true, false}) { |
| for (auto cond3 : {true, false}) { |
| std::vector<IValue> args1{cond1, cond2, cond3, at::randn({1})}; |
| std::vector<IValue> args2{cond1, cond2, cond3, at::randn({3, 3, 3})}; |
| testStaticRuntime(src, args1, args2); |
| } |
| } |
| } |
| } |
| |
| TEST(StaticRuntime, BasicForLoop) { |
| const auto src = R"JIT( |
| def forward(self, x, loop_max: int): |
| y = x.clone() |
| for i in range(loop_max): |
| y += 1 |
| return y |
| )JIT"; |
| |
| std::vector<IValue> args1{at::randn({1}), 10}; |
| std::vector<IValue> args2{at::randn({3, 3, 3}), 10}; |
| |
| testStaticRuntime(src, args1, args2); |
| } |
| |
| TEST(StaticRuntime, BasicWhileLoop) { |
| const auto src = R"JIT( |
| def forward(self, x, loop_max: int): |
| y = x.clone() |
| loop_count = 0 |
| while loop_count < loop_max: |
| y += 1 |
| loop_count += 1 |
| return y |
| )JIT"; |
| |
| std::vector<IValue> args1{at::randn({1}), 10}; |
| std::vector<IValue> args2{at::randn({3, 3, 3}), 10}; |
| |
| testStaticRuntime(src, args1, args2); |
| } |
| |
| TEST(StaticRuntime, NestedLoops) { |
| const auto src = R"JIT( |
| def forward(self, x, loop_max: int): |
| y = x.clone() |
| even: List[int] = [] |
| odd: List[int] = [] |
| |
| for i in range(loop_max): |
| if i % 2: |
| odd.append(i) |
| else: |
| even.append(i) |
| |
| for j in range(i): |
| y += 1 |
| |
| return y, even, odd |
| )JIT"; |
| |
| std::vector<IValue> args1{at::randn({1}), 10}; |
| std::vector<IValue> args2{at::randn({3, 3, 3}), 10}; |
| |
| testStaticRuntime(src, args1, args2); |
| } |
| |
| TEST(StaticRuntime, TupleIndex) { |
| const auto src = R"JIT( |
| def forward(self, idx: int, tup: Tuple[int, int]): |
| a = tup[idx] |
| return a * a |
| )JIT"; |
| const auto tuple = c10::ivalue::Tuple::create({1, 2}); |
| testStaticRuntime(src, {1, tuple}, {-1, tuple}); |
| |
| torch::jit::Module mod("module"); |
| mod.define(src); |
| StaticModule smod(mod); |
| EXPECT_THROW(smod({100, tuple}), std::out_of_range); |
| } |
| |
| TEST(StaticRuntime, RaiseException) { |
| const auto src = R"IR( |
| graph(%str: str): |
| %none: NoneType = prim::Constant() |
| prim::RaiseException(%str, %none) |
| return (%none) |
| )IR"; |
| auto graph = getGraphFromIR(src); |
| StaticModule smod(graph); |
| const auto msg = "exception message"; |
| EXPECT_THROW( |
| { |
| try { |
| smod({msg}); |
| } catch (const std::runtime_error& e) { |
| EXPECT_STREQ(msg, e.what()); |
| throw; |
| } |
| }, |
| std::runtime_error); |
| } |
| |
| TEST(StaticRuntime, Uninitialized) { |
| const auto src = R"IR( |
| graph(): |
| %0: int = prim::Uninitialized() |
| return (%0) |
| )IR"; |
| auto graph = getGraphFromIR(src); |
| StaticModule smod(graph); |
| const auto ret = smod({}); |
| // If a and b are both uninitialized, then a != b. So just check that the type |
| // is Any |
| EXPECT_EQ(ret.type()->kind(), c10::TypeKind::AnyType); |
| } |
| |
| TEST(StaticRuntime, Format) { |
| const auto src = R"JIT( |
| def forward(self, arg1: int, arg2: Tensor, arg3: str): |
| a = "arg1: {}, arg2: {}, arg3: {}".format(arg1, arg2, arg3) |
| return a[::] |
| )JIT"; |
| testStaticRuntime(src, {1, at::randn({3}), "str"}); |
| } |
| |
| TEST(StaticRuntime, Device) { |
| const auto src = R"JIT( |
| def forward(self, x): |
| return x.device, x.device |
| )JIT"; |
| testStaticRuntime(src, {at::tensor({1})}); |
| } |
| |
| TEST(StaticRuntime, Dtype) { |
| const auto src = R"JIT( |
| def forward(self, x, y): |
| return x.dtype, y.dtype |
| )JIT"; |
| testStaticRuntime( |
| src, {at::tensor({1}, at::kLong), at::tensor({1}, at::kFloat)}); |
| } |
| |
| TEST(StaticRuntime, Dim) { |
| const auto src = R"JIT( |
| def forward(self, x, y): |
| return x.dim(), y.dim() |
| )JIT"; |
| testStaticRuntime(src, {at::randn({2, 2}), at::randn({1})}); |
| } |
| |
| TEST(StaticRuntime, Not) { |
| const auto src = R"JIT( |
| def forward(self, x: bool, y: bool): |
| return not x, not y |
| )JIT"; |
| testStaticRuntime(src, {true, false}); |
| } |
| |
| TEST(StaticRuntime, Bool) { |
| const auto src = R"JIT( |
| def forward(self, x: Tensor, y: int, z: float): |
| return bool(x), bool(y), bool(z) |
| )JIT"; |
| testStaticRuntime(src, {at::randn({1}), 0, 1.151}, {at::zeros({1}), 1, 0.0}); |
| } |
| |
| TEST(StaticRuntime, IsCuda) { |
| const auto src = R"JIT( |
| def forward(self, x: Tensor, y: Tensor): |
| return x.is_cuda, y.is_cuda |
| )JIT"; |
| testStaticRuntime(src, {at::randn({1}), at::randn({1})}); |
| } |
| |
| TEST(StaticRuntime, ToList) { |
| const auto src = R"JIT( |
| graph(%x: Tensor): |
| %type: int = prim::Constant[value=1]() |
| %dim: int = aten::dim(%x) |
| %ret: float[] = prim::tolist(%x, %dim, %type) |
| return (%ret) |
| )JIT"; |
| testStaticRuntime(src, {at::randn({2, 2})}); |
| } |
| |
| TEST(StaticRuntime, IfThenElse) { |
| const auto src = R"IR( |
| graph(%cond: bool, %a: Tensor, %b: Tensor): |
| %none: NoneType = prim::Constant() |
| %c: Tensor = prim::IfThenElse(%cond, %a, %b) |
| %d: Tensor = aten::clone(%c, %none) |
| return (%d) |
| )IR"; |
| |
| std::vector<IValue> args1{true, at::randn({1}), at::randn({1})}; |
| std::vector<IValue> args2{false, at::randn({1}), at::randn({1})}; |
| |
| testStaticRuntime(src, args1); |
| testStaticRuntime(src, args2); |
| } |
| |
| TEST(StaticRuntime, EmptyIfBlock) { |
| const auto src = |
| R"JIT( |
| def forward(self, cond: bool, a: Tensor, b: Tensor): |
| l = [] |
| if cond: |
| l.append((a + b).clone()) |
| return l |
| )JIT"; |
| |
| testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})}); |
| testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})}); |
| } |
| |
| TEST(StaticRuntime, EmptyNestedIfBlock) { |
| const auto src = |
| R"JIT( |
| def forward(self, cond: bool, a: Tensor, b: Tensor): |
| l = [] |
| if cond: |
| if cond: |
| l.append((a + b).clone()) |
| return l |
| )JIT"; |
| |
| testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})}); |
| testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})}); |
| } |
| |
| TEST(StaticRuntime, StackEmpty) { |
| const auto src = R"JIT( |
| def forward(self): |
| x = torch.stack([]) |
| return x |
| )JIT"; |
| |
| torch::jit::Module mod("mod"); |
| mod.define(src); |
| |
| torch::jit::StaticModule smod(mod); |
| EXPECT_THROW(smod({}), c10::Error); |
| } |
| |
| TEST(StaticRuntime, ConcatEmpty) { |
| const auto src = R"JIT( |
| def forward(self): |
| x = torch.concat([]) |
| return x |
| )JIT"; |
| |
| torch::jit::Module mod("mod"); |
| mod.define(src); |
| |
| torch::jit::StaticModule smod(mod); |
| EXPECT_THROW(smod({}), c10::Error); |
| } |
| |
| TEST(StaticRuntime, IntImplicit) { |
| const auto src = R"IR( |
| graph(%a: Tensor): |
| %y: int = aten::IntImplicit(%a) |
| return (%y) |
| )IR"; |
| testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()}); |
| } |
| |
| TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) { |
| const auto src = R"IR( |
| graph(%a: Tensor): |
| %y: int = aten::IntImplicit(%a) |
| return (%y) |
| )IR"; |
| auto graph = getGraphFromIR(src); |
| torch::jit::StaticModule smod(graph); |
| // Not 0D tensor |
| EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error); |
| // Wrong dtype |
| EXPECT_THROW( |
| smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error); |
| } |
| |
| TEST(StaticRuntime, Select) { |
| const auto src = R"IR( |
| graph(%a: Tensor, %dim: int, %index: int): |
| %none: NoneType = prim::Constant() |
| %b: Tensor = aten::select(%a, %dim, %index) |
| %c: Tensor = aten::clone(%b, %none) |
| return (%c) |
| )IR"; |
| testStaticRuntime(src, {at::randn({2, 2}), 0, 1}); |
| } |
| |
| TEST(StaticRuntime, ReshapeAs) { |
| const auto src = R"JIT( |
| def forward(self, a, b): |
| return a.reshape_as(b).clone() |
| )JIT"; |
| testStaticRuntime(src, {at::randn({2, 2}), at::randn({4})}); |
| } |
| |
| TEST(StaticRuntime, MoveCtor) { |
| auto mod = getDeepAndWideSciptModel(); |
| std::vector<IValue> args{ |
| at::randn({1, 1, 32}), at::randn({1, 1, 32}), at::randn({1, 50})}; |
| |
| torch::jit::StaticModule smod(mod); |
| |
| torch::jit::StaticRuntime runtime(smod); |
| auto expected = runtime(args); |
| |
| torch::jit::StaticRuntime new_runtime(std::move(runtime)); |
| auto actual = new_runtime(args); |
| compareResults(expected, actual); |
| } |
| |
| TEST(StaticRuntime, SingleBlockIfReturnList) { |
| const auto src = R"JIT( |
| def forward(self, a, b, cond: bool): |
| lst = [] |
| if cond: |
| lst.append(a + b) |
| return lst |
| )JIT"; |
| std::vector<IValue> args1{at::randn({1}), at::randn({1}), true}; |
| std::vector<IValue> args2{at::randn({42, 42}), at::randn({42, 42}), false}; |
| testStaticRuntime(src, args1, args2); |
| } |
| |
| TEST(StaticRuntime, NestedBlockIfReturnList) { |
| const auto src = R"JIT( |
| def forward(self, a, b, cond1: bool, cond2: bool): |
| if cond1: |
| lst = [] |
| if cond2: |
| lst.append(a + b) |
| lst.append(a * b) |
| return lst |
| return [] |
| )JIT"; |
| std::vector<IValue> args1{at::randn({1}), at::randn({1}), true, true}; |
| std::vector<IValue> args2{ |
| at::randn({42, 42}), at::randn({42, 42}), true, false}; |
| testStaticRuntime(src, args1, args2); |
| } |
| |
| TEST(StaticRuntime, ClampNaNToNum) { |
| const auto src1 = R"JIT( |
| def forward(self, a): |
| return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone() |
| )JIT"; |
| |
| const auto src2 = R"JIT( |
| def forward(self, a, nan: float): |
| return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone() |
| )JIT"; |
| |
| const auto src3 = R"JIT( |
| def forward(self, a): |
| return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone() |
| )JIT"; |
| |
| auto a = at::tensor({ |
| std::numeric_limits<float>::quiet_NaN(), |
| std::numeric_limits<float>::infinity(), |
| -std::numeric_limits<float>::infinity(), |
| 0.0f, |
| 3.0f |
| }); |
| auto b = a.repeat({10, 5}); |
| |
| // Have to use_allclose even though all NaNs will be replaced - testStaticRuntime |
| // also checks inputs at the end to make sure they're not changed |
| testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| |
| testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| |
| testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| |
| // Non-NNC path |
| testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true); |
| } |
| |
| TEST(StaticRuntime, PrepackWeights) { |
| const std::string src = R"IR( |
| graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor): |
| %none: NoneType = prim::Constant() |
| %result: Tensor = fb::quantized_linear_unpacked_weight_v2(%input, %weight, %bias, %scale, %zero_point) |
| %dequantized: Tensor = aten::dequantize(%result) |
| return (%dequantized) |
| )IR"; |
| |
| auto graph = getGraphFromIR(src); |
| PrepackWeights(graph); |
| ASSERT_TRUE(graphHasOp(graph, "quantized::linear")); |
| ASSERT_TRUE(graphHasOp(graph, "quantized::linear_prepack")); |
| ASSERT_FALSE(graphHasOp(graph, "fb::quantized_linear_unpacked_weight_v2")); |
| |
| auto scale = at::tensor({2}, at::kFloat); |
| auto zero_point = at::tensor({3}, at::kLong); |
| |
| auto weight = |
| at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8); |
| auto input = |
| at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8); |
| auto args1 = std::vector<IValue>{input, weight, c10::nullopt, scale, zero_point}; |
| |
| auto weight_2 = |
| at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8); |
| auto input_2 = |
| at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8); |
| auto bias_2 = torch::randn({3}, torch::kFloat); |
| auto args2 = std::vector<IValue>{input, weight, bias_2, scale, zero_point}; |
| |
| testStaticRuntime(src, args1); |
| testStaticRuntime(src, args2); |
| } |
| |
| TEST(StaticRuntime, IfReturningTuple) { |
| const auto src = R"JIT( |
| def forward(self, x, y, cond: bool, idx: int): |
| if cond: |
| tup = (x, y) |
| else: |
| tup = (x, x) |
| return tup[idx] |
| )JIT"; |
| |
| std::vector<IValue> args{at::randn({3}), at::randn({3}), true, 0}; |
| testStaticRuntime(src, args); |
| } |