blob: 3272d4096c7a33d4e969c4cea14f16487f4f48a1 [file] [log] [blame]
#include <gtest/gtest.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/TensorOperators.h>
namespace torch {
namespace jit {
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
// Fixture to set up a graph and make assertions clearer
class TopologicalMoveTest : public ::testing::Test {
protected:
TopologicalMoveTest() {
createGraph();
aliasDb = torch::make_unique<AliasDb>(graph);
}
// Nodes are named after their output.
// e.g. "a" is an alias for "the node that outputs the value `a`"
void createGraph() {
graph = std::make_shared<Graph>();
createNode("a", {});
createNode("b", {"a"});
createNode("c", {});
createNode("d", {"a", "b"});
createNode("e", {"c", "b"});
createNode("f", {"e"});
createNode("g", {"e"});
createNode("h", {"g"});
createNode("i", {"g"});
createNode("j", {"i"});
createNode("k", {"i"});
createNode("l", {"a"});
createNode("m", {}, {"l"}); // block depends on l
createNode("n", {"m"});
createNode("o", {"n"});
createNode("p", {});
createNode("q", {});
createNode("r", {"q"});
createNode("s", {"q"});
graph->lint();
}
void createNode(
const std::string& name,
const std::vector<std::string>& inputNames,
const std::vector<std::string>& blockInputNames = {}) {
std::vector<Value*> inputs;
for (const auto& name_ : inputNames) {
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
inputs.push_back(nodes.at(name_)->output());
}
auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
node->output()->setDebugName(name);
nodes[name] = node;
if (blockInputNames.size() != 0) {
node->addBlock();
std::vector<Value*> blockDeps;
for (const auto& name_ : blockInputNames) {
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
blockDeps.push_back(nodes.at(name_)->output());
}
auto block = node->blocks().at(0);
block->appendNode(graph->create(prim::AutogradZero, blockDeps));
}
}
bool moveBeforeTopologicallyValid(
const std::string& toInsert,
const std::string& insertPoint) {
std::function<bool(Node*, Node*)> func =
[this](Node* toInsert, Node* insertPoint) {
return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
};
return moveWithChecks(toInsert, insertPoint, func);
}
bool moveAfterTopologicallyValid(
const std::string& toInsert,
const std::string& insertPoint) {
std::function<bool(Node*, Node*)> func =
[this](Node* toInsert, Node* insertPoint) {
return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
};
return moveWithChecks(toInsert, insertPoint, func);
}
bool moveWithChecks(
const std::string& toInsert,
const std::string& insertPoint,
std::function<bool(Node*, Node*)> func) {
auto n = nodes.at(toInsert);
auto insert = nodes.at(insertPoint);
bool isAfter = n->isAfter(insert);
std::vector<Node*> originalOrdering;
Node* original = isAfter ? n->next() : n->prev();
auto curNode = original;
while (curNode != n->owningBlock()->return_node()) {
originalOrdering.push_back(curNode);
if (isAfter) {
curNode = curNode->next();
} else {
curNode = curNode->prev();
}
}
const auto couldMove = func(n, insert);
// Check the graph is okay
graph->lint();
// If this is the picture of nodes
// <some nodes> ... toInsert ... <some more nodes> ... insertPoint
// ^----------^ check that these nodes haven't moved
curNode = original;
size_t idx = 0;
while (curNode != n->owningBlock()->return_node()) {
EXPECT_TRUE(originalOrdering[idx] == curNode);
if (isAfter) {
curNode = curNode->next();
} else {
curNode = curNode->prev();
}
idx++;
}
return couldMove;
}
void checkPostCondition(
const std::string& toInsert,
const std::string& insertPoint,
bool after) {
if (after) {
EXPECT_EQ(nodes.at(toInsert)->prev(), nodes.at(insertPoint));
} else {
EXPECT_EQ(nodes.at(toInsert)->next(), nodes.at(insertPoint));
}
}
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::shared_ptr<Graph> graph;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unique_ptr<AliasDb> aliasDb;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unordered_map<std::string, Node*> nodes;
};
TEST_F(TopologicalMoveTest, SplitsDeps) {
// Check that we are removing `this`'s deps properly when we need to split
// `this` and deps (see code for what the hell that means)
EXPECT_TRUE(moveBeforeTopologicallyValid("q", "s"));
checkPostCondition("q", "s", false);
}
// Move after
TEST_F(TopologicalMoveTest, MoveAfterBackwardSimple) {
// Simple move backward
EXPECT_TRUE(moveAfterTopologicallyValid("c", "a"));
checkPostCondition("c", "a", true);
}
TEST_F(TopologicalMoveTest, MoveAfterBackwardInvalid) {
// simple invalid move backward
EXPECT_FALSE(moveAfterTopologicallyValid("d", "a"));
}
TEST_F(TopologicalMoveTest, MoveAfterNoOp) {
// doesn't actually move anything
EXPECT_TRUE(moveAfterTopologicallyValid("f", "e"));
checkPostCondition("f", "e", true);
}
TEST_F(TopologicalMoveTest, MoveAfterBackwardMultipleDeps) {
// move backward with multiple dependencies
EXPECT_TRUE(moveAfterTopologicallyValid("e", "c"));
checkPostCondition("e", "c", true);
}
TEST_F(TopologicalMoveTest, MoveAfterBackwardNonZeroWorkingSet) {
// Move backward with non-zero working set
EXPECT_TRUE(moveAfterTopologicallyValid("k", "f"));
checkPostCondition("k", "f", true);
}
TEST_F(TopologicalMoveTest, MoveAfterForwardSimple) {
// Simple move forward
EXPECT_TRUE(moveAfterTopologicallyValid("c", "d"));
checkPostCondition("c", "d", true);
}
TEST_F(TopologicalMoveTest, MoveAfterForwardNonZeroWorkingSet) {
// Move forward with non-zero working set
EXPECT_TRUE(moveAfterTopologicallyValid("f", "l"));
checkPostCondition("f", "l", true);
}
// Move before
TEST_F(TopologicalMoveTest, MoveBeforeForwardSimple) {
// Simple move forward
EXPECT_TRUE(moveBeforeTopologicallyValid("b", "d"));
checkPostCondition("b", "d", false);
}
TEST_F(TopologicalMoveTest, MoveBeforeBackwardSimple) {
// Simple move backward
EXPECT_TRUE(moveBeforeTopologicallyValid("c", "a"));
checkPostCondition("c", "a", false);
}
TEST_F(TopologicalMoveTest, MoveBeforeNoOp) {
// doesn't actually move anything
EXPECT_TRUE(moveBeforeTopologicallyValid("a", "b"));
checkPostCondition("a", "b", false);
}
TEST_F(TopologicalMoveTest, MoveBeforeForwardWithDeps) {
// move forward with deps
EXPECT_TRUE(moveBeforeTopologicallyValid("f", "m"));
checkPostCondition("f", "m", false);
}
TEST_F(TopologicalMoveTest, MoveBeforeBackwardWithDeps) {
// move backward with deps
EXPECT_TRUE(moveBeforeTopologicallyValid("l", "f"));
checkPostCondition("l", "f", false);
}
// check that dependencies in blocks are recognized
TEST_F(TopologicalMoveTest, DepsDisallowMove) {
EXPECT_FALSE(moveAfterTopologicallyValid("l", "m"));
EXPECT_FALSE(moveBeforeTopologicallyValid("m", "l"));
EXPECT_FALSE(moveAfterTopologicallyValid("n", "l"));
EXPECT_FALSE(moveBeforeTopologicallyValid("l", "n"));
}
// Test that moveAfter(n) and moveBefore(n->next()) are not necessarily
// equivalent. Here, the dependency ordering is n -> o -> p. So we can't
// move `n` after `o`, but we can move `n` before `p` (which pushes `o` after
// `p`)
TEST_F(TopologicalMoveTest, MoveAfterBeforeWithDeps) {
EXPECT_FALSE(moveAfterTopologicallyValid("n", "o"));
EXPECT_TRUE(moveBeforeTopologicallyValid("o", "p"));
checkPostCondition("o", "p", false);
}
namespace {
Node* insertIf(
Graph& g,
Value* condValue,
std::function<std::vector<Value*>()> trueInst,
std::function<std::vector<Value*>()> falseInst) {
auto if_ = g.insertNode(g.create(prim::If, 0));
if_->addInput(condValue); // condition value
auto trueBlock = if_->addBlock();
auto falseBlock = if_->addBlock();
{
// Mutate in true block
WithInsertPoint g(trueBlock);
auto outputs = trueInst();
for (auto output : outputs) {
trueBlock->registerOutput(output);
}
}
{
WithInsertPoint g(falseBlock);
auto outputs = falseInst();
for (auto output : outputs) {
falseBlock->registerOutput(output);
}
}
EXPECT_TRUE(trueBlock->outputs().size() == falseBlock->outputs().size());
for (auto output : trueBlock->outputs()) {
if_->addOutput()->setType(output->type());
}
return if_;
}
template <class Exception, class Functor>
inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
try {
std::forward<Functor>(functor)();
} catch (const Exception& e) {
if (std::string(e.what()).find(expectMessageContains) ==
std::string::npos) {
AT_ERROR(
"Expected error message to contain \"",
expectMessageContains,
"\" but error message was: ",
e.what());
}
return;
}
AT_ERROR(
"Expected to throw exception containing \"",
expectMessageContains,
"\" but didn't throw");
}
} // namespace
TEST(AliasAnalysisTest, AliasingMutationBlocksMoves) {
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->addInput();
// addsB = b + b
// c = a + b
// a += b
// d = c + c
auto addsB = graph->insert(aten::add, {b, b});
auto c = graph->insert(aten::add, {a, b});
auto aMut = graph->insert(aten::add_, {a, b});
auto d = graph->insert(aten::add, {c, c});
graph->lint();
AliasDb aliasDb(graph);
// Can't move past a mutation of a used value
EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
// b should alias to a (since they are both inputs)
EXPECT_FALSE(
aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
graph->lint();
}
TEST(AliasAnalysisTest, AliasingMutationBlocksMoves2) {
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->addInput();
auto constant = graph->insertConstant(1);
auto fresh = graph->insert(aten::rand, {constant});
auto usesB = graph->insert(aten::add, {b, fresh});
auto aliasesB = graph->insert(aten::select, {a, constant, constant});
auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh});
graph->insert(aten::add, {fresh, aliasesB});
graph->lint();
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(
aliasesB->node(), mutatesAliasOfB->node()));
EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(
usesB->node(), mutatesAliasOfB->node()));
}
TEST(AliasAnalysisTest, SideEffectsBlockMoves) {
// Test moves across side effectful nodes
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto print1 = graph->insertNode(graph->create(prim::Print, {a}, 0));
WithInsertPoint guard(print1);
auto print2 = graph->insertNode(graph->create(prim::Print, {a, a}, 0));
AliasDb aliasDb(graph);
// def foo(a):
// print2(a, a)
// print1(a)
// test moving across each other
EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(print2, print1));
EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(print1, print2));
// test moving where they already are
EXPECT_TRUE(aliasDb.moveBeforeTopologicallyValid(print2, print1));
EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(print1, print2));
graph->insertNode(graph->create(prim::MakeTestTensor, {}, 1));
AliasDb aliasDb2(graph);
// def foo(a):
// print2(a, a)
// non_side_effectful = makeTestTensor()
// print1(a)
// test moving with a side effectful node between
EXPECT_FALSE(aliasDb2.moveAfterTopologicallyValid(print2, print1));
EXPECT_FALSE(aliasDb2.moveBeforeTopologicallyValid(print2, print1));
EXPECT_FALSE(aliasDb2.moveAfterTopologicallyValid(print1, print2));
EXPECT_FALSE(aliasDb2.moveBeforeTopologicallyValid(print1, print2));
}
TEST(AliasAnalysisTest, MovingAcrossInnerBlocks) {
// Test moves across inner blocks
// a = rand(1)
// b = rand(1)
// if True:
// a.add_(b)
// c = a + b
auto graph = std::make_shared<Graph>();
auto constant = graph->insertConstant(1);
auto a = graph->insert(aten::rand, {constant});
auto b = graph->insert(aten::rand, {constant});
auto if_ = insertIf(
*graph,
constant,
[&]() -> std::vector<Value*> {
auto aMut = graph->insert(aten::add_, {a, b});
return {aMut};
},
[&]() -> std::vector<Value*> { return {a}; });
auto c = graph->insert(aten::add, {a, b});
graph->lint();
// we should not be able to move `c` before the if statement, since it
// may write to `a`.
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(c->node(), if_));
}
TEST(AliasAnalysisTest, NoneHasNoWriters) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%opt : Tensor? = prim::Constant()
%out : Tensor = prim::unchecked_unwrap_optional(%opt)
%ret.2 : Tensor = aten::div(%out, %out, %out)
return (%opt, %out, %ret.2)
)IR",
&*graph,
vmap);
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.hasWriters(vmap["opt"]->node()));
}
TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=0]()
%b : Tensor = aten::add(%x, %2, %3)
%c : Tensor = aten::add(%x, %2, %3)
%d : Tensor = aten::add(%x, %2, %3)
%e : Tensor = aten::add(%x, %2, %3)
%f : Tensor[] = prim::ListConstruct(%e)
%14 : (Tensor, Tensor) = prim::TupleConstruct(%b, %c)
return (%14)
)IR",
&*graph,
vmap);
AliasDb aliasDb(graph);
// x, b, c escape scope, so we can't introduce an aliasing relationship
EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["b"]));
EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["x"]));
EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["c"]));
EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["b"]));
// e aliases the wildcard set because it's contained in a list
EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["e"], vmap["x"]));
EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["e"]));
// d is a temporary with no writers, safe to change aliasing relationship
// here
EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["d"]));
EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
}
class BatchAndInstanceNormFixture
: public ::testing::TestWithParam<std::tuple<std::string, NodeKind, bool>> {
};
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNorm) {
auto param = GetParam();
auto fnName = std::get<0>(param);
auto nodeKind = std::get<1>(param);
auto isTraining = std::get<2>(param);
std::string isTrainingStr = std::to_string((int)isTraining);
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor):
%none : NoneType = prim::Constant()
%training : bool = prim::Constant[value=)IR" +
isTrainingStr + R"IR(]()
%momentum : float = prim::Constant[value=1.0]()
%eps : float = prim::Constant[value=1.0e-9]()
%cudnn_enabled : bool = prim::Constant[value=0]()
%res : Tensor = )IR" +
fnName +
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
return (%res)
)IR",
&*graph);
graph->lint();
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == nodeKind) {
break;
}
}
EXPECT_TRUE(n != nullptr);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.hasWriters(n) == isTraining);
}
TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNormTrainingUnknown) {
auto param = GetParam();
auto fnName = std::get<0>(param);
auto nodeKind = std::get<1>(param);
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor, %training : bool):
%none : NoneType = prim::Constant()
%momentum : float = prim::Constant[value=1.0]()
%eps : float = prim::Constant[value=1.0e-9]()
%cudnn_enabled : bool = prim::Constant[value=0]()
%res : Tensor = )IR" +
fnName +
R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
return (%res)
)IR",
&*graph);
graph->lint();
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == nodeKind) {
break;
}
}
EXPECT_TRUE(n != nullptr);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.hasWriters(n));
}
TEST_P(BatchAndInstanceNormFixture, BatchNormTrainingWithNoMeanOrVar) {
auto param = GetParam();
auto fnName = std::get<0>(param);
auto nodeKind = std::get<1>(param);
auto isTraining = std::get<2>(param);
std::string isTrainingStr = std::to_string((int)isTraining);
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%input : Tensor):
%none : NoneType = prim::Constant()
%training : bool = prim::Constant[value=)IR" +
isTrainingStr + R"IR(]()
%momentum : float = prim::Constant[value=1.0]()
%eps : float = prim::Constant[value=1.0e-9]()
%cudnn_enabled : bool = prim::Constant[value=0]()
%res : Tensor = )IR" +
fnName +
R"IR((%input, %none, %none, %none, %none, %training, %momentum, %eps, %cudnn_enabled)
return (%res)
)IR",
&*graph);
graph->lint();
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == nodeKind) {
break;
}
}
EXPECT_TRUE(n != nullptr);
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.hasWriters(n));
}
INSTANTIATE_TEST_SUITE_P(
AliasAnalysisTest,
BatchAndInstanceNormFixture,
::testing::Values(
std::make_tuple("aten::batch_norm", aten::batch_norm, false),
std::make_tuple("aten::instance_norm", aten::instance_norm, false),
std::make_tuple("aten::batch_norm", aten::batch_norm, true),
std::make_tuple("aten::instance_norm", aten::instance_norm, true)));
TEST(WriteTrackingTest, Basic) {
RegisterOperators reg({Operator(
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
[](Stack&) {},
aliasAnalysisFromSchema())});
const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->addInput();
// aten::add(%b, %b)
// aten::add_(%a, %b)
// foo::creates_alias(%a)
auto pureNode = graph->insert(aten::add, {b, b})->node();
auto writingNode = graph->insert(aten::add_, {a, b})->node();
auto node3 = graph->insert(creates_alias, {a})->node();
auto aAlias = node3->output();
graph->lint();
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.mayAlias(aAlias, a));
EXPECT_TRUE(aliasDb.mayAlias(a, b));
EXPECT_FALSE(
aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
EXPECT_FALSE(
aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
EXPECT_TRUE(
aliasDb.writesToAlias(writingNode, std::unordered_set<const Value*>{a}));
EXPECT_TRUE(aliasDb.writesToAlias(
writingNode, std::unordered_set<const Value*>{a, b}));
EXPECT_TRUE(aliasDb.writesToAlias(
writingNode, std::unordered_set<const Value*>{aAlias}));
}
TEST(WriteTrackingTest, IsMutable) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%x: Tensor):
%b : Tensor = aten::relu_(%x)
return (%b)
)IR",
&*graph);
auto node_iter = graph->block()->nodes().begin();
auto relu = *node_iter;
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.isMutable(relu));
}
TEST(WriteTrackingTest, IsImmutable) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%x: Tensor, %y : Tensor):
%b : Tensor = aten::mul(%x, %y)
return (%b)
)IR",
&*graph);
auto node_iter = graph->block()->nodes().begin();
auto mul = *node_iter;
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.isMutable(mul));
}
TEST(WriteTrackingTest, HasWriters) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x: Tensor, %y : Tensor):
%c1 : int = prim::Constant[value=1]()
%b : Tensor = aten::add_(%x, %y, %c1)
return (%b)
)IR",
&*graph,
vmap);
auto add = vmap["b"]->node();
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.hasWriters(add));
EXPECT_TRUE(aliasDb.isMutable(add));
}
TEST(ContainerAliasingTest, MayContainAlias) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%inp: Tensor[]):
%x : str = prim::Constant[value="a"]()
%y : Tensor = prim::Constant()
%z : Tensor = prim::Constant()
%a : (Tensor) = prim::TupleConstruct(%y)
%b : Dict(str, Tensor) = prim::DictConstruct(%x, %y)
%c : Tensor[] = prim::ListConstruct(%y)
return (%a, %b, %c)
)IR",
&*graph,
vmap);
auto str_output = vmap["x"];
auto ten_output = vmap["y"];
auto local_var = vmap["z"];
AliasDb aliasDb(graph);
EXPECT_TRUE(graph->outputs().size() == 3);
for (auto out : graph->outputs()) {
EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, out));
EXPECT_FALSE(aliasDb.mayContainAlias(local_var, out));
}
EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, graph->inputs()));
EXPECT_FALSE(aliasDb.mayContainAlias(local_var, graph->inputs()));
EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, graph->outputs()));
EXPECT_TRUE(aliasDb.mayContainAlias(
at::ArrayRef<Value*>{ten_output}, graph->outputs()));
EXPECT_FALSE(aliasDb.mayContainAlias(str_output, graph->outputs()));
}
TEST(ContainerAliasingTest, MayContainAlias_cast) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%input.1 : Tensor):
%2 : NoneType = prim::Constant()
%3 : bool = prim::Constant[value=0]()
%4 : int = prim::Constant[value=6]()
%5 : int = prim::Constant[value=1]()
%a.1 : Tensor = aten::add(%input.1, %input.1, %5)
%b.1 : Tensor = aten::to(%a.1, %4, %3, %3, %2)
%c.1 : Tensor = aten::mul(%b.1, %b.1)
return (%c.1)
)IR",
&*graph,
vmap);
auto a = vmap["a.1"];
auto b = vmap["b.1"];
auto c = vmap["c.1"];
AliasDb aliasDb(graph);
EXPECT_TRUE(graph->outputs().size() == 1);
for (auto out : graph->outputs()) {
EXPECT_TRUE(aliasDb.mayContainAlias(c, out));
}
EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
EXPECT_FALSE(aliasDb.mayContainAlias(b, graph->inputs()));
EXPECT_TRUE(aliasDb.mayContainAlias(c, graph->outputs()));
EXPECT_TRUE(
aliasDb.mayContainAlias(at::ArrayRef<Value*>{c}, graph->outputs()));
EXPECT_FALSE(aliasDb.mayContainAlias(b, graph->outputs()));
}
TEST(ContainerAliasingTest, PrimitveValuesDontAliasContainers) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%x : str = prim::Constant[value="a"]()
%y : int = prim::Constant[value=1]()
%a : (int) = prim::TupleConstruct(%y)
%b : Dict(str, int) = prim::DictConstruct(%x, %y)
%c : int[] = prim::ListConstruct(%y)
return (%a, %b, %c)
)IR",
&*graph);
auto node_iter = graph->block()->nodes().begin();
node_iter++; // string
Node* int_node = *node_iter++;
AliasDb aliasDb(graph);
EXPECT_TRUE(graph->outputs().size() == 3);
// primitive values don't need to alias container
for (auto out : graph->outputs()) {
EXPECT_FALSE(aliasDb.mayContainAlias(int_node->output(), out));
}
}
TEST(ContainerAliasingTest, UnionAliasing) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%a : Dict(str, Tensor),
%b : Tensor[],
%c : Union(Dict(str, Tensor), Tensor[])):
return (%a, %b, %c)
)IR",
&*graph);
AliasDb aliasDb(graph);
auto a = graph->outputs().at(0);
auto b = graph->outputs().at(1);
auto c = graph->outputs().at(2);
EXPECT_TRUE(aliasDb.mayAlias(a, c));
EXPECT_TRUE(aliasDb.mayAlias(b, c));
EXPECT_TRUE(aliasDb.mayAlias(c, c));
EXPECT_FALSE(aliasDb.mayAlias(a, b));
EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
EXPECT_TRUE(aliasDb.mayContainAlias(a, c));
EXPECT_TRUE(aliasDb.mayContainAlias(b, c));
}
TEST(ContainerAliasingTest, InputsCanAliasOutputs) {
// Test input aliasing
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%x: Tensor, %y: Tensor):
%a : (Tensor) = prim::TupleConstruct(%x)
return (%a)
)IR",
&*graph);
auto node_iter = graph->block()->nodes().begin();
auto tuple_node = *node_iter;
AliasDb aliasDb(graph);
for (auto input : graph->inputs()) {
EXPECT_TRUE(aliasDb.mayContainAlias(input, tuple_node->output()));
}
EXPECT_TRUE(aliasDb.mayContainAlias(graph->inputs(), graph->outputs()));
}
// Test tuple that doesn't come from construct
TEST(ContainerAliasingTest, NestedTupleConstruct) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%x : int,
%y : Tensor,
%z : Tensor):
%3 : int = prim::Constant[value=1]()
%4 : bool = aten::eq(%x, %3)
%a : (Tensor) = prim::If(%4)
block0():
%a.1 : (Tensor) = prim::TupleConstruct(%y)
-> (%a.1)
block1():
%a.2 : (Tensor) = prim::TupleConstruct(%z)
-> (%a.2)
return (%a)
)IR",
&*graph);
AliasDb aliasDb(graph);
for (auto input : graph->inputs()) {
if (input->type() == IntType::get()) {
continue;
}
EXPECT_TRUE(aliasDb.mayContainAlias(input, graph->outputs().at(0)));
}
}
// test nested types
TEST(ContainerAliasingTest, NestedTypes) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%a : Tensor = prim::MakeTestTensor()
%a_list : Tensor[] = prim::ListConstruct(%a)
%b : Tensor = prim::MakeTestTensor()
%b_list : Tensor[] = prim::ListConstruct(%b)
%13 : (Tensor[], Tensor[]) = prim::TupleConstruct(%a_list, %b_list)
return (%13)
)IR",
&*graph);
AliasDb aliasDb(graph);
auto g_output = graph->outputs().at(0);
auto list_2 = g_output->node()->inputs().at(0);
auto list_1 = g_output->node()->inputs().at(1);
// TODO FIX assume conservatively for now
EXPECT_TRUE(aliasDb.mayContainAlias(list_1, list_2));
EXPECT_TRUE(aliasDb.mayContainAlias(list_2, list_1));
EXPECT_TRUE(aliasDb.mayContainAlias(list_1, g_output));
EXPECT_TRUE(aliasDb.mayContainAlias(list_2, g_output));
}
// simple example
TEST(ContainerAliasingTest, Simple) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%0 : Tensor = prim::Constant()
%1 : Tensor = prim::Constant()
%13 : (Tensor) = prim::TupleConstruct(%0)
return (%13)
)IR",
&*graph);
AliasDb aliasDb(graph);
auto node_iter = graph->block()->nodes().begin();
auto first_ten = *node_iter++;
auto second_ten = *node_iter++;
auto tup_node = *node_iter;
EXPECT_TRUE(aliasDb.mayContainAlias(first_ten->output(), tup_node->output()));
EXPECT_TRUE(
!aliasDb.mayContainAlias(second_ten->output(), tup_node->output()));
std::vector<Value*> first_st = {first_ten->output()};
std::vector<Value*> second_st = {second_ten->output()};
std::vector<Value*> tup_st = {tup_node->output()};
EXPECT_TRUE(aliasDb.mayContainAlias(first_st, tup_st));
EXPECT_FALSE(aliasDb.mayContainAlias(first_st, second_st));
EXPECT_FALSE(aliasDb.mayContainAlias(second_st, tup_st));
}
TEST(ContainerAliasingTest, Lists) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%x : str = prim::Constant[value="a"]()
%y : Tensor = prim::Constant()
%c : Tensor[] = prim::ListConstruct(%y)
%d : Tensor[] = prim::ListConstruct(%y)
return (%c, %d)
)IR",
&*graph,
vmap);
AliasDb aliasDb(graph);
auto x = vmap["x"];
auto c = vmap["c"];
EXPECT_FALSE(aliasDb.mayContainAlias(x, c));
EXPECT_FALSE(aliasDb.mayContainAlias(c, x));
auto d = vmap["d"];
EXPECT_TRUE(aliasDb.mayContainAlias(d, c));
EXPECT_TRUE(aliasDb.mayContainAlias(c, d));
}
TEST(ContainerAliasingTest, Lists2) {
// Test list container aliasing
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%2 : int[] = prim::ListConstruct(%0, %1)
%x : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
%y : Tensor = prim::MakeTestTensor()
%22 : int[] = prim::ListConstruct(%0, %1)
%z : Tensor = prim::MakeTestTensor()
%32 : int[] = prim::ListConstruct(%0, %1)
%fresh : Tensor = prim::MakeTestTensor()
%foo : Tensor[] = prim::ListConstruct(%x, %y)
%43 : Tensor[] = aten::append(%foo, %z)
return ()
)IR",
graph.get(),
vmap);
AliasDb aliasDb(graph);
auto x = vmap["x"];
auto y = vmap["y"];
auto z = vmap["z"];
// Tensors x, y, and z went into a list, so they all may alias each other.
EXPECT_TRUE(aliasDb.mayAlias(x, y));
EXPECT_TRUE(aliasDb.mayAlias(y, z));
EXPECT_TRUE(aliasDb.mayAlias(x, z));
// But we know `fresh` didn't go into a list, so x, y, and z should not
// alias it.
auto fresh = vmap["fresh"];
EXPECT_FALSE(aliasDb.mayAlias(x, fresh));
EXPECT_FALSE(aliasDb.mayAlias(y, fresh));
EXPECT_FALSE(aliasDb.mayAlias(z, fresh));
}
TEST(ContainerAliasingTest, Conservative) {
// test "conservative" analysis writes to the inside of a container.
auto ops = torch::RegisterOperators(
"custom::conservative", [](torch::List<at::Tensor> in) { return in; });
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = prim::MakeTestTensor()
%12 : Tensor[] = prim::ListConstruct(%11)
%out : Tensor[] = custom::conservative(%12)
%ret.2 : Tensor = aten::div(%11, %11)
return ()
)IR",
graph.get(),
vmap);
AliasDb aliasDb(graph);
auto conservativeOp = vmap["out"]->node();
auto tensor = vmap["11"];
EXPECT_TRUE(aliasDb.writesToAlias(conservativeOp, ValueSet{tensor}));
}
TEST(ContainerAliasingTest, MovesAcrossContainedWrites) {
auto ops = torch::RegisterOperators().op(
"uses::list",
torch::RegisterOperators::options()
.catchAllKernel([](torch::List<at::Tensor> in) {
return torch::rand({2, 3});
})
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
// Write to the inside of a list. Check that we can't reorder a
// print across it.
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%35 : int = prim::Constant[value=1]()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%23 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
%21 : Tensor = prim::MakeTestTensor()
%l : Tensor[] = prim::ListConstruct(%11, %21)
%24 : Tensor = aten::select(%l, %23)
%25 : int[] = prim::ListConstruct(%0, %1)
%34 : Tensor = prim::MakeTestTensor()
%36 : Tensor = aten::add_(%24, %34, %35)
%37 : Tensor = uses::list(%l)
return (%37)
)IR",
graph.get(),
vmap);
AliasDb aliasDb(graph);
auto listUse = vmap["37"]->node();
auto internalWrite = vmap["36"]->node();
EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
}
TEST(ContainerAliasingTest, MovesAcrossContainedWritesNested) {
// The same as above, but with a nested list
auto ops = torch::RegisterOperators().op(
"uses::list",
torch::RegisterOperators::options()
.catchAllKernel([](torch::List<at::Tensor> in) {
return torch::rand({2, 3});
})
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
// Write to the inside of a list. Check that we can't reorder a
// print across it.
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%38 : int = prim::Constant[value=1]()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%24 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
%11 : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
%21 : Tensor = prim::MakeTestTensor()
%l : Tensor[] = prim::ListConstruct(%11, %21)
%25 : Tensor = aten::select(%l, %24)
%27 : Tensor = aten::select(%25, %24, %24)
%28 : int[] = prim::ListConstruct(%0, %1)
%37 : Tensor = prim::MakeTestTensor()
%39 : Tensor = aten::add_(%27, %37, %38)
%40 : Tensor = uses::list(%l)
return (%40)
)IR",
graph.get(),
vmap);
AliasDb aliasDb(graph);
auto listUse = vmap["40"]->node();
auto internalWrite = vmap["39"]->node();
EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
}
TEST(WildcardsTest, Basic) {
RegisterOperators reg(
{Operator(
"prim::returns_wildcard(Tensor a) -> Tensor(*)",
[](Stack&) {},
aliasAnalysisFromSchema()),
Operator(
"prim::writes(Tensor(z!) a) -> Tensor(a)",
[](Stack&) {},
aliasAnalysisFromSchema())});
const auto returns_wildcard =
Symbol::fromQualString("prim::returns_wildcard");
const auto writes = Symbol::fromQualString("prim::writes");
auto graph = std::make_shared<Graph>();
const auto a = graph->addInput();
const auto constant = graph->insertConstant(1);
const auto fresh = graph->insert(aten::rand, {constant});
const auto fresh2 = graph->insert(aten::rand, {constant});
const auto wildcard = graph->insert(returns_wildcard, {fresh});
{
graph->lint();
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.mayAlias(a, fresh));
EXPECT_FALSE(aliasDb.mayAlias(wildcard, fresh));
EXPECT_TRUE(aliasDb.mayAlias(wildcard, a));
EXPECT_FALSE(aliasDb.mayAlias(ValueSet{wildcard}, ValueSet{}));
EXPECT_FALSE(aliasDb.hasWriters(wildcard->node()));
}
graph->insert(writes, {fresh2})->node();
{
graph->lint();
AliasDb aliasDb(graph);
EXPECT_FALSE(aliasDb.hasWriters(wildcard->node()));
}
const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
{
graph->lint();
AliasDb aliasDb(graph);
// Test writes to wildcards
EXPECT_FALSE(aliasDb.writesToAlias(
wildcardWrite, std::unordered_set<const Value*>{fresh}));
EXPECT_FALSE(aliasDb.writesToAlias(
wildcardWrite, std::unordered_set<const Value*>{fresh2}));
EXPECT_TRUE(aliasDb.writesToAlias(
wildcardWrite, std::unordered_set<const Value*>{a}));
EXPECT_TRUE(aliasDb.hasWriters(wildcard->node()));
}
}
// test that wildcards are correctly divided by type
TEST(WildcardsTest, TypeIsolation) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
%ten : Tensor = prim::Constant()
%4 : Tensor[] = aten::append(%ten_list, %ten)
%ten_ten_list : Tensor[][] = prim::Constant()
%int_int_list : int[][] = prim::Constant()
return ()
)IR",
&*graph,
vmap);
AliasDb aliasDb(graph);
auto opt_ten_list = vmap["opt_ten_list"];
auto ten_list = vmap["ten_list"];
auto int_list = vmap["int_list"];
EXPECT_FALSE(aliasDb.hasWriters(int_list));
EXPECT_TRUE(aliasDb.hasWriters(opt_ten_list));
EXPECT_TRUE(aliasDb.hasWriters(ten_list));
EXPECT_FALSE(aliasDb.mayContainAlias(int_list, opt_ten_list));
EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, opt_ten_list));
EXPECT_TRUE(aliasDb.mayAlias(ten_list, opt_ten_list));
auto list_of_tensor_lists = vmap["ten_ten_list"];
EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, list_of_tensor_lists));
EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, vmap["ten"]));
EXPECT_TRUE(
!aliasDb.mayContainAlias(vmap["int_int_list"], list_of_tensor_lists));
}
// test invariant container aliasing
// the containers of different type cannot alias each other,
// however they may contain elements which alias each other
TEST(WildcardsTest, InvariantContainerAliasing) {
{
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
%ten : Tensor = prim::Constant()
%4 : Tensor[] = aten::append(%ten_list, %ten)
return ()
)IR",
&*graph,
vmap);
AliasDb aliasDb(graph);
auto ten_opt_list = vmap["ten_opt_list"];
auto ten_list = vmap["ten_list"];
EXPECT_FALSE(aliasDb.hasWriters(ten_opt_list));
EXPECT_TRUE(aliasDb.hasWriters(ten_list));
EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, ten_opt_list));
EXPECT_FALSE(aliasDb.mayAlias(ten_list, ten_opt_list));
}
{
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
return ()
)IR",
&*graph,
vmap);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.mayAlias(vmap["float_3D"], vmap["float_2D"]));
}
{
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
return ()
)IR",
&*graph,
vmap);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.mayAlias(vmap["float_3D_list"], vmap["float_2D_list"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["float_3D_list"], vmap["ten"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["float_2D_list"], vmap["ten"]));
}
}
TEST(AliasRegistrationTest, ConservativeWithInferredSchema) {
auto registry = torch::RegisterOperators().op(
"foo::rand1",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
const auto rand_op = Symbol::fromQualString("foo::rand1");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// Conservatively we assume there is a reference
EXPECT_TRUE(aliasDb.mayAlias(a, b));
}
TEST(AliasRegistrationTest, ConservativeWithSpecifiedSchema) {
auto registry = torch::RegisterOperators().op(
"foo::rand2(Tensor arg1) -> Tensor",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
const auto rand_op = Symbol::fromQualString("foo::rand2");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// Conservatively we assume there is a reference
EXPECT_TRUE(aliasDb.mayAlias(a, b));
}
TEST(AliasRegistrationTest, ConservativeWithAliasingAnnotationsShouldError) {
auto registry = torch::RegisterOperators().op(
"foo::rand3(Tensor(a) arg1) -> Tensor(b)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
const auto rand_op = Symbol::fromQualString("foo::rand3");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
graph->insert(rand_op, {a});
// Registration time is okay, but throw exception when fetch from
// registration.
expectThrows<c10::Error>(
[&graph] { AliasDb aliasDb(graph); },
"Tried to register operator foo::rand3(Tensor(a) arg1) -> (Tensor(b)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
TEST(AliasRegistrationTest, ConservativeWithAliasingAnnotationsShouldError2) {
auto registry = torch::RegisterOperators().op(
"foo::rand4(Tensor(a) arg1) -> Tensor(a)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
const auto rand_op = Symbol::fromQualString("foo::rand4");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
graph->insert(rand_op, {a});
// Registration time is okay, but throw exception when fetch from
// registration.
expectThrows<c10::Error>(
[&graph] { AliasDb aliasDb(graph); },
"Tried to register operator foo::rand4(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
TEST(AliasRegistrationTest, FromSchemaWithInferredSchemaShouldError) {
expectThrows<c10::Error>(
[] {
torch::RegisterOperators().op(
"foo::rand5",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
},
"Tried to register operator foo::rand5(Tensor _0) -> (Tensor _0) with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred");
}
TEST(AliasRegistrationTest, FromSchemaInferredPure) {
auto registry = torch::RegisterOperators().op(
"foo::rand6(Tensor arg1) -> Tensor",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
const auto rand_op = Symbol::fromQualString("foo::rand6");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema doesn't contain alias information, which means it's pure
// (meh!)
EXPECT_FALSE(aliasDb.mayAlias(a, b));
}
TEST(AliasRegistrationTest, FromSchemaAliased) {
auto registry = torch::RegisterOperators().op(
"foo::rand7(Tensor(a) arg1) -> Tensor(a)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
const auto rand_op = Symbol::fromQualString("foo::rand7");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema has an alias reference
EXPECT_TRUE(aliasDb.mayAlias(a, b));
}
TEST(AliasRegistrationTest, FromSchemaPure) {
auto registry = torch::RegisterOperators().op(
"foo::rand8(Tensor(a) arg1) -> Tensor(b)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
const auto rand_op = Symbol::fromQualString("foo::rand8");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema does not have an alias reference
EXPECT_FALSE(aliasDb.mayAlias(a, b));
}
TEST(AliasRegistrationTest, PureNoSchema) {
auto registry = torch::RegisterOperators().op(
"foo::rand9",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
const auto rand_op = Symbol::fromQualString("foo::rand9");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema is pure, there cannot be any alias
EXPECT_FALSE(aliasDb.mayAlias(a, b));
}
TEST(AliasRegistrationTest, PureWithSchema) {
auto registry = torch::RegisterOperators().op(
"foo::rand10(Tensor arg1) -> Tensor",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
const auto rand_op = Symbol::fromQualString("foo::rand10");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema is pure, there cannot be any alias
EXPECT_FALSE(aliasDb.mayAlias(a, b));
}
TEST(AliasRegistrationTest, PureWithAnnotationsShouldError) {
auto registry = torch::RegisterOperators().op(
"foo::rand11(Tensor(a) arg1) -> Tensor(a)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
const auto rand_op = Symbol::fromQualString("foo::rand11");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
graph->insert(rand_op, {a});
// Registration time is okay, but throw exception when fetch from
// registration.
expectThrows<c10::Error>(
[&graph] { AliasDb aliasDb(graph); },
"Tried to register operator foo::rand11(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
TEST(AliasRegistrationTest, AliasMoveAtenListOp) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%8 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=2]()
%y : Tensor[] = prim::ListConstruct(%x)
%6 : Tensor = aten::add_(%x, %4, %5)
%9 : Tensor = aten::cat(%y, %8)
return (%9))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(graph);
// bc y.1 has a single used in a single non-aliasing aten op,
// x is added to y.1 contained elements instead of wildcard set
EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["9"]));
// write to contained element should prevent move
EXPECT_TRUE(!aliasDb.moveBeforeTopologicallyValid(
vmap["y"]->node(), vmap["9"]->node()));
}
TEST(
AliasRegistrationTest,
AliasMoveForTupleConstructWithSingleUseAsGraphOutput) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%y : Tensor = prim::MakeTestTensor()
%z : (Tensor) = prim::TupleConstruct(%x, %y)
return (%z))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(graph, /*isFrozen=*/false);
EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["y"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["y"]));
}
TEST(AliasRegistrationTest, RecursiveSubgraphTupleContainment) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%y : Tensor = prim::MakeTestTensor()
%z : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
return (%z))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
auto node = vmap["z"]->node();
auto subgraph =
SubgraphUtils::createSingletonSubgraph(node, prim::FunctionalGraph);
AliasDb aliasDb(graph);
EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["x"]));
EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["y"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
}
TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%y : Tensor = prim::MakeTestTensor()
%z : Tensor = prim::MakeTestTensor()
%0 : int = prim::Constant[value=0]()
%a : (Tensor) = prim::TupleConstruct(%x, %y)
%b : (Tensor) = prim::TupleConstruct(%z)
%c : Tensor = prim::TupleIndex(%a, %0)
%d : Tensor = prim::TupleIndex(%b, %0)
return (%c, %d))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(graph, /*isFrozen=*/false);
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["z"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["z"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["x"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["y"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["z"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["x"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["y"]));
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["z"]));
}
TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%0 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%y : Tensor = aten::add(%x, %x, %0)
%lengths_list : int[] = prim::tolist(%1, %2)
%a : Tensor[] = aten::split(%y, %lengths_list, %0)
%b : Tensor, %c : Tensor = prim::ListUnpack(%a)
%b1 : Tensor = aten::flatten(%b, %0, %1)
%c1 : Tensor = aten::flatten(%c, %0, %1)
%d : Tensor = aten::add(%b1, %c1, %0)
return (%d))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(graph, /*isFrozen=*/false);
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
}
TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
auto graph_string = R"IR(
graph():
%x : Tensor = prim::MakeTestTensor()
%0 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%y : Tensor = aten::add(%x, %x, %0)
%a : Tensor[] = aten::split(%y, %2, %0)
%b : Tensor, %c : Tensor = prim::ListUnpack(%a)
%b1 : Tensor = aten::flatten(%b, %0, %1)
%c1 : Tensor = aten::flatten(%c, %0, %1)
%d : Tensor = aten::add(%b1, %c1, %0)
return (%d))IR";
torch::jit::parseIR(graph_string, graph.get(), vmap);
AliasDb aliasDb(graph, /*isFrozen=*/false);
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
}
TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
auto registry = torch::RegisterOperators().op(
"foo::rand12(Tensor(a) arg1) -> Tensor(b)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
const auto rand_op = Symbol::fromQualString("foo::rand12");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
graph->insert(rand_op, {a});
// Registration time is okay, but throw exception when fetch from
// registration.
expectThrows<c10::Error>(
[&graph] { AliasDb aliasDb(graph); },
"Tried to register operator foo::rand12(Tensor(a) arg1) -> (Tensor(b)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
} // namespace jit
} // namespace torch