| #include <gtest/gtest.h> |
| |
| #include <torch/csrc/jit/ir/ir.h> |
| #include <torch/csrc/jit/ir/irparser.h> |
| #include <torch/csrc/jit/testing/file_check.h> |
| |
| #include <sstream> |
| #include <string> |
| |
| namespace torch { |
| namespace jit { |
| |
| /** \brief Parse IR from \p S, print the parsed graph and verify that the output |
| * string matches the original string. |
| * |
| * The function is sensitive to value naming and whitespace, so it should be |
| * used with care. Nevertheless, it helps to keep tests more compact. |
| */ |
| static void checkRoundtrip(const std::string& s) { |
| auto graph = std::make_shared<Graph>(); |
| parseIR(s, &*graph); |
| std::ostringstream ss; |
| ss << *graph; |
| std::string parsed = ss.str(); |
| |
| // Skip whitespace in the beginning of the input string. |
| int i = 0; |
| for (char c : s) { |
| if (!isspace(c)) { |
| break; |
| } |
| i++; |
| } |
| std::string original = s.substr(i, s.size()); |
| if (original != parsed) { |
| std::cerr << "Input:" << std::endl << original << std::endl; |
| std::cerr << "Parsed:" << std::endl << parsed << std::endl; |
| } |
| AT_ASSERT(original == parsed); |
| } |
| |
| TEST(IRParserTest, Basic) { |
| auto graph = std::make_shared<Graph>(); |
| std::unordered_map<std::string, Value*> vmap; |
| parseIR( |
| R"IR( |
| graph(%0 : Tensor, %1 : Tensor): |
| %2 : Tensor = foo::add(%0, %1) |
| %res, %3 = foo::mul(%0, %2) |
| %x, %y = foo::combine(%res, %2, %3) |
| return (%x, %y, %res))IR", |
| &*graph, |
| vmap); |
| |
| AT_ASSERT(graph->inputs().size() == 2); |
| AT_ASSERT(graph->outputs().size() == 3); |
| Value* x = graph->outputs()[0]; |
| Value* y = graph->outputs()[1]; |
| Value* res = graph->outputs()[2]; |
| Value* t0 = graph->inputs()[0]; |
| Value* t1 = graph->inputs()[1]; |
| AT_ASSERT(vmap["x"] == x); |
| AT_ASSERT(vmap["y"] == y); |
| AT_ASSERT(vmap["res"] == res); |
| AT_ASSERT(vmap["0"] == t0); |
| AT_ASSERT(vmap["1"] == t1); |
| AT_ASSERT(x->node() == y->node()); |
| Node* comb = x->node(); |
| Value* t2 = comb->inputs()[1]; |
| Value* t3 = comb->inputs()[2]; |
| AT_ASSERT(vmap["2"] == t2); |
| AT_ASSERT(vmap["3"] == t3); |
| AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine")); |
| AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y})); |
| AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3})); |
| Node* mul = res->node(); |
| AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul")); |
| AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2})); |
| AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3})); |
| Node* add = t2->node(); |
| AT_ASSERT(add->kind().toQualString() == std::string("foo::add")); |
| AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1})); |
| AT_ASSERT(add->outputs() == std::vector<Value*>({t2})); |
| } |
| |
| TEST(IRParserTest, NestedBlock) { |
| checkRoundtrip(R"IR( |
| graph(): |
| %0 : Tensor = a::a() |
| block0(): |
| %1 : Tensor = b::b() |
| block0(): |
| %2 : Tensor = c::c() |
| -> () |
| -> () |
| %3 : Tensor = d::d() |
| return (%3) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, If) { |
| checkRoundtrip(R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : int = prim::Constant[value=1]() |
| %4 : Tensor = aten::add(%0, %1, %3) |
| %5 : Tensor = prim::If(%2) |
| block0(): |
| %6 : int = prim::Constant[value=1]() |
| %7 : Tensor = aten::add(%1, %3, %6) |
| %8 : int = prim::Constant[value=1]() |
| %9 : Tensor = aten::add(%7, %3, %8) |
| -> (%9) |
| %10 : int = prim::Constant[value=1]() |
| %11 : Tensor = aten::add(%5, %3, %10) |
| return (%11) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, If2) { |
| checkRoundtrip(R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : int = prim::Constant[value=-1]() |
| %4 : Tensor = aten::add(%0, %1, %3) |
| %5 : Tensor = prim::If(%2) |
| block0(): |
| %6 : int = prim::Constant[value=1]() |
| %7 : Tensor = aten::add(%1, %3, %6) |
| %8 : int = prim::Constant[value=1]() |
| %9 : Tensor = aten::add(%7, %3, %8) |
| -> (%9) |
| %10 : int = prim::Constant[value=-987]() |
| %11 : Tensor = aten::add(%5, %3, %10) |
| return (%11) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, InferredTypeIsTensor) { |
| auto graph = std::make_shared<Graph>(); |
| parseIR( |
| R"IR( |
| graph(%a): |
| return (%a))IR", |
| &*graph); |
| AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get())); |
| } |
| |
| TEST(IRParserTest, ValueReuse) { |
| // Check that parser correctly handles values reusing the same name. |
| auto graph = std::make_shared<Graph>(); |
| parseIR( |
| R"IR( |
| graph(%x): |
| %x = a::a(%x) |
| %x = b::b(%x) |
| return (%x))IR", |
| &*graph); |
| Value* x0 = graph->inputs()[0]; |
| Value* x2 = graph->outputs()[0]; |
| Node* b = x2->node(); |
| Value* x1 = b->inputs()[0]; |
| Node* a = x1->node(); |
| AT_ASSERT(a->inputs() == std::vector<Value*>({x0})); |
| AT_ASSERT(a->outputs() == std::vector<Value*>({x1})); |
| AT_ASSERT(b->inputs() == std::vector<Value*>({x1})); |
| AT_ASSERT(b->outputs() == std::vector<Value*>({x2})); |
| } |
| |
| TEST(IRParserTest, Attributes) { |
| // Check that parser handles attributes and types. |
| checkRoundtrip( |
| R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3., s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0) |
| %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3., s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0) |
| %7 : float = vvv::vvv[s_asdf="hello"](%0) |
| %8 : string = z::z() |
| return (%7) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, OptionalTypes) { |
| checkRoundtrip( |
| R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : int? = prim::Constant() |
| return (%3) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, StarTensor) { |
| checkRoundtrip( |
| R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : Float(*, *, *) = prim::Constant() |
| return (%3) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, UnshapedTensor) { |
| checkRoundtrip( |
| R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : Long() = prim::Constant() |
| return (%3) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, ShapedTensor) { |
| checkRoundtrip( |
| R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : Double(4, 4, 5) = prim::Constant() |
| return (%3) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, NestedContrainer) { |
| checkRoundtrip( |
| R"IR( |
| graph(): |
| %0 : float[] = prim::Constant[value=[1., 2., 3.]]() |
| %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]() |
| %2 : (float[], str[]) = prim::TupleConstruct(%0, %1) |
| return (%2) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, MalformedShapeAnnotation) { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
| EXPECT_ANY_THROW(checkRoundtrip( |
| R"IR( |
| graph(%0 : Tensor, |
| %1 : Tensor, |
| %2 : Tensor): |
| %3 : Double(4!, 4, 5) = prim::Constant() |
| return (%3) |
| )IR")); |
| } |
| |
| TEST(IRParserTest, FileCheck) { |
| auto graph = std::make_shared<Graph>(); |
| const std::string& text = |
| R"IR( |
| graph(%a): |
| # CHECK: return |
| return (%a))IR"; |
| |
| parseIR(text, &*graph); |
| AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get())); |
| torch::jit::testing::FileCheck().run(text, *graph); |
| } |
| |
| TEST(IRParserTest, Strides) { |
| auto graph = std::make_shared<Graph>(); |
| std::unordered_map<std::string, Value*> vmap; |
| parseIR( |
| R"IR( |
| graph(%a : Float(4, 5), |
| %b : Float(4, 5, strides=[5, 1]), |
| %c : Double(*, *)): |
| return (%a) |
| )IR", |
| &*graph, |
| vmap); |
| Value* a = graph->inputs()[0]; |
| Value* b = graph->inputs()[1]; |
| Value* c = graph->inputs()[2]; |
| |
| auto a_type = a->type()->cast<TensorType>(); |
| auto a_sizes = *a_type->sizes().concrete_sizes(); |
| auto a_strides = a_type->strides().concrete_sizes(); |
| AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5); |
| AT_ASSERT(a_strides == std::nullopt); |
| |
| auto b_type = b->type()->cast<TensorType>(); |
| auto b_sizes = *b_type->sizes().concrete_sizes(); |
| auto b_strides = *(b_type->strides().sizes()); |
| AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5); |
| AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1); |
| |
| auto c_type = c->type()->cast<TensorType>(); |
| AT_ASSERT(*c_type->sizes().size() == 2); |
| AT_ASSERT(c_type->sizes().concrete_sizes() == std::nullopt); |
| AT_ASSERT(c_type->strides().concrete_sizes() == std::nullopt); |
| } |
| |
| TEST(IRParserTest, MalformedStrides) { |
| auto graph = std::make_shared<Graph>(); |
| std::unordered_map<std::string, Value*> vmap; |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
| EXPECT_ANY_THROW(parseIR( |
| R"IR( |
| graph(%a : Float(4, strides=[5], 5)): |
| return (%a) |
| )IR", |
| &*graph, |
| vmap)); |
| } |
| |
| TEST(IRParserTest, TensorShapes) { |
| checkRoundtrip( |
| R"IR( |
| graph(%a : Float(4, 5), |
| %b : Float(4, 5, strides=[5, 1]), |
| %c : Double(*, *)): |
| return (%a) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, DeviceAndRequiresGradTensors) { |
| checkRoundtrip( |
| R"IR( |
| graph(%a : Float(*, *, device=cpu), |
| %b : Float(*, *, requires_grad=1), |
| %c : Long(5, 10, requires_grad=1, device=cpu), |
| %d : Float(5, requires_grad=0, device=cuda:2), |
| %e : Long(4, 3, 1, strides=[6, 2, 1], requires_grad=0, device=cuda:1), |
| %f : Float(), |
| %g : Float(device=cpu), |
| %h : Float(requires_grad=1), |
| %i : Float(requires_grad=0, device=cuda:1), |
| %j : Double(*, *, requires_grad=0)): |
| return (%a) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, ListConstant) { |
| auto graph = std::make_shared<Graph>(); |
| parseIR( |
| R"IR( |
| graph(): |
| %d : int[] = prim::Constant[value=[1,2,3]]() |
| return (%d) |
| )IR", |
| &*graph); |
| Node* n = graph->outputs()[0]->node(); |
| AT_ASSERT(n->kind() == prim::Constant); |
| AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival); |
| const auto& genericList = n->ival(attr::value).toList(); |
| std::vector<int> int_vals; |
| // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) |
| for (const IValue& ival : genericList) { |
| int_vals.push_back(ival.toInt()); |
| } |
| AT_ASSERT(int_vals.size() == 3); |
| AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3); |
| } |
| |
| TEST(IRParserTest, PartialStarTensor) { |
| checkRoundtrip( |
| R"IR( |
| graph(%x : Float(10, *, 10)): |
| return (%x) |
| )IR"); |
| } |
| |
| TEST(IRParserTest, ComplexTensorAttributes) { |
| checkRoundtrip( |
| R"IR( |
| graph(%x : Double(*, 200, *, requires_grad=1, device=cuda:1), |
| %b : Float(5, *, requires_grad=1), |
| %c : Long(*, 10, device=cpu)): |
| return (%x) |
| )IR"); |
| } |
| } // namespace jit |
| } // namespace torch |