| #include <gtest/gtest.h> |
| |
| #include <test/cpp/jit/test_utils.h> |
| #include <torch/csrc/jit/ir/irparser.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| TEST(IRTest, Attributes) { |
| Graph g; |
| auto one = attr::alpha; |
| auto two = attr::device; |
| auto three = attr::end; |
| auto four = attr::perm; |
| Node* n = g.create(Symbol::fromQualString("foo::bar")); |
| Node& attr = *n; |
| attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what"); |
| ASSERT_EQ(attr.f(one), 3.4); |
| ASSERT_EQ(attr.s(three), "what"); |
| ASSERT_EQ(attr.i(two), 5); |
| attr.s_(one, "no"); |
| ASSERT_EQ(attr.s(one), "no"); |
| ASSERT_TRUE(attr.hasAttribute(three)); |
| ASSERT_TRUE(!attr.hasAttribute(four)); |
| attr.ss_(two, {"hi", "now"}); |
| ASSERT_EQ(attr.ss(two).at(1), "now"); |
| |
| Node* n2 = g.create(Symbol::fromQualString("foo::baz")); |
| Node& attr2 = *n2; |
| attr2.copyAttributes(attr); |
| ASSERT_EQ(attr2.s(one), "no"); |
| attr2.f_(one, 5); |
| ASSERT_EQ(attr.s(one), "no"); |
| ASSERT_EQ(attr2.f(one), 5); |
| } |
| |
| TEST(IRTest, Blocks) { |
| auto g = std::make_shared<Graph>(); |
| const auto graph_string = R"IR( |
| graph(%a : Tensor, |
| %b : Tensor, |
| %c : Tensor): |
| %2 : int = prim::Constant[value=1]() |
| %3 : Tensor = aten::add(%a, %b, %2) |
| %5 : Tensor = prim::If(%c) |
| block0(): |
| %6 : int = prim::Constant[value=1]() |
| %7 : Tensor = aten::add(%3, %3, %6) |
| -> (%7) |
| block1(): |
| %8 : int = prim::Constant[value=1]() |
| %9 : Tensor = aten::add(%b, %3, %8) |
| %10 : int = prim::Constant[value=1]() |
| %11 : Tensor = aten::add(%9, %3, %10) |
| -> (%11) |
| %12 : int = prim::Constant[value=1]() |
| %13 : Tensor = aten::add(%5, %3, %12) |
| return (%13))IR"; |
| torch::jit::parseIR(graph_string, g.get()); |
| |
| g->lint(); |
| testing::FileCheck() |
| .check("add") |
| ->check("prim::If") |
| ->check("block0") |
| ->check("aten::add") |
| ->check("block1") |
| ->check_count("aten::add", 3) |
| ->run(*g); |
| |
| // Removes block0 of the conditional |
| for (auto* node : g->block()->nodes()) { |
| if (node->kind() == prim::If) { |
| node->eraseBlock(0); |
| break; |
| } |
| } |
| |
| testing::FileCheck() |
| .check("add") |
| ->check("prim::If") |
| ->check("block0") |
| ->check_not("block") |
| ->run(*g); |
| g->lint(); |
| // test recursive copy of blocks works |
| auto g2 = g->copy(); |
| testing::FileCheck() |
| .check("add") |
| ->check("prim::If") |
| ->check("block0") |
| ->check_not("block") |
| ->run(*g2); |
| } |
| |
| TEST(IRTest, CommonAncestor) { |
| std::string input_str = R"( |
| graph(%x : Tensor, |
| %a.1 : bool, |
| %b.1 : bool, |
| %c.1 : bool): |
| %4 : int = prim::If(%a.1) |
| block0(): |
| %5 : int = prim::If(%b.1) |
| block0(): |
| %6 : int = prim::Constant[value=2]() |
| -> (%6) |
| block1(): |
| %7 : int = prim::Constant[value=3]() |
| -> (%7) |
| -> (%5) |
| block1(): |
| %8 : int = prim::If(%c.1) |
| block0(): |
| %9 : int = prim::Constant[value=4]() |
| -> (%9) |
| block1(): |
| %10 : int = prim::Constant[value=5]() |
| -> (%10) |
| -> (%8) |
| return (%4) |
| )"; |
| |
| torch::jit::Graph g; |
| std::unordered_map<std::string, torch::jit::Value*> name_to_value; |
| torch::jit::parseIR(input_str, &g, name_to_value); |
| |
| std::vector<std::string> value_names{"6", "7", "9", "10"}; |
| std::unordered_set<std::string> value_names_set( |
| value_names.begin(), value_names.end()); |
| |
| /* clang-format off */ |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
| int ref_blocks_from_graph[4][4] = { |
| /* (6, 6), (6, 7), (6, 9), (6, 10) */ |
| { 2, 1, 0, 0 }, |
| /* (7, 6), (7, 7), (7, 9), (7, 10) */ |
| { 1, 2, 0, 0 }, |
| /* (9, 6), (9, 7), (9, 9), (9, 10) */ |
| { 0, 0, 2, 1, }, |
| /* (10, 6),(10, 7),(10, 9),(10, 10) */ |
| { 0, 0, 1, 2 } |
| }; |
| /* clang-format on */ |
| |
| for (size_t i = 0; i < value_names.size(); ++i) { |
| Value* i_val = name_to_value[value_names[i]]; |
| for (size_t j = 0; j < value_names.size(); ++j) { |
| Value* j_val = name_to_value[value_names[j]]; |
| Block* common_ancestor = |
| i_val->node()->findCommonAncestorBlockWith(j_val->node()); |
| int blocks_from_graph_block = |
| common_ancestor->param_node()->blocksFromGraphBlock(); |
| ASSERT_EQ(blocks_from_graph_block, ref_blocks_from_graph[i][j]); |
| } |
| } |
| } |
| |
| TEST(IRTest, OperatorMap) { |
| OperatorMap<int> op_map; |
| const char* literal1 = |
| "aten::dropout(Tensor input, float p, bool train) -> Tensor"; |
| const char* literal2 = |
| "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor"; |
| const char* literal3 = |
| "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor"; |
| const char* literal4 = |
| "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor"; |
| const char* literal5 = |
| "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor"; |
| const char* literal6 = |
| "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor"; |
| std::shared_ptr<Operator> op1 = getOperatorForLiteral(literal1); |
| std::shared_ptr<Operator> op2 = getOperatorForLiteral(literal2); |
| std::shared_ptr<Operator> op3 = getOperatorForLiteral(literal3); |
| std::shared_ptr<Operator> op4 = getOperatorForLiteral(literal4); |
| std::shared_ptr<Operator> op5 = getOperatorForLiteral(literal5); |
| std::shared_ptr<Operator> op6 = getOperatorForLiteral(literal6); |
| op_map.insert(op1, 1); |
| op_map.insert({{op2, 2}, {op3, 3}}); |
| op_map.insert({{op4, 4}, {op5, 5}}); |
| op_map.insert(op6, 6); |
| ASSERT_TRUE(op_map.contains(*op1)); |
| ASSERT_TRUE(op_map.contains(*op2)); |
| ASSERT_TRUE(op_map.contains(*op3)); |
| ASSERT_TRUE(op_map.contains(*op4)); |
| ASSERT_TRUE(op_map.contains(*op5)); |
| ASSERT_TRUE(op_map.contains(*op6)); |
| op_map.erase(op6); |
| op_map.erase(op3); |
| op_map.erase(op1); |
| ASSERT_FALSE(op_map.contains(*op1)); |
| ASSERT_FALSE(op_map.contains(*op3)); |
| ASSERT_FALSE(op_map.contains(*op6)); |
| op_map.insert(op1, 1); |
| ASSERT_TRUE(op_map.contains(*op1)); |
| c10::optional<int> o1 = op_map.find(*op1); |
| ASSERT_TRUE(o1.has_value()); |
| c10::optional<int> o2 = op_map.find(*op2); |
| ASSERT_TRUE(o2.has_value()); |
| c10::optional<int> o3 = op_map.find(*op3); |
| ASSERT_FALSE(o3.has_value()); |
| c10::optional<int> o4 = op_map.find(*op4); |
| ASSERT_TRUE(o4.has_value()); |
| c10::optional<int> o5 = op_map.find(*op5); |
| ASSERT_TRUE(o5.has_value()); |
| c10::optional<int> o6 = op_map.find(*op6); |
| ASSERT_FALSE(o6.has_value()); |
| } |
| |
| } // namespace jit |
| } // namespace torch |