blob: 7e0733fbc8af438c2f7047fc1a04348cf4b28c8d [file] [log] [blame]
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
#include "test_utils.h"
#include <ATen/core/ivalue.h>
#include <gtest/gtest.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/memory_planner.h>
#include <torch/csrc/jit/runtime/static/passes.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/allclose.h>
#endif
#include <memory>
#include <unordered_map>
using namespace torch::jit;
using namespace torch;
using c10::IValue;
namespace torch {
namespace jit {
namespace test {
namespace {
class GraphExecutorWrapper {
public:
GraphExecutorWrapper() = default;
explicit GraphExecutorWrapper(const std::shared_ptr<Graph>& graph)
: graph_exec_(graph, "") {}
c10::IValue operator()(const std::vector<c10::IValue>& args) {
Stack stack(args);
graph_exec_.run(stack);
if (stack.size() == 1) {
return stack[0];
}
return c10::ivalue::Tuple::create(stack);
}
private:
GraphExecutor graph_exec_;
};
// Test scripts passed to testStaticRuntime can either be IR or JIT.
// The logic for running the script and producing a corresponding StaticModule
// is a bit different for each case. This logic is encapsulated within concrete
// implementations of this class, and testStaticRuntime is only aware of this
// interface.
class StaticRuntimeTestContext {
public:
virtual ~StaticRuntimeTestContext() = default;
virtual IValue getExpected(const std::vector<IValue>& args) = 0;
virtual StaticModule makeStaticModule(
const StaticModuleOptions& opt) const = 0;
};
class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
public:
explicit ModuleStaticRuntimeTestContext(const std::string& source_jit)
: module_("module") {
module_.define(source_jit);
}
IValue getExpected(const std::vector<IValue>& args) override {
return module_.forward(args);
}
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
return torch::jit::StaticModule(
module_, /* is_frozen */ false, opt, /* sample_inputs */ {});
}
private:
Module module_;
};
class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
public:
explicit GraphStaticRuntimeContext(const std::string& source_ir) {
graph_ = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(source_ir, graph_.get(), vmap);
graph_exec_ = GraphExecutorWrapper(graph_);
}
IValue getExpected(const std::vector<IValue>& args) override {
return graph_exec_(args);
}
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
return StaticModule(graph_, opt, /* sample_inputs */ {});
}
private:
std::shared_ptr<Graph> graph_;
GraphExecutorWrapper graph_exec_;
};
std::unique_ptr<StaticRuntimeTestContext> makeTestContext(
const std::string& source) {
try {
return std::make_unique<ModuleStaticRuntimeTestContext>(source);
// Could not parse as TorchScript, assume it's IR
} catch (const std::runtime_error&) {
return std::make_unique<GraphStaticRuntimeContext>(source);
}
}
void compareTensorLists(
const std::vector<IValue>& l, /* expects */
const std::vector<IValue>& r, /* values */
const bool use_allclose,
const bool use_equalnan) {
EXPECT_TRUE(l.size() == r.size());
for (int i = 0; i < l.size(); ++i) {
ASSERT_TRUE(l[i].isTensor());
ASSERT_TRUE(r[i].isTensor());
VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
if (!l[i].toTensor().defined()) {
EXPECT_TRUE(!r[i].toTensor().defined());
} else {
if (use_allclose) {
EXPECT_TRUE(at::allclose(
l[i].toTensor(),
r[i].toTensor(),
/*rtol*/ 1e-05,
/*atol*/ 1e-08,
use_equalnan));
} else {
EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
}
}
}
}
} // namespace
void compareResults(
const IValue& expect,
const IValue& actual,
const bool use_allclose,
const bool use_equalnan) {
if (expect.isTensor()) {
VLOG(2) << "expect " << expect.toTensor() << std::endl;
VLOG(2) << "output " << actual.toTensor() << std::endl;
EXPECT_TRUE(actual.isTensor());
if (use_allclose) {
EXPECT_TRUE(at::allclose(
expect.toTensor(),
actual.toTensor(),
/*rtol*/ 1e-05,
/*atol*/ 1e-08,
use_equalnan));
} else {
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
}
return;
} else if (expect.isTuple()) {
EXPECT_TRUE(actual.isTuple());
auto lhs = expect.toTupleRef().elements();
auto rhs = actual.toTupleRef().elements();
EXPECT_TRUE(lhs.size() == rhs.size());
for (size_t i = 0; i < lhs.size(); i++) {
compareResults(lhs[i], rhs[i]);
}
} else if (expect.isList()) {
EXPECT_TRUE(actual.isList());
auto lhs = expect.toList();
auto rhs = actual.toList();
EXPECT_TRUE(lhs.size() == rhs.size());
for (size_t i = 0; i < lhs.size(); i++) {
compareResults(lhs[i], rhs[i]);
}
} else if (expect.isGenericDict()) {
EXPECT_TRUE(actual.isGenericDict());
auto lhs = expect.toGenericDict();
auto rhs = actual.toGenericDict();
EXPECT_TRUE(lhs.size() == rhs.size());
for (auto& lh : lhs) {
auto f = rhs.find(lh.key());
EXPECT_FALSE(f == rhs.end());
compareResults(lh.value(), f->value());
}
} else {
// fall back to the default comparison impl in IValue
EXPECT_TRUE(expect == actual);
}
}
at::Tensor getTensor(const at::IValue& ival) {
if (ival.isTensor()) {
return ival.toTensor();
} else if (ival.isTensorList()) {
auto tensor_vec = ival.toTensorVector();
TORCH_CHECK(tensor_vec.size() == 1);
return tensor_vec[0];
} else if (ival.isTuple()) {
auto tuple = ival.toTuple();
auto ivalue_vec = tuple->elements();
TORCH_CHECK(ivalue_vec.size() == 1);
return ivalue_vec[0].toTensor();
} else {
CAFFE_THROW("Unknown input IValue");
}
}
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) {
return smodule.findNodeWithKindForTesting(kind);
}
Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
const auto symbol = c10::Symbol::fromQualString(kind);
DepthFirstGraphNodeIterator it(graph);
for (auto* node = it.next(); node != nullptr; node = it.next()) {
if (node->kind() == symbol) {
return node;
}
}
return nullptr;
}
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind) {
return getNodeWithKind(smodule, kind) != nullptr;
}
bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
return getNodeWithKind(graph, kind) != nullptr;
}
std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script) {
script::Module module("module");
module.define(jit_script);
Method method = module.get_method("forward");
return module.get_method("forward").graph();
}
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(ir, graph.get(), vmap);
return graph;
}
void compareResultsWithJIT(
StaticRuntime& runtime,
const std::shared_ptr<Graph>& graph,
const std::vector<c10::IValue>& args,
const bool use_allclose,
const bool use_equalnan) {
GraphExecutorWrapper graph_exec(graph);
auto expected = graph_exec(args);
auto actual = runtime(args, {});
runtime.check_for_memory_leak();
compareResults(expected, actual, use_allclose, use_equalnan);
}
void testStaticRuntime(
const std::string& source,
const std::vector<IValue>& args,
const std::vector<IValue>& args2,
const bool use_allclose,
const bool use_equalnan,
const bool check_resize) {
auto test_context = makeTestContext(source);
std::vector<IValue> args_tensors, args_copy;
for (const auto& ival : args) {
if (ival.isTensor()) {
args_tensors.emplace_back(ival);
const at::Tensor& t = ival.toTensor();
args_copy.emplace_back(t.clone());
}
}
auto expect = test_context->getExpected(args);
for (bool enable_out_variant : {true, false}) {
for (bool manage_output_tensors : {true, false}) {
for (bool enable_tensorexpr_fusion : {true, false}) {
if (!enable_out_variant && manage_output_tensors) {
continue;
}
// run static runtime three times
// 1st run: collect allocation profiles (args)
// 2nd run: exercise memory planner and resizing with args2
// 3rd run: run with args again
StaticModuleOptions opts{
.enable_out_variant = enable_out_variant,
.optimize_memory = enable_out_variant,
.manage_output_tensors = manage_output_tensors,
.enable_tensorexpr_fusion = enable_tensorexpr_fusion};
auto smodule = test_context->makeStaticModule(opts);
StaticRuntime runtime(smodule);
auto actual = runtime(args, {});
if (actual.isTensor()) {
EXPECT_GE(smodule.num_nodes(), 2)
<< "If we only have one node, the output of the op we are testing is "
<< "not being managed by the memory planner! A failure here "
<< "can typically be fixed by clone()ing the output of the test script.";
}
runtime.check_for_memory_leak();
// first run
VLOG(2) << "enable_out_variant: " << enable_out_variant;
VLOG(2) << "manage_output_tensors: " << manage_output_tensors;
VLOG(2) << "enable_tensorexpr_fusion: " << enable_tensorexpr_fusion;
VLOG(2) << "args: " << args;
VLOG(2) << "args2: " << args2;
VLOG(2) << "expect: " << expect;
VLOG(2) << "actual: " << actual;
compareResults(expect, actual, use_allclose, use_equalnan);
VLOG(2) << "first run comparison done";
if (manage_output_tensors) {
actual = IValue();
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
if (!args2.empty()) {
auto* memory_planner = runtime.get_memory_planner();
size_t managed_bytes =
memory_planner ? memory_planner->total_managed() : 0;
// Run static runtime again with inputs of a different shape.
expect = test_context->getExpected(args2);
actual = runtime(args2, {});
runtime.check_for_memory_leak();
VLOG(2) << "comparing with args2";
compareResults(expect, actual, use_allclose, use_equalnan);
VLOG(2) << "second run comparison done";
if (manage_output_tensors) {
actual = IValue();
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
size_t new_managed_bytes =
memory_planner ? memory_planner->total_managed() : 0;
if (check_resize && new_managed_bytes > 0) {
EXPECT_GT(new_managed_bytes, managed_bytes);
}
// Run static runtime again with an input of the shape observed during
// the profile run.
expect = test_context->getExpected(args);
actual = runtime(args, {});
runtime.check_for_memory_leak();
// third run
VLOG(2) << "comparing third run";
compareResults(expect, actual, use_allclose, use_equalnan);
VLOG(2) << "third run comparison done";
if (manage_output_tensors) {
actual = IValue();
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
} else {
// run static runtime again to exercise the memory planner
// and allocate managed tensors.
actual = runtime(args, {});
runtime.check_for_memory_leak();
VLOG(2) << "comparing second run with same args";
compareResults(expect, actual, use_allclose, use_equalnan);
VLOG(2) << "second run comparison done";
if (manage_output_tensors) {
actual = IValue();
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
// third run to use the allocated managed tensors.
actual = runtime(args, {});
runtime.check_for_memory_leak();
if (manage_output_tensors) {
actual = IValue();
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
}
}
}
}
// make sure inputs were not modified
VLOG(2) << "Printing out input tensors";
compareTensorLists(args_tensors, args_copy, use_allclose, use_equalnan);
}
bool hasProcessedNodeWithName(
torch::jit::StaticModule& smodule,
const char* name) {
return smodule.findNodeWithKindForTesting(name) != nullptr;
}
} // namespace test
} // namespace jit
} // namespace torch