blob: dd3c8f68874826da9536f6d2071e7f5b15c2a34b [file] [log] [blame]
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/ir/type_hashing.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/guard_elimination.h>
#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/insert_guards.h>
#include <torch/csrc/jit/passes/liveness.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_grad_of.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/jit_trace.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/symbolic_script.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/jit.h>
#include <torch/script.h>
#include <onnx/onnx_pb.h>
#include <c10/util/Exception.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iostream>
#include <memory>
#include <set>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
size_t i = 0;
out << "{";
for (auto&& e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "}";
return out;
}
TEST(InternedStringsTest, Basic) {
ASSERT_EQ(prim::Param, Symbol::prim("Param"));
ASSERT_EQ(prim::Return, Symbol::prim("Return"));
ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
Symbol newsym = Symbol::aten("__NEW_SYMBOL");
size_t symstart = newsym;
ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
// TODO: This test is a bit too close to the implementation details.
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
}
TEST(FromQualStringTest, Basic) {
ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
ASSERT_EQ(
Symbol::fromQualString("::").ns().toQualString(),
std::string("namespaces::"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").toUnqualString(),
std::string("param"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
std::string("new_ns"));
ASSERT_EQ(
Symbol::fromQualString("new_ns::param").ns(),
Symbol::fromQualString("namespaces::new_ns"));
auto bad_inputs = {"scope", ":", ""};
for (auto input : bad_inputs) {
try {
Symbol::fromQualString(input);
ASSERT_TRUE(0);
} catch (const std::exception& c) {
}
}
}
TEST(THNNConvTest, Basic) {
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
std::vector<int64_t> kernel_size = {3, 5};
std::vector<int64_t> stride = {1, 2};
std::vector<int64_t> padding = {2, 1};
constexpr int out_channels = 5;
// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn(
{out_channels, input_size[1], kernel_size[0], kernel_size[1]});
at::Tensor bias = torch::randn({out_channels});
// run forward eagerly
at::Tensor output = at::_slow_conv2d_forward(
input, weight, kernel_size, bias, stride, padding);
// make grad_outputs
at::Tensor grad_output =
torch::randn_like(output, at::MemoryFormat::Preserve);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
std::tie(grad_input, grad_weight, grad_bias) = at::_slow_conv2d_backward(
grad_output,
input,
weight,
kernel_size,
stride,
padding,
{true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto ksz_val = graph->insertConstant(kernel_size);
auto kst_val = graph->insertConstant(stride);
auto pad_val = graph->insertConstant(padding);
auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
Value* conv = graph->insert(
aten::_slow_conv2d_forward,
{inputg, weightg, ksz_val, biasg, kst_val, pad_val});
auto outputs = conv->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();
// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);
tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);
// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}
TEST(ATenNativeBatchNormTest, Basic) {
// aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
// running_mean, Tensor running_var, bool training, float momentum, float eps)
// -> (Tensor, Tensor, Tensor)
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
bool training = true;
float momentum = 0.9;
float eps = 1e-5;
// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn({input_size[1]});
at::Tensor bias = torch::randn({input_size[1]});
at::Tensor running_mean = torch::randn({input_size[1]});
at::Tensor running_var = torch::randn({input_size[1]});
// running_mean and running_var are changed in-place, so clone and send them
at::Tensor running_mean_eager = running_mean.clone();
at::Tensor running_var_eager = running_var.clone();
at::Tensor running_mean_jit = running_mean.clone();
at::Tensor running_var_jit = running_var.clone();
// run forward eagerly
at::Tensor output, savemean, saveinvstd;
std::tie(output, savemean, saveinvstd) = at::native_batch_norm(
input,
weight,
bias,
running_mean_eager,
running_var_eager,
training,
momentum,
eps);
// make grad_outputs
at::Tensor grad_output =
torch::randn_like(output, at::MemoryFormat::Preserve);
at::Tensor grad_savemean =
torch::zeros_like(savemean, at::MemoryFormat::Preserve);
at::Tensor grad_saveinvstd =
torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
// aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
// weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
// save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
// Tensor, Tensor)
std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(
grad_output,
input,
weight,
running_mean_eager,
running_var_eager,
savemean,
saveinvstd,
training,
eps,
{true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto training_val = graph->insertConstant(IValue(training));
auto momentum_val = graph->insertConstant(IValue(momentum));
auto eps_val = graph->insertConstant(IValue(eps));
auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
auto running_meang = graph->addInput("running_mean");
auto running_varg = graph->addInput("running_var");
Value* bn = graph->insert(
aten::native_batch_norm,
{inputg,
weightg,
biasg,
running_meang,
running_varg,
training_val,
momentum_val,
eps_val});
auto outputs = bn->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();
// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);
tensors_in.push_back(running_mean_jit);
tensors_in.push_back(running_var_jit);
tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
tensor_grads_in.push_back(grad_savemean);
tensor_grads_in.push_back(grad_saveinvstd);
// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);
// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensors_out.push_back(savemean);
expected_tensors_out.push_back(saveinvstd);
expected_tensors_out.push_back(running_mean_eager);
expected_tensors_out.push_back(running_var_eager);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);
tensors_out.push_back(running_mean_jit);
tensors_out.push_back(running_var_jit);
// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}
TEST(CustomFusionTest, Basic) {
#if defined(FBCODE_CAFFE2)
return;
#endif
auto graph_string = R"IR(
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4)):
%2 : Tensor = aten::mul(%0, %1)
%3 : Tensor = aten::mul(%2, %0)
return (%3))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
torch::jit::overrideCanFuseOnCPU(true);
CustomFuseGraph(
g,
[](Node* n) { return n->kind() != prim::Param; },
Symbol::fromQualString("prim::FusionGroup"));
torch::jit::overrideCanFuseOnCPU(false);
const auto& nodes = g->nodes();
auto fusion_group =
std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
return node->kind() == Symbol::fromQualString("prim::FusionGroup");
});
AT_ASSERT(fusion_group != nodes.end());
auto subgraph = fusion_group->g(attr::Subgraph);
auto hits = 0;
// two multiplications
for (const auto& n : subgraph->nodes()) {
(void)n;
hits++;
}
AT_ASSERT(hits == 2);
}
TEST(CustomFusionTest, NestedBlocks) {
#if defined(FBCODE_CAFFE2)
return;
#endif
auto graph_string = R"IR(
graph(%0 : Float(2, 3, 4),
%1 : Float(2, 3, 4),
%2 : Float(2, 3, 4)):
%3 : int = prim::Constant[value=1]()
%4 : Tensor = prim::If(%2)
block0():
%5 : Tensor = aten::mul(%0, %2)
%6 : Tensor = aten::mul(%5, %1)
-> (%6)
block1():
%7 : Tensor = aten::add(%0, %2, %3)
%8 : Tensor = aten::add(%7, %1, %3)
-> (%8)
%9 : Tensor = aten::add(%4, %2, %3)
return (%4))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
CustomFuseGraph(
g,
[](Node* n) { return n->kind() == aten::mul; },
Symbol::fromQualString("prim::FusionGroup"));
// Could be done in more efficient ways, but this is only a test.
std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
Symbol s) {
for (auto node : b->nodes()) {
if (node->kind() == s)
return true;
for (auto nested_b : node->blocks())
if (dfs(nested_b, s))
return true;
}
return false;
};
AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
}
static const auto cf_examples = R"JIT(
def if_test(a, b):
# FIXME: use 0 instead of a.
# c = 0
c = a
if bool(a < b):
c = b
else:
c = a
return c
def if_one(a, b):
c = b
if bool(a < b):
c = a
return c
def while_test(a, i):
while bool(i < 3):
a *= a
i += 1
return a
)JIT";
TEST(ControlFlowTest, Basic) {
auto cu = compile(cf_examples);
auto run = [&](const std::string& name, std::vector<IValue> stack) {
auto graph = toGraphFunction(cu->get_function(name)).graph();
Code code(graph, "");
InterpreterState interp(code);
interp.run(stack);
return stack;
};
auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
return V(run(name, {L(a), L(b)})[0]);
};
ASSERT_EQ(2, run_binary("if_test", 1, 2));
ASSERT_EQ(3, run_binary("if_test", 3, 2));
ASSERT_EQ(2, run_binary("if_one", 2, 3));
ASSERT_EQ(2, run_binary("if_one", 3, 2));
ASSERT_EQ(256, run_binary("while_test", 2, 0));
}
#if defined(__has_feature)
#if __has_feature(address_sanitizer)
#define HAS_ASANUBSAN 1
#endif
#endif
#ifndef HAS_ASANUBSAN
// This test fails vptr UBSAN checks
TEST(ProtoTest, Basic) {
::ONNX_NAMESPACE::ModelProto proto;
proto.set_producer_name("foo");
}
#endif
// test a few features that are not directly used in schemas yet
TEST(SchemaParserTest, NestedArrays) {
// nested arrays
auto s = parseSchema("at::what(int[][4] foo) -> ()");
ASSERT_TRUE(s.arguments().at(0).N() == 4);
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
}
TEST(SchemaParserTest, OutVariant) {
auto schema_with_out = parseSchema(
"at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
auto schema_without_out =
parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
for (const auto& arg : schema_without_out.arguments()) {
ASSERT_TRUE(!arg.is_out());
}
auto schema_with_is_write = parseSchema(
"aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
for (const auto& arg : schema_with_is_write.arguments()) {
ASSERT_TRUE(!arg.is_out());
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SchemaParserTest, NamedReturns) {
// named returns
parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
auto s3 =
parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
}
TEST(SchemaParserTest, Futures) {
// futures
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(
*s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
}
TEST(SchemaParserTest, AnnotatedAliasSets) {
// test tensor with annotated alias sets
parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
}
TEST(SchemaParserTest, TensorListAnnotatedAliasSets) {
const auto s = parseSchema(
"at::foo(Tensor(a!) self, Tensor(b!)[] out)"
" -> ()");
const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info();
const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info();
ASSERT_TRUE(
selfAliasInfo->beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(selfAliasInfo->isWrite());
ASSERT_TRUE(outAliasInfo->isWrite());
ASSERT_TRUE(outAliasInfo->beforeSets().empty());
ASSERT_EQ(outAliasInfo->containedTypes().size(), 1);
auto containedType = outAliasInfo->containedTypes()[0];
ASSERT_TRUE(containedType.isWrite());
ASSERT_TRUE(
containedType.beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::b")});
}
TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) {
EXPECT_THAT(
[]() { parseSchema("at::foo(Tensor(!) self) -> Tensor"); },
::testing::Throws<std::runtime_error>(::testing::Property(
&std::runtime_error::what,
::testing::HasSubstr("expected ident but found '!' here"))));
}
TEST(SchemaParserTest, BeforeAfterSets) {
const auto s = parseSchema(
"at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
" -> (Tensor(b|c)[](a!))");
// The list itself is annotated with `a`
const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
ASSERT_NE(aliasInfo, nullptr);
ASSERT_TRUE(
aliasInfo->beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(aliasInfo->isWrite());
// Check the contained types
ASSERT_TRUE(!aliasInfo->containedTypes().empty());
const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
const auto expected = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"),
Symbol::fromQualString("alias::c"),
};
ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
ASSERT_FALSE(containedAliasInfo.isWrite());
}
TEST(SchemaParserTest, BeforeAfterSets2) {
const auto s = parseSchema(
"at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
" -> (Tensor(b|c)[](a!))");
// The list itself is annotated with `a`
const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
ASSERT_NE(aliasInfo, nullptr);
ASSERT_EQ(
aliasInfo->beforeSets(),
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_EQ(
aliasInfo->afterSets(),
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(aliasInfo->isWrite());
ASSERT_EQ(aliasInfo->containedTypes().size(), 1);
// Check the contained types
ASSERT_TRUE(!aliasInfo->containedTypes().empty());
const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
const auto expectedBefore = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"),
};
const auto expectedAfter = std::unordered_set<Symbol>{
Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
ASSERT_FALSE(containedAliasInfo.isWrite());
}
TEST(TopologicalIndexTest, Basic) {
Graph graph;
auto node1 = graph.create(prim::AutogradZero);
auto node2 = graph.create(prim::AutogradZero);
auto node3 = graph.create(prim::AutogradZero);
auto node4 = graph.create(prim::AutogradZero);
graph.appendNode(node4);
graph.prependNode(node1);
node2->insertAfter(node1);
node3->insertBefore(node4);
// nodes should be in numerical order
ASSERT_TRUE(node1->isBefore(node2));
ASSERT_TRUE(node1->isBefore(node3));
ASSERT_TRUE(node1->isBefore(node4));
ASSERT_TRUE(node2->isAfter(node1));
ASSERT_TRUE(node2->isBefore(node3));
ASSERT_TRUE(node2->isBefore(node4));
ASSERT_FALSE(node3->isBefore(node1));
ASSERT_FALSE(node3->isBefore(node2));
ASSERT_FALSE(node3->isAfter(node4));
// Built up a block structure
// node3
// /\ ...
// A B block1
// \ ...
// C block2
auto block1 = node3->addBlock();
auto A = graph.create(prim::AutogradZero);
block1->appendNode(A);
auto B = graph.create(prim::AutogradZero);
block1->appendNode(B);
auto block2 = B->addBlock();
auto C = graph.create(prim::AutogradZero);
block2->appendNode(C);
// Check isAfter on different block levels
ASSERT_TRUE(node1->isBefore(A));
ASSERT_TRUE(A->isBefore(B));
ASSERT_TRUE(A->isBefore(C));
// make sure things don't blow up on deletions
node2->destroy();
auto node2p = graph.create(prim::AutogradZero);
node2p->insertAfter(node1);
ASSERT_TRUE(node1->isBefore(node2p));
ASSERT_TRUE(node2p->isBefore(node3));
}
TEST(TopologicalIndexTest, Reindex) {
// Induce reindexing to test that path
Graph graph;
std::map<size_t, Node*> nodes;
auto anchor = graph.create(prim::AutogradZero);
graph.appendNode(anchor);
// Inserting to the same place a lot will trigger reindexing
for (auto i = 0; i < 100; ++i) {
auto n = graph.create(prim::AutogradZero);
n->insertAfter(anchor);
nodes[i] = n;
}
// Nodes should be in reverse order
for (auto i = 0; i < 100; ++i) {
for (auto j = i + 1; j < 100; ++j) {
ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
}
}
}
at::Tensor invokeTestRecordFunction(at::Tensor& t) {
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
auto t2 = t.pow(2);
return t2;
}
static const auto invokeTestRecordFunction_JIT = R"JIT(
def foo(self, t):
t2 = t.pow(2)
return t2
def forward(self, t):
return self.foo(t)
)JIT";
at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
auto module = std::make_shared<script::Module>(
"RecordFunctionTestModule", std::make_shared<script::CompilationUnit>());
module->define(invokeTestRecordFunction_JIT);
return module->forward({t}).toTensor();
}
using TracedTestValues =
std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
void checkTracedInputs(const TracedTestValues& inputs) {
bool found_test = false;
bool found_pow = false;
bool found_mul = false;
for (const auto& input : inputs) {
const auto& fn = std::get<0>(input);
const auto& sizes = std::get<1>(input);
if (fn == "test") {
found_test = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
} else if (fn == "aten::pow") {
found_pow = true;
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
TORCH_CHECK(sizes[1].empty());
} else if (fn == "aten::mul") {
found_mul = true;
TORCH_CHECK(sizes.size() > 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
}
}
TORCH_CHECK(found_test);
TORCH_CHECK(found_pow);
TORCH_CHECK(found_mul);
}
void checkTracedOutputs(const TracedTestValues& outputs) {
bool found_test = false;
bool found_pow = false;
bool found_mul = false;
for (const auto& output : outputs) {
const auto& fn = std::get<0>(output);
const auto& sizes = std::get<1>(output);
if (fn == "test") {
found_test = true;
TORCH_CHECK(sizes.empty());
} else if (fn == "aten::pow") {
found_pow = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
} else if (fn == "aten::mul") {
found_mul = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
}
}
TORCH_CHECK(found_test);
TORCH_CHECK(found_pow);
TORCH_CHECK(found_mul);
}
static bool bad_scope = false;
template <RecordScope scope, size_t* cnt>
std::unique_ptr<at::ObserverContext> checkScopeCallback(
const at::RecordFunction& fn) {
if (fn.scope() == scope) {
++(*cnt);
} else {
bad_scope = true;
}
return nullptr;
}
template <RecordScope scope, size_t* cnt>
void pushScopedCallback() {
at::addGlobalCallback(
at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
.scopes({scope}));
}
// These cannot be function-local because that would prohibit them
// from being used as template arguments prior to C++17.
static size_t fun_cnt;
static size_t ts_fun_cnt;
static size_t user_scope_cnt;
void checkScopeCallbacks() {
static bool found_function_scope;
static bool found_method_scope;
static bool found_user_scope;
found_function_scope = false;
found_method_scope = false;
found_user_scope = false;
at::addGlobalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
if (fn.scope() == at::RecordScope::FUNCTION &&
std::string(fn.name()) == "test_function") {
found_function_scope = true;
}
if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
std::string(fn.name()) == "test_method") {
found_method_scope = true;
}
if (fn.scope() == at::RecordScope::USER_SCOPE &&
std::string(fn.name()) == "test_user_scope") {
found_user_scope = true;
}
return nullptr;
}));
bad_scope = false;
fun_cnt = 0;
pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
ts_fun_cnt = 0;
pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
user_scope_cnt = 0;
pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
TORCH_CHECK(at::hasCallbacks());
{
RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
{ RECORD_FUNCTION("test_function", {}); }
{ RECORD_USER_SCOPE("test_user_scope"); }
}
TORCH_CHECK(!bad_scope);
TORCH_CHECK(fun_cnt == 1);
TORCH_CHECK(ts_fun_cnt == 1);
TORCH_CHECK(user_scope_cnt == 1);
TORCH_CHECK(found_function_scope);
TORCH_CHECK(found_method_scope);
TORCH_CHECK(found_user_scope);
}
static TracedTestValues traced_inputs;
static TracedTestValues traced_outputs;
static std::unordered_set<std::string> ts_input_names;
static std::unordered_set<std::string> ts_output_names;
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
const RecordFunction& fn) {
if (fn.scope() == RecordScope::FUNCTION) {
auto inputs = fn.inputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& input : inputs) {
if (input.isTensor()) {
sizes.push_back(input.toTensor().sizes().vec());
} else if (input.isScalar()) {
// NOLINTNEXTLINE(modernize-use-emplace)
sizes.push_back(std::vector<int64_t>());
}
}
traced_inputs.push_back(std::make_tuple(fn.name(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_input_names.insert(fn.name());
}
return nullptr;
}
void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
if (fn.scope() == RecordScope::FUNCTION) {
auto outputs = fn.outputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& output : outputs) {
if (output.isTensor()) {
sizes.push_back(output.toTensor().sizes().vec());
} else if (output.isScalar()) {
sizes.emplace_back();
}
}
traced_outputs.push_back(std::make_tuple(fn.name(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_output_names.insert(fn.name());
}
}
TEST(RecordFunctionTest, TracedTestInputsOutputs) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// [(fn, [[sizes], [sizes], ...]), ...]
addGlobalCallback(
RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
.needsInputs(true)
.needsOutputs(true));
TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
{
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
auto t2 = invokeTestRecordFunction(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
eager_inputs = traced_inputs;
eager_outputs = traced_outputs;
traced_inputs.clear();
traced_outputs.clear();
TORCH_CHECK(ts_input_names.empty());
TORCH_CHECK(ts_output_names.empty());
t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
t2 = invokeTestRecordFunctionJIT(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
jit_inputs = traced_inputs;
jit_outputs = traced_outputs;
traced_inputs.clear();
traced_outputs.clear();
}
TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
checkTracedInputs(eager_inputs);
checkTracedOutputs(eager_outputs);
checkTracedInputs(jit_inputs);
checkTracedOutputs(jit_outputs);
at::clearCallbacks();
}
static int sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
if (std::string(fn.name()) == "test") {
++sampled_cb_ctr;
}
return nullptr;
}
static int non_sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
if (std::string(fn.name()) == "test") {
++non_sampled_cb_ctr;
}
return nullptr;
}
TEST(RecordFunctionTest, SampledCallbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// test sampled callbacks
sampled_cb_ctr = 0;
auto setup_sampled_callback = [](double sampling_prob) {
return addGlobalCallback(
RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
};
addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
auto handle = setup_sampled_callback(0.5);
auto run_test_function = []() {
auto t = torch::randn({1, 2, 3}, at::kCPU);
for (auto k = 0; k < 1000; k++) {
invokeTestRecordFunction(t);
}
};
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 1000);
TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
sampled_cb_ctr = 0;
removeCallback(handle);
handle = setup_sampled_callback(0.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 2000);
TORCH_CHECK(sampled_cb_ctr == 0);
sampled_cb_ctr = 0;
removeCallback(handle);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
handle = setup_sampled_callback(1.0);
run_test_function();
TORCH_CHECK(non_sampled_cb_ctr == 3000);
TORCH_CHECK(sampled_cb_ctr == 1000);
clearCallbacks();
// test the scope of the callbacks
checkScopeCallbacks();
clearCallbacks();
}
TEST(RecordFunctionTest, RecordFunctionGuard) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::vector<std::string> fn_names;
static std::mutex guard_mtx;
// check record function guard
addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
std::lock_guard<std::mutex> lock(guard_mtx);
// NOLINTNEXTLINE(modernize-use-emplace)
fn_names.push_back(fn.name());
return nullptr;
}));
{
RecordFunctionGuard g1(false);
{
RECORD_USER_SCOPE("A");
{
RecordFunctionGuard g2(true);
RECORD_USER_SCOPE("B");
{
DisableRecordFunctionGuard g3;
RECORD_USER_SCOPE("C");
}
}
{ RECORD_USER_SCOPE("D"); }
}
}
TORCH_CHECK(fn_names.size() == 1);
TORCH_CHECK(fn_names[0] == "B");
clearCallbacks();
}
static std::vector<size_t> ids;
template <size_t id>
auto add_remove_test_add_cb() {
return addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(id);
return nullptr;
}));
}
TEST(RecordFunctionTest, Callbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
auto h1 = add_remove_test_add_cb<1>();
add_remove_test_add_cb<2>();
auto h3 = add_remove_test_add_cb<3>();
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 3);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
ids.clear();
removeCallback(h1);
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
ids.clear();
removeCallback(h3);
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
clearCallbacks();
// thread local / global callbacks
ids.clear();
add_remove_test_add_cb<1>();
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
auto th = std::thread([]() {
addThreadLocalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(2);
return nullptr;
}));
{ RECORD_USER_SCOPE("test_thread"); }
});
th.join();
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
ids.clear();
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
clearCallbacks();
// START: thread local / global context check callbacks
struct TestContext : public ObserverContext {
int a{0};
std::string b;
};
ids.clear();
{ // START: global test
addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction&
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>();
ctx->a = 123;
ctx->b = "test_str";
ids.push_back(1);
return ctx;
},
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx != nullptr);
TORCH_CHECK(ctx->a == 123);
TORCH_CHECK(ctx->b == "test_str");
}));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
} // END: global test
{ // START: thread local test
auto ctx_th = std::thread([]() {
const std::string test_str = "test thread str";
addThreadLocalCallback(RecordFunctionCallback(
[](const RecordFunction&
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>();
ctx->a = 234;
ctx->b = "test_thread_str";
ids.push_back(2);
return ctx;
},
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == 234);
TORCH_CHECK(ctx->b == "test_thread_str");
}));
// Will call both global and thread local callbacks.
{ RECORD_USER_SCOPE("test_thread"); }
});
ctx_th.join();
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
ids.clear();
} // END: thread local test
clearCallbacks();
}
TEST(RecordFunctionTest, ShouldRun) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static bool ran = false;
auto handle = addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ran = true;
return nullptr;
}));
{ RECORD_USER_SCOPE("test"); }
EXPECT_TRUE(ran) << "first run didn't happen";
ran = false;
disableCallback(handle);
{ RECORD_USER_SCOPE("test"); }
EXPECT_FALSE(ran) << "second run happened but shouldn't have";
ran = false;
reenableCallback(handle);
{ RECORD_USER_SCOPE("test"); }
EXPECT_TRUE(ran) << "run after re-enable didn't happen";
ran = false;
clearCallbacks();
}
TEST(RecordFunctionTest, Basic) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::string recorded_op;
static bool has_ids = false;
// test propagation of TLS callbacks
std::thread t([]() {
RecordFunctionGuard enable_rec_fn;
auto handle = addThreadLocalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
recorded_op = fn.name();
return nullptr;
}));
ThreadLocalState state;
std::thread t_child([state]() {
ThreadLocalStateGuard g_tls(state);
RECORD_USER_SCOPE("test_in_thread");
});
t_child.join();
EXPECT_EQ(recorded_op, "test_in_thread");
removeCallback(handle);
});
t.join();
clearCallbacks();
// test set ids
addGlobalCallback(
RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0;
return nullptr;
})
.needsIds(true));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(has_ids);
clearCallbacks();
has_ids = false;
addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0;
return nullptr;
}));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(!has_ids);
clearCallbacks();
}
TEST(RecordFunctionTest, OperatorNameOverload) {
static std::set<std::string> operator_names;
at::addGlobalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
c10::optional<c10::OperatorName> op_name =
fn.operator_name();
if (op_name.has_value()) {
operator_names.insert(c10::toString(*op_name));
} else {
operator_names.insert("No Operator Name");
}
return nullptr;
})
.scopes({at::RecordScope::FUNCTION}));
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(false);
auto t2 = t.pow(2);
at::clearCallbacks();
EXPECT_TRUE(operator_names.count("No Operator Name") == 0)
<< "Expected that all traced operators had an associated OperatorName object";
EXPECT_TRUE(operator_names.count("aten::randn") == 1)
<< "Expected aten::randn to have been called and recorded, but it was not";
EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1)
<< "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not";
}
class TestThreadLocalDebugInfo : public c10::DebugInfoBase {
public:
int getModelId() const {
return model_id_;
}
void setModelId(int model_id) {
model_id_ = model_id;
}
// NOLINTNEXTLINE(modernize-use-equals-default)
virtual ~TestThreadLocalDebugInfo() override {}
private:
int model_id_ = 0;
};
void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
auto* debug_info = c10::ThreadLocalDebugInfo::get(kind);
TORCH_CHECK(debug_info != nullptr);
auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info);
TORCH_CHECK(test_debug_info != nullptr);
TORCH_CHECK(test_debug_info->getModelId() == model_id);
}
TEST(ThreadLocalDebugInfoTest, Basic) {
static std::atomic<bool> done{false};
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
debug_info->setModelId(42);
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
}
// check that thread local debug info is propagated through fork calls
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true;
});
}
while (!done) {
}
// check that thread local debug info is propagated through backward pass
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
done = false;
auto handle = addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true;
return nullptr;
}));
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
auto t2 = t.pow(2);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
}
removeCallback(handle);
TORCH_CHECK(done);
// check nested debug info
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
{
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
{
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
debug_info->setModelId(314);
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info);
{
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = false;
at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = true;
});
while (!done) {
}
}
}
}
}
}
TEST(TestSymIntArrayRef, BasicConversion) {
const size_t X = 2, Y = 4, Z = 5;
std::vector<int64_t> tgt_size_v{2, 4, 5};
std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
auto a = at::randn({1, 4, 1}, at::kCPU);
auto b = a.expand_symint(tgt_size);
auto c = a.expand(tgt_size_v);
ASSERT_TRUE(torch::allclose(b, c));
}
TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
static const size_t LENGTH = 5;
auto a = at::randn({10}, at::kCPU);
c10::SymInt si(LENGTH);
auto b = a.narrow_copy_symint(0, 0, si);
auto c = a.narrow(0, 0, LENGTH);
ASSERT_TRUE(torch::allclose(b, c));
}
TEST(TestSymInt, NarrowCopy) {
static const size_t LENGTH = 5;
auto a = at::randn({10}, at::kCPU);
auto b = a.narrow_copy(0, 0, LENGTH);
auto c = a.narrow(0, 0, LENGTH);
ASSERT_TRUE(torch::allclose(b, c));
}
TEST(TestSymInt, AddSymbolicInt) {
c10::SymInt a(5);
c10::SymInt b(3);
ASSERT_TRUE((a + b).expect_int() == 8);
}
TEST(FallbackGraphsTest, Basic) {
auto x = at::randn({1}, at::kCPU);
auto y = at::randn({1}, at::kCPU);
auto stack = createStack({x.clone(), y.clone()});
auto graph_string = R"IR(
graph(%0 : Float(1),
%1 : Float(1)):
%2 : Tensor = aten::mul(%0, %1)
%3 : Tensor = aten::mul(%2, %0)
return (%3))IR";
auto graph = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, graph.get());
{
Code code(graph, "");
InterpreterState interpreter{code};
interpreter.run(stack);
}
at::Tensor et;
pop(stack, et);
float ef = et.item<float>();
{
EnableProfilingGuard epg;
GraphFunction f("fallbackGraphs", graph, nullptr);
for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) {
stack.emplace_back(x.clone());
stack.emplace_back(y.clone());
if (i == getNumProfiledRuns()) {
// we will be modifying a profiled graph
// before ProfilingGraphExecutor
// will optimize it in the next iteration
auto opt_graph = lastExecutedOptimizedGraph();
// this is safe to do since we are done profiling
ProfilingRecord::removeProfileCounter(opt_graph->block());
replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs());
auto it = opt_graph->block()->nodes().begin();
ASSERT_EQ(it->kind(), prim::FallbackGraph);
auto fallback = *it++;
ASSERT_EQ(it, opt_graph->block()->nodes().end());
ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph));
testing::FileCheck()
.check("Tensor = aten::mul")
->check("Tensor = aten::mul")
->run(*fallback->g(attr::Subgraph));
}
f.run(stack);
at::Tensor at;
pop(stack, at);
float af = at.item<float>();
ASSERT_EQ(af, ef);
}
auto opt_graph = lastExecutedOptimizedGraph();
testing::FileCheck()
.check("(Tensor) = prim::CallFunction")
->run(*opt_graph);
}
}
// TODO this test wasn't running and is broken.
// TEST(AutogradProfilerTest, Basic) {
// constexpr int batch_size = 4;
// constexpr int input_size = 256;
// constexpr int seq_len = 32;
// int hidden_size = 2 * input_size;
// auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
// auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
// auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
// auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
// auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
// std::stringstream ss;
// {
// RecordProfile guard(ss);
// for (size_t i = 0; i < 100; ++i) {
// std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
// }
// }
// std::string result = ss.str();
// size_t count = 0;
// for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
// count++, pos++) {
// }
// ASSERT_EQ((count, 200);
// }
TEST(NoneSchemaMatchTest, Basic) {
RegisterOperators reg({
Operator(
"prim::test_none() -> int?",
[](Stack& stack) { push(stack, IValue()); },
aliasAnalysisFromSchema()),
Operator(
"prim::is_none(int? a) -> bool",
[](Stack& stack) {
IValue a = pop(stack);
if (a.isNone()) {
push(stack, true);
} else {
push(stack, false);
}
},
aliasAnalysisFromSchema()),
});
// Constant propagation will run test_none and produce a None,
// testing that its type is set appropriately and schema matching doesn't
// fail when running is_none
auto r = std::make_shared<Graph>();
auto& g = *r;
auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
g.registerOutput(out_bool);
ConstantPropagation(r);
auto nodes = r->block()->nodes();
// checking that constant propagation ran wo/failure
AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
}
static int testPassValue = 0;
void fakePass(std::shared_ptr<Graph>& g) {
testPassValue++;
return;
}
RegisterPass p(fakePass);
TEST(PassManagementTest, Basic) {
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%a):
return (%a))IR",
&*graph);
std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
GraphExecutor executor(graph, "");
executor.run(stack);
return stack;
};
run(graph, stack);
// we will not run fusion in simple mode
if (!getExecutorMode()) {
AT_ASSERT(testPassValue);
}
}
static void checkShape(TypePtr typ, std::vector<int64_t> expected) {
auto ptp = typ->expect<TensorType>();
ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
}
static void checkShape(
Node* n,
std::vector<int64_t> expected,
bool prev = true) {
auto profile = (prev) ? n->inputs().at(0)->node() : n;
checkShape(profile->output()->type(), expected);
}
void count_(
Block* block,
const std::function<bool(Node* n)>& pred,
size_t& count) {
for (Node* n : block->nodes()) {
if (pred(n)) {
count++;
}
for (Block* ib : n->blocks()) {
count_(ib, pred, count);
}
}
}
size_t countNodes(
const std::shared_ptr<Graph>& graph,
const std::function<bool(Node* n)>& pred) {
size_t count = 0;
count_(graph->block(), pred, count);
return count;
}
bool true_pred(Node* n) {
return true;
};
bool is_loop(Node* n) {
return n->kind() == prim::Loop;
};
TEST(LoopPeelerTest, NoInductionVariableUse) {
// do not use an induction variable explicitly
static const auto str_func_def = R"JIT(
def test_peel_n_times():
sum = 0
for i in range(10):
sum += 2
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 20);
}
// test peeling more than one iteration
{
LoopsPeeler peeler(true_pred, 3);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 20);
}
}
TEST(LoopPeelerTest, YesInductionVariableUse) {
// uses the induction variable
static const auto str_func_def = R"JIT(
def test_peel_n_times():
sum = 0
for i in range(10):
sum += i
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 45);
}
// test peeling more than one iteration
{
LoopsPeeler peeler(true_pred, 3);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 45);
}
}
TEST(LoopPeelerTest, LoopWithTerminationCondition) {
// tests with explicit termination conditions
static const auto str_func_def = R"JIT(
def test_with_cond_times():
sum = 0
i = 0
while (sum < 2):
sum += i
i += 1
return sum
)JIT";
// the peel changes the termination condition to false
// so the original loop doesn't run
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
auto stack = createStack({});
// peeling 5 iterations should update the termination
// condition to false
{
LoopsPeeler peeler(true_pred, 5);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
// the termination condition remains true
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
int num_loops =
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
ASSERT_EQ(num_loops, 2);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
}
// tests simple nested loops
TEST(LoopPeelerTest, SimpleNestedLoops) {
static const auto str_func_def = R"JIT(
def test_nested_loops():
sum = 0
i = 0
for i in range(10):
for j in range(10):
sum += i + j
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 900);
}
{
LoopsPeeler peeler(true_pred, 5);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 900);
}
}
TEST(LoopPeelerTest, SimpleNestedLoops2) {
static const auto str_func_def = R"JIT(
def test_nested_loops():
sum = 0
i = 0
for i in range(10):
j = 0
while sum < 2:
sum += i + j
j += 1
return sum
)JIT";
auto cu = compile(str_func_def);
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});
{
LoopsPeeler peeler(true_pred, 1);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
{
LoopsPeeler peeler(true_pred, 5);
auto copy = f.graph()->copy();
peeler.run(copy);
ASSERT_EQ(countNodes(copy, is_loop), 5);
Code code(copy, "");
InterpreterState interpreter{code};
interpreter.run(stack);
ASSERT_EQ(stack.back().toInt(), 3);
}
}
TEST(JitTracing, Basic) {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto input = at::randn({batch_size, input_size}, at::kCPU);
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
auto graph = build_lstm();
auto stack = createStack({input, hx, cx, w_ih, w_hh});
auto traced = TraceGraph(graph, stack);
// Check that the inputs of traced graph have the same type as the inputs
// specified here.
ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input));
ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx));
ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx));
ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih));
ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh));
Tensor prof_out;
pop(stack, prof_out);
{
stack = createStack({input, hx, cx, w_ih, w_hh});
Code cd(traced, "traced");
InterpreterState is{cd};
is.run(stack);
Tensor traced_out;
pop(stack, traced_out);
torch::allclose(prof_out, traced_out);
}
{
stack = createStack({input, hx, cx, w_ih, w_hh});
Code cd(graph, "graph");
InterpreterState is{cd};
is.run(stack);
Tensor scripted_out;
pop(stack, scripted_out);
torch::allclose(prof_out, scripted_out);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
static const auto basic_example = R"JIT(
def basic(x, y):
a = x + y
b = x * y
c = x + 1
d = a - c
e = b - c
return d + e
)JIT";
auto cu = compile(basic_example);
auto& fun = toGraphFunction(cu->get_function("basic"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
auto stack = createStack({x, y});
// introduce some profiling information
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
auto copy = pr->profiled_graph_->copy();
ProfilingRecord::removeProfileCounter(copy->block());
InsertGuards(copy);
auto nodes = copy->block()->nodes();
auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
return n->kind() == prim::Guard;
});
ASSERT_NE(guard, nodes.end());
ASSERT_EQ(
guard->input()->type()->expectRef<TensorType>().sizes().size(),
c10::nullopt);
checkShape(*guard, {2, 3}, false);
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 12);
// now eliminate as many guards as possible
// we should be left with two guards on x and y's defs
EliminateRedundantGuards(copy);
num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 2);
}
TEST(InsertBailOutsTest, Basic) {
static const auto basic_example = R"JIT(
def basic_loop(x, y):
a = x + 1
b = y + 2
c = x + y + 3
for i in range(10):
a = a + b
# invariant
d = b * c
#
a = a - d
e = a + 4
return e
)JIT";
auto cu = compile(basic_example);
auto& fun = toGraphFunction(cu->get_function("basic_loop"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
auto stack = createStack({x, y});
// introduce some profiling information
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
auto copy = pr->profiled_graph_->copy();
ProfilingRecord::removeProfileCounter(copy->block());
InsertGuards(copy);
EliminateRedundantGuards(copy);
auto nodes = copy->block()->nodes();
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
ASSERT_EQ(num_guards, 3);
InsertBailOuts(copy);
auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
ASSERT_EQ(num_guards, num_bailouts);
std::vector<Node*> bailouts(num_bailouts);
std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);
for (auto blo : bailouts) {
ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
}
}
TEST(ProfilerTest, Basic) {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto input = at::randn({batch_size, input_size}, at::kCPU);
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
auto g = build_lstm();
auto stack = createStack({input, hx, cx, w_ih, w_hh});
auto& opt_graph = *g.get();
ArgumentSpecCreator arg_spec_creator(opt_graph);
ArgumentSpec spec =
arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
arg_spec_creator.specializeTypes(opt_graph, spec);
auto pr = ProfilingRecord::instrumentGraph(g);
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
// profiled types are stored as attributes and show up in the dump, e.g.
// Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1],
// requires_grad=0, device=cpu)
testing::FileCheck()
.check("Tensor = prim::profile[profiled_type")
->check_same("256")
->run(*pr->profiled_graph_);
auto begin = pr->profiled_graph_->block()->nodes().begin();
auto end = pr->profiled_graph_->block()->nodes().end();
auto mm =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; });
ASSERT_NE(mm, end);
std::vector<int64_t> mm_expected{4, 2048};
std::vector<int64_t> eltwise{4, 512};
checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected);
auto mul_n =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; });
ASSERT_NE(mul_n, end);
checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
auto tanh_n =
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
}
TEST(ProfilerTest, OptionalProfiling) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%inp : Tensor,
%weight : Tensor,
%bias : Tensor?):
%1 : Tensor = aten::linear(%inp, %weight, %bias)
return (%1))IR",
&*graph,
vmap);
auto pr = ProfilingRecord::instrumentGraph(graph);
pr->profiling_count_ = 2;
auto input = torch::randn({1, 2});
auto weight = torch::randn({2, 2});
auto bias = torch::randn({1, 2});
auto stack = createStack({input, weight, bias});
Code cd(pr->profiled_graph_, "");
InterpreterState is{cd};
is.run(stack);
testing::FileCheck()
.check_count("Tensor? = prim::profile[profiled_type", 1, true)
->run(*pr->profiled_graph_);
// make sure we recorded the shape
auto begin = pr->profiled_graph_->block()->nodes().begin();
auto end = pr->profiled_graph_->block()->nodes().end();
auto linear = std::find_if(
begin, end, [](Node* n) { return n->kind() == aten::linear; });
ASSERT_NE(linear, end);
std::vector<int64_t> bias_expected_shape = {1, 2};
auto profiled_bias = linear->namedInput("bias")->node();
checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
ASSERT_EQ(0, profiled_bias->i(attr::seen_none));
auto none_bias = c10::IValue();
stack.clear();
stack.emplace_back(input);
stack.emplace_back(weight);
stack.emplace_back(none_bias);
is = InterpreterState{cd};
is.run(stack);
// make sure we recorded that "None" was seen.
begin = pr->profiled_graph_->block()->nodes().begin();
end = pr->profiled_graph_->block()->nodes().end();
linear = std::find_if(
begin, end, [](Node* n) { return n->kind() == aten::linear; });
ASSERT_NE(linear, end);
profiled_bias = linear->namedInput("bias")->node();
checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
ASSERT_EQ(1, profiled_bias->i(attr::seen_none));
}
TEST(CallStackTest, Basic) {
const auto text = R"(
def ham(x):
return x/7
def bar(x):
return x*3
def baz(x):
return ham(x)*x
def foo(x):
return bar(x)*baz(x)*11
)";
auto cu = compile(text);
const auto& foo = toGraphFunction(cu->get_function("foo"));
for (Node* n : foo.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::i) {
continue;
}
int v = n->i(attr::value);
switch (v) {
case 3: {
// Const 3 comes from function 'bar', which gets inlined to 'foo'.
// The callstack for the corresponding node should contain only the
// function 'bar'.
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 1);
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar"));
break;
}
case 7: {
// Const 7 comes from function 'ham', which gets inlined to 'baz',
// which is then inlined to 'foo'. The callstack for the corresponding
// node should contain these two functions.
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 2);
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz"));
ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham"));
break;
}
case 11: {
// Const 11 comes from function 'foo', which is not inlined anywhere
// and thus it should not have a callstack.
ASSERT_FALSE(n->callstack());
break;
}
}
}
}
// Check that inlining doesn't corrupt callstack of the callee's nodes.
const auto& baz = toGraphFunction(cu->get_function("baz"));
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::i) {
continue;
}
int v = n->i(attr::value);
ASSERT_TRUE(v == 7);
// Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
// was also inlined into 'foo', but when looking at the graph of 'baz' we
// should only see a callstack of depth 1 (containing only 'ham').
ASSERT_TRUE(n->callstack());
auto callstack_vector = (*n->callstack())->vec();
ASSERT_EQ(callstack_vector.size(), 1);
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham"));
}
}
}
TEST(CallStackTest, Caching) {
const auto text = R"(
def a(x):
print("a1")
print("a2")
return x
def b(x):
print("b1")
print("b2")
a(x)
return x
def c(x):
print("c1")
print("c2")
b(x)
return x
)";
auto cu = compile(text);
const auto& baz = toGraphFunction(cu->get_function("c"));
std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
n->kindOf(attr::value) != AttributeKind::s) {
continue;
}
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
std::string v = n->s(attr::value);
if (n->callstack()) {
callstack_objects[v] = n->callstack()->get();
}
}
}
// We expect to see nodes prim::Constant[value="a1"] and
// prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
// the same (a->b->c), so we want to make sure we're not creating different
// callstack entries for them.
ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
}
TEST(InlinedCallStackTest, BlockAnnotation) {
Module a("A");
a.define(R"(
def forward(self, x, y, z: int):
if (z == 1):
return x + y
else:
return x * y
)");
Module b("B");
b.define(R"(
def forward(self, x):
return x + 2
)");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"(
def forward(self, x, y, z: int):
return self.A0.forward(x, y, z) + self.B0.forward(x)
)");
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::stringstream add_ss, mul_ss;
for (Node* n : graph->nodes()) {
if (n->kind() == prim::If) {
for (Block* block : n->blocks()) {
for (Node* if_node : block->nodes()) {
if (if_node->kind() == aten::add) {
for (const auto& e : if_node->callstack().value()->vec()) {
add_ss << std::get<1>(e);
}
add_ss << if_node->sourceRange();
}
if (if_node->kind() == aten::mul) {
for (const auto& e : if_node->callstack().value()->vec()) {
mul_ss << std::get<1>(e);
}
mul_ss << if_node->sourceRange();
}
}
}
}
}
ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
ASSERT_NE(
add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
ASSERT_NE(
mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlinedCallStackTest, SelfCallMethods) {
Module a("A");
a.define(R"(
def my_new_method(self, x):
return x * 3
def forward_impl_(self, x, y):
return self.my_new_method(x) + y
def forward(self, x, y):
y = y + 2
return self.forward_impl_(x, y)
)");
Module b("B");
b.define(R"(
def forward(self, x):
return x + 2
)");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"(
def call_b(self, x):
return self.B0.forward(x)
def forward(self, x, y):
return self.A0.forward(x, y) + self.call_b(x)
)");
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::unordered_map<std::string, size_t> module_hierarchies;
for (Node* n : graph->nodes()) {
auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
if (module_hierarchies.count(hierarchy) == 0) {
module_hierarchies[hierarchy] = 0;
}
module_hierarchies[hierarchy] += 1;
}
ASSERT_EQ(module_hierarchies["A0(A)"], 2);
ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
}
TEST(AutogradSymbolsTest, Basic) {
Symbol sym = Symbol::fromQualString("aten::test_symbol");
Graph graph;
auto node = graph.create(sym);
TORCH_CHECK(canRunWithAutograd(node));
sym = Symbol::fromQualString("prim::test_symbol");
node = graph.create(sym);
TORCH_CHECK(canRunWithAutograd(node));
sym = Symbol::fromQualString("prim::FusionGroup");
node = graph.create(sym);
TORCH_CHECK(!canRunWithAutograd(node));
sym = Symbol::fromQualString("custom::test_symbol");
node = graph.create(sym);
TORCH_CHECK(!canRunWithAutograd(node));
}
TEST(DefaultArgTypeHintingTest, Basic) {
const auto text_non_hinted = R"(
def a(x, y=1):
print("a1")
print("a2")
return x
)";
const auto text_hinted = R"(
def a(x, y:int=1):
print("a1")
print("a2")
return x
)";
try {
compile(text_non_hinted);
ASSERT_TRUE(0);
} catch (const std::exception& c) {
}
auto cu = compile(text_hinted);
}
// Basic set case.
TEST(FuturesTest, Basic) {
auto f1 = c10::make_intrusive<Future>(IntType::get());
ASSERT_FALSE(f1->completed());
ASSERT_FALSE(f1->hasValue());
int32_t sat1 = 0;
int32_t sat2 = 0;
f1->addCallback([&](Future& /* unused */) { ++sat1; });
f1->markCompleted(43);
ASSERT_TRUE(f1->completed());
ASSERT_TRUE(f1->hasValue());
ASSERT_FALSE(f1->hasError());
ASSERT_EQ(sat1, 1);
ASSERT_EQ(f1->constValue().toInt(), 43);
ASSERT_EQ(f1->value().toInt(), 43);
f1->addCallback([&](Future& /* unused */) { ++sat2; });
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
}
// Sparse CUDA tensor test
TEST(FutureTest, SparseTensor) {
// Skip test if CUDA is not available.
bool has_cuda = at::globalContext().hasCUDA();
if (!has_cuda) {
LOG(INFO) << "CUDA not available, skipping test";
}
for (int i = 0; i < 2; ++i) {
auto f = c10::make_intrusive<Future>(TensorType::get());
at::TensorOptions opts = at::TensorOptions().device(at::DeviceType::CUDA);
auto sparse_tensor = i == 0 ? at::ones(10).to_sparse()
: at::sparse_coo_tensor(
at::arange(10).unsqueeze(0).to(at::kLong),
at::ones({10, 10}),
opts);
// Runs storage extraction for sparse CUDA tensors
f->markCompleted(sparse_tensor);
ASSERT_TRUE(f->completed());
ASSERT_FALSE(f->hasError());
}
}
// Basic error cases.
TEST(FuturesTest, Error) {
auto f1 = c10::make_intrusive<Future>(IntType::get());
int sat1 = 0;
int sat2 = 0;
f1->addCallback([&](Future& /* unused */) { ++sat1; });
f1->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_EQ(sat1, 1);
ASSERT_TRUE(f1->completed());
ASSERT_TRUE(f1->hasError());
ASSERT_FALSE(f1->hasValue());
try {
(void)f1->value();
ASSERT_TRUE(false); // Supposed to throw.
} catch (const std::exception& e) {
ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
}
f1->addCallback([&](Future& /* unused */) { ++sat2; });
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
f1->setErrorIfNeeded(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
try {
(void)f1->constValue();
ASSERT_TRUE(false); // Supposed to throw.
} catch (const std::exception& e) {
// Original error should be logged.
ASSERT_TRUE(std::string(e.what()).find("Failed") != std::string::npos);
}
}
// then
TEST(FuturesTest, Then) {
auto f1 = c10::make_intrusive<Future>(IntType::get());
auto f2 = f1->then(
[](Future& f1) -> IValue { return f1.constValue().toInt() + 1; },
IntType::get());
auto f3 = f2->then(
[](Future& f2) -> IValue { return f2.constValue().toInt() * 3; },
IntType::get());
bool done = false;
f3->addCallback([&done](Future& f3) {
ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3);
done = true;
});
ASSERT_FALSE(done);
f1->markCompleted(42);
ASSERT_TRUE(done);
}
// collectAll()
TEST(FuturesTest, CollectAll) {
auto s1 = c10::make_intrusive<Future>(IntType::get());
auto s2 = c10::make_intrusive<Future>(IntType::get());
auto s3 = c10::make_intrusive<Future>(IntType::get());
// Empty case
c10::List<intrusive_ptr<ivalue::Future>> futures(
FutureType::create(IntType::get()));
auto c1 = collectAll(futures);
ASSERT_TRUE(c1->completed());
ASSERT_EQ(c1->value().toList().size(), 0);
ASSERT_TRUE(
*(c1->value().toList().elementType()) ==
*FutureType::create(IntType::get()));
// 1-element, initially not completed.
futures.push_back(s1);
auto c2 = collectAll(futures);
ASSERT_FALSE(c2->completed());
s1->markCompleted(5);
ASSERT_TRUE(c2->completed());
ASSERT_EQ(c2->value().toList().size(), 1);
ASSERT_TRUE(
*(c2->value().toList().elementType()) ==
*FutureType::create(IntType::get()));
ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5);
// 1-element, already completed
auto c3 = collectAll(futures);
ASSERT_TRUE(c3->completed());
ASSERT_EQ(c3->value().toList().size(), 1);
ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5);
// 3 elements.
futures.push_back(s2);
futures.push_back(s3);
auto c4 = collectAll(futures);
ASSERT_FALSE(c4->completed());
s3->markCompleted(7);
ASSERT_FALSE(c4->completed());
s2->markCompleted(6);
ASSERT_TRUE(c4->completed());
ASSERT_EQ(c4->value().toList().size(), 3);
ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5);
ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6);
ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7);
ASSERT_TRUE(
*(c4->value().toList().elementType()) ==
*FutureType::create(IntType::get()));
// Handle exception in the list.
auto s4 = c10::make_intrusive<Future>(IntType::get());
futures.push_back(s4);
auto c5 = collectAll(futures);
ASSERT_FALSE(c5->completed());
s4->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_TRUE(c5->completed());
try {
c5->value();
ASSERT_TRUE(false); // supposed to throw
} catch (const std::exception& e) {
ASSERT_EQ(std::string(e.what()), "Failed");
}
}
// collectAny()
TEST(FuturesTest, CollectAny) {
auto s1 = c10::make_intrusive<Future>(IntType::get());
// Empty case
c10::List<intrusive_ptr<ivalue::Future>> futures(
FutureType::create(IntType::get()));
auto c1 = collectAny(futures);
ASSERT_TRUE(c1->completed());
// 1 element, not yet satisfied
futures.push_back(s1);
auto c2 = collectAny(futures);
ASSERT_FALSE(c2->completed());
s1->markCompleted(5);
ASSERT_TRUE(c2->completed());
ASSERT_TRUE(c2->value().isInt());
ASSERT_EQ(c2->value().toInt(), 5);
// 1 element already satisfied.
auto c3 = collectAny(futures);
ASSERT_TRUE(c3->completed());
ASSERT_TRUE(c3->value().isInt());
ASSERT_EQ(c3->value().toInt(), 5);
// 2 elements
futures.clear();
auto s2 = c10::make_intrusive<Future>(IntType::get());
auto s3 = c10::make_intrusive<Future>(IntType::get());
futures.push_back(s2);
futures.push_back(s3);
auto c4 = collectAny(futures);
ASSERT_FALSE(c4->completed());
s3->markCompleted(7);
ASSERT_TRUE(c4->completed());
ASSERT_EQ(c4->value().toInt(), 7);
s2->markCompleted(1);
ASSERT_EQ(c4->value().toInt(), 7);
}
TEST(TLSFutureCallbacksTest, Basic) {
// cb that verifies the profiler is enabled
auto profilerEnabledCb = [](Future& /* unused */) {
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
};
// test running callbacks with propagation of TLS state.
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
s1->addCallback(wrapPropagateTLSState(profilerEnabledCb));
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
// Since we join here, we can ensure that all callbacks corresponding to
// markCompleted() have finished.
t.join();
torch::autograd::profiler::disableProfilerLegacy();
}
// then() with TLS State
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
auto s2 = s1->then(
wrapPropagateTLSState([&profilerEnabledCb](Future& s1) {
profilerEnabledCb(s1);
return at::IValue(1);
}),
IntType::get());
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
t.join();
s2->wait();
torch::autograd::profiler::disableProfilerLegacy();
}
}
TEST(ProfilerDisableInCallbackTest, Basic) {
// cb that verifies the profiler is enabled
auto profilerEnabledCb = []() {
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
};
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
auto verifyProfilerCb =
wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) {
// Ensure the profiler is still enabled in this thread.
profilerEnabledCb();
auto t1 = torch::ones({2, 2});
auto t2 = torch::ones({2, 2});
torch::add(t1, t2);
// Don't cleanup TLSState, and just consolidate.
auto opts =
torch::autograd::profiler::ProfilerDisableOptions(false, true);
auto thread_event_lists =
// NOLINTNEXTLINE(performance-move-const-arg)
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
// Ensure that the events from this thread are still profiled and we
// obtain the expected in events in our consolidated list when calling
// disableProfilerLegacy().
bool found_ones = false;
bool found_add = false;
for (const auto& li : thread_event_lists) {
for (const auto& evt : li) {
if (strcmp(evt.name(), "aten::add") == 0) {
found_add = true;
} else if (strcmp(evt.name(), "aten::ones") == 0) {
found_ones = true;
}
}
if (found_add && found_ones) {
break;
}
}
ASSERT_TRUE(found_ones);
ASSERT_TRUE(found_add);
});
s1->addCallback(verifyProfilerCb);
// Disable the profiler, but do not consolidate results in the main thread.
auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
// NOLINTNEXTLINE(performance-move-const-arg)
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
t.join();
// Similar to above test, but verifies correctness in the case where
// continuation runs on the main thread.
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
s1 = c10::make_intrusive<Future>(IntType::get());
s1->addCallback(verifyProfilerCb);
// Runs callback inline
s1->markCompleted(at::IValue(1));
opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
// NOLINTNEXTLINE(performance-move-const-arg)
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
}
TEST(RecordDebugHandles, Basic) {
// Enable the profiler in this thread
const std::set<torch::autograd::profiler::ActivityType> activities(
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
activities);
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
activities);
{
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
float x{5.9999}, y{2.1212};
float z = x / y;
(void)z;
}
{
RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
float x{5.9999}, y{2.1212};
float z = x / y;
(void)z;
}
auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
const auto& kineto_events = profiler_results_ptr->events();
size_t my_events{0};
for (const auto& e : kineto_events) {
if (e.name() == "my_function") {
ASSERT_EQ(e.debugHandle(), 42);
my_events++;
} else if (e.name() == "not_my_function") {
ASSERT_EQ(e.debugHandle(), -1);
my_events++;
}
}
ASSERT_EQ(my_events, 2);
}
TEST(RecordDebugHandles, ScopedCallbacks) {
// Enable the profiler in this thread
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
{
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
ASSERT_TRUE(profiler_results_ptr->events().size() > 0);
// Enable the profiler in this thread
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU},
{at::RecordScope::LITE_INTERPRETER});
{
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
profiler_results_ptr = torch::autograd::profiler::disableProfiler();
ASSERT_TRUE(profiler_results_ptr->events().size() == 0);
torch::autograd::profiler::prepareProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU});
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::KINETO, false, false),
{torch::autograd::profiler::ActivityType::CPU},
{at::RecordScope::LITE_INTERPRETER});
{
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
{
RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
auto a = torch::rand({128, 128});
auto b = torch::rand({128, 128});
auto c = a + b;
}
profiler_results_ptr = torch::autograd::profiler::disableProfiler();
const auto& kineto_events = profiler_results_ptr->events();
for (const auto& e : kineto_events) {
if (e.name() == "my_function") {
ASSERT_EQ(e.debugHandle(), 42);
}
}
ASSERT_TRUE(profiler_results_ptr->events().size() == 1);
}
TEST(IValueKWargsTest, Basic) {
const auto text = R"(
def foo(a : int, b : int, c : int = 4):
return a + 2*b + 3*c
)";
auto cu = compile(text);
auto result = cu->get_function("foo")({1}, {{"b", 3}});
ASSERT_EQ(result.toInt(), 19);
}
TEST(ComputeFlopsTest, Basic) {
uint64_t flops = 0;
// Test unknown operator
std::unordered_map<std::string, c10::IValue> extra_args;
flops = torch::profiler::impl::computeFlops(
std::string("aten::unknown"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::conv2d
extra_args.clear();
std::vector<int64_t> input_size = {4, 5, 6, 7};
std::vector<int64_t> weight_size = {3, 5, 2, 1};
std::vector<int64_t> padding = {1, 0};
std::vector<int64_t> stride = {1, 1};
std::vector<int64_t> dilation = {0, 0};
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
extra_args["groups"] = 1;
extra_args["padding"] = at::IValue(at::IntArrayRef(padding));
extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
extra_args["dilation"] = at::IValue(at::IntArrayRef(dilation));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 13440);
// Test aten::conv2d fail
input_size = {4, 5, 6, 7};
weight_size = {4, 5, 6};
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::conv2d fail 2
weight_size = {3, 5, 2, 1};
stride = {0, 0};
extra_args["weight_size"] = at::IValue(at::IntArrayRef(input_size));
extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::conv2d fail 3
extra_args.clear();
input_size = {4, 5, 6, 7};
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
flops = torch::profiler::impl::computeFlops(
std::string("aten::conv2d"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::mm
extra_args.clear();
std::vector<int64_t> mat1_sizes = {3, 4, 5, 6};
std::vector<int64_t> mat2_sizes = {6, 5, 4, 3};
extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
ASSERT_EQ(flops, 43200);
// Test aten::addmm
flops = torch::profiler::impl::computeFlops(
std::string("aten::addmm"), extra_args);
ASSERT_EQ(flops, 43200);
// Test aten::bmm
extra_args.clear();
mat1_sizes = {7, 5, 6};
mat2_sizes = {7, 6, 3};
extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::bmm"), extra_args);
ASSERT_EQ(flops, 1260);
// Test aten::baddbmm
flops = torch::profiler::impl::computeFlops(
std::string("aten::baddbmm"), extra_args);
ASSERT_EQ(flops, 1260);
// Test mm out of range
extra_args.clear();
flops =
torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
ASSERT_EQ(flops, 0);
// Test aten::add.Tensor
extra_args.clear();
std::vector<int64_t> mat_sizes = {3, 4, 5, 6};
extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::add"), extra_args);
ASSERT_EQ(flops, 360);
// Test aten::mul.Tensor
extra_args.clear();
mat_sizes = {3, 4, 5, 6};
extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
flops =
torch::profiler::impl::computeFlops(std::string("aten::mul"), extra_args);
ASSERT_EQ(flops, 360);
}
TEST(TestConstant, TensorGrad) {
auto graph = std::make_shared<Graph>();
IValue ten = torch::randn({3, 5}).requires_grad_(true);
auto con = tryInsertConstant(*graph, ten);
ASSERT_TRUE(con == c10::nullopt);
}
TEST(TestMutation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%9 : int = prim::Constant[value=4]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%7 : Tensor = aten::add_(%x.3, %2, %2)
%y.1 : Tensor = aten::add(%x.3, %9, %2)
return (%y.1))IR",
&*graph,
vmap);
RemoveTensorMutation(graph, [](Node*) { return false; });
testing::FileCheck().check("aten::add_")->run(*graph);
RemoveTensorMutation(graph, [](Node*) { return true; });
testing::FileCheck().check_not("aten::add_")->run(*graph);
}
TEST(TestInplaceToFunctionalActivation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu_(%x.3)
return (%y))IR",
&*graph,
vmap);
InplaceToFunctionalActivation(graph);
testing::FileCheck().check("aten::relu")->run(*graph);
testing::FileCheck().check_not("aten::relu_")->run(*graph);
}
TEST(TestRegisterShapeOp, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph():
%2 : int = prim::Constant[value=5]()
%3: int[] = prim::ListConstruct(%2, %2)
return (%3))IR",
&*graph,
vmap);
auto g2 = std::make_shared<Graph>();
parseIR(
R"IR(
graph():
%2 : Tensor = prim::MakeTestTensor()
return (%2))IR",
&*g2,
vmap);
const FunctionSchema& schema = g2->nodes().begin()->schema();
torch::jit::RegisterShapeComputeGraphForSchema(schema, graph);
PropagateShapesOnGraph(g2);
testing::FileCheck().check("5, 5")->run(*g2);
}
TEST(TestFunctionalToInplaceActivation, Basic) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu(%x.3)
return (%y))IR",
&*graph,
vmap);
FunctionalToInplaceActivation(graph);
testing::FileCheck().check("aten::relu_")->run(*graph);
testing::FileCheck().check_not("aten::relu(")->run(*graph);
}
TEST(TestFunctionExecutor, SimpleExecutorTest) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%x.3 : Tensor = aten::add(%x.1, %2, %2)
%y : Tensor = aten::relu(%x.3)
return (%y))IR",
&*graph);
{
auto func = torch::make_unique<GraphFunction>(
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
Stack stack = {a};
func->run(stack);
auto g = lastExecutedOptimizedGraph();
testing::FileCheck()
.check("prim::profile")
->check("aten::add")
->check("aten::relu")
->run(*g);
}
{
auto func = torch::make_unique<GraphFunction>(
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
Stack stack = {a};
func->run(stack);
auto g = func->getDebugState().graph;
testing::FileCheck()
.check_not("prim::profile")
->check("aten::add")
->check("aten::relu")
->run(*g);
}
}
TEST(TestFunctionExecutor, RunDecompositionTest) {
static auto* func = torch::jit::GetDecompositionExecutor(
"aten::var(Tensor self, bool unbiased=True) -> Tensor");
for (bool unbiased : {true, false}) {
auto input = at::rand({4, 4});
Stack stack = {input, unbiased};
func->run(stack);
at::Tensor out = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(out, input.var(unbiased)));
}
}
TEST(TestShapeGraphLinting, Basic) {
auto schemas = RegisteredShapeComputeSchemas();
for (const auto& schema : schemas) {
// arange does not acually support complex, leave as
// union[int, float] for now
if (schema->name() == "aten::arange") {
continue;
}
auto g = shapeComputeGraphForSchema(*schema);
TORCH_INTERNAL_ASSERT(g);
LintShapeComputeGraph(schema, *g);
}
}
// TODO: move to test_kernel when global settings are explicit
// fusion parameters
class Composed : public ::testing::Test {
public:
void SetUp() override {
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
}
};
TEST_F(Composed, ComposedOp) {
struct WithCPUFuser {
WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
overrideCanFuseOnCPU(val);
}
~WithCPUFuser() {
overrideCanFuseOnCPU(cpuFuserEnabled);
}
bool cpuFuserEnabled;
};
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
%1 : Float(5, 3, strides=[1, 5], device=cpu)):
%2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
%3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2)
%4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3)
return (%3, %4))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);
// wrong input sizes so we hit the fallback path
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat))
.transpose(0, 1);
auto ref1 = a * (a * b);
auto ref2 = a * ref1;
WithCPUFuser g(true);
bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
FuseTensorExprs(
graph,
/*min_group_size*/ 2,
/*add_composed_op*/ true,
/*fuse_to_dynamic_shapes*/ true);
Code code(graph, "");
InterpreterState interpreter{code};
std::vector<IValue> stack = {a, b};
interpreter.run(stack);
at::Tensor out2 = pop(stack).toTensor();
at::Tensor out1 = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(ref1, out1));
ASSERT_TRUE(at::allclose(ref2, out2));
auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
stack = {inp_1, inp_2, a, b};
InterpreterState interpreter2{code};
interpreter2.run(stack);
out2 = pop(stack).toTensor();
out1 = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(ref1, out1));
ASSERT_TRUE(at::allclose(ref2, out2));
// inp_1 is on the bottom of the stack, and corresponds
// to the second output. inp_2 is on the top corresponds to first output
ASSERT_TRUE(at::allclose(inp_1, ref2));
ASSERT_TRUE(at::allclose(inp_2, ref1));
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device;
#endif
}
TEST(ConstantPropagation, CustomClassesCanBePropagated) {
const auto src = R"IR(
graph():
%none: NoneType = prim::Constant()
%dim: int = prim::Constant[value=3]()
%shape: int[] = prim::ListConstruct(%dim, %dim)
%weight: Tensor = aten::ones(%shape, %none, %none, %none, %none)
%scale: float = prim::Constant[value=1.]()
%zero_point: int = prim::Constant[value=0]()
%dtype: int = prim::Constant[value=12]()
%weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
%params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none)
return (%params)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
ConstantPropagation(graph);
testing::FileCheck().check_not("quantized::linear_prepack")->run(*graph);
}
} // namespace jit
} // namespace torch