blob: 28f17f10ff439704c09da0dff2885f89b1a2d0b4 [file] [log] [blame] [edit]
#include <gtest/gtest.h>
#include <c10/util/irange.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
using namespace torch::nn;
using namespace torch::test;
struct AGIUnit : torch::nn::Module {};
namespace test {
struct AGIUnit : torch::nn::Module {};
struct AGIUnit2 : torch::nn::Module {
AGIUnit2() : torch::nn::Module("Foo") {}
};
} // namespace test
struct ModuleTest : torch::test::SeedingFixture {};
TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
Linear module(3, 4);
ASSERT_TRUE(module->is_training());
module->eval();
ASSERT_FALSE(module->is_training());
module->train();
ASSERT_TRUE(module->is_training());
}
TEST_F(ModuleTest, ZeroGrad) {
Linear module(3, 4);
auto weight = torch::ones({8, 3}, torch::requires_grad());
auto loss = module(weight).sum();
loss.backward();
for (auto& parameter : module->parameters()) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto grad = parameter.grad();
ASSERT_TRUE(grad.defined());
ASSERT_NE(grad.sum().item<float>(), 0);
}
module->zero_grad();
for (auto& parameter : module->parameters()) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto grad = parameter.grad();
ASSERT_FALSE(grad.defined());
}
}
TEST_F(ModuleTest, ZeroGradWithUndefined) {
struct TestModule : torch::nn::Module {
TestModule() {
x = register_parameter("x", torch::ones(5, torch::requires_grad()));
y = register_parameter("y", torch::ones(5, torch::requires_grad()));
}
torch::Tensor x, y;
};
TestModule module;
auto z = module.x * 2;
z.sum().backward();
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
module.zero_grad(false); // set_to_none = false
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
module.zero_grad();
ASSERT_FALSE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
}
TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
struct TestModel : public torch::nn::Module {};
ASSERT_THROWS_WITH(
TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
"Submodule name must not contain a dot (got 'name.with.dot')");
ASSERT_THROWS_WITH(
TestModel{}.register_module("", torch::nn::Linear(3, 4)),
"Submodule name must not be empty");
}
TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
struct TestModel : public torch::nn::Module {};
TestModel model;
model.register_module("linear", torch::nn::Linear(3, 4));
ASSERT_THROWS_WITH(
model.register_module("linear", torch::nn::Linear(3, 4)),
"Submodule 'linear' already defined");
}
TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
torch::nn::Module model;
ASSERT_THROWS_WITH(
model.replace_module("linear", torch::nn::Linear(3, 4)),
"Submodule 'linear' is not defined");
}
TEST_F(ModuleTest, ReplaceModule) {
struct TestModel : public torch::nn::Module {
torch::nn::Linear l1{nullptr};
TestModel() {
l1 = register_module("l1", torch::nn::Linear(3, 4));
}
};
auto model = std::make_shared<TestModel>();
model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6));
ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6);
ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>());
}
TEST_F(ModuleTest, UnregisterModule) {
struct TestModel : public torch::nn::Module {};
TestModel model;
ASSERT_THROWS_WITH(
model.unregister_module("linear"),
"No Module with name `linear` is registered");
model.register_module("linear", torch::nn::Linear(3, 4));
model.unregister_module("linear");
ASSERT_TRUE(model.children().empty());
}
TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
struct TestModel : public torch::nn::Module {};
ASSERT_THROWS_WITH(
TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
"Parameter name must not contain a dot (got 'name.with.dot')");
ASSERT_THROWS_WITH(
TestModel{}.register_parameter("", torch::ones(5)),
"Parameter name must not be empty");
}
TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
struct TestModel : public torch::nn::Module {};
TestModel model;
model.register_parameter("p", torch::ones(5));
ASSERT_THROWS_WITH(
model.register_parameter("p", torch::ones(5)),
"Parameter 'p' already defined");
}
TEST_F(ModuleTest, RegisterParameterUndefinedTensor) {
struct TestModel : public torch::nn::Module {};
{
TestModel model;
model.register_parameter(
"undefined_tensor", torch::Tensor(), /*requires_grad=*/false);
ASSERT_EQ(model.parameters().size(), 0);
}
{
WarningCapture warnings;
TestModel model;
model.register_parameter("undefined_tensor", torch::Tensor());
ASSERT_EQ(model.parameters().size(), 0);
ASSERT_EQ(
count_substr_occurrences(
warnings.str(),
"Ignoring the `requires_grad=true` function parameter"),
1);
}
}
TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
struct TestModel : public torch::nn::Module {};
ASSERT_THROWS_WITH(
TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
"Buffer name must not contain a dot (got 'name.with.dot')");
ASSERT_THROWS_WITH(
TestModel{}.register_buffer("", torch::ones(5)),
"Buffer name must not be empty");
}
TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
struct TestModel : public torch::nn::Module {};
TestModel model;
model.register_buffer("p", torch::ones(5));
ASSERT_THROWS_WITH(
model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
}
TEST_F(ModuleTest, CanGetName) {
// CHECK instead of REQUIRE because demangling may fail.
AGIUnit agi;
// Call it twice just to make sure there are no bugs in the lazy
// initialization semantics.
EXPECT_EQ(agi.name(), "AGIUnit");
EXPECT_EQ(agi.name(), "AGIUnit");
EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
EXPECT_EQ(test::AGIUnit2().name(), "Foo");
}
TEST_F(ModuleTest, AsCastsModulesCorrectly) {
Linear module(3, 4);
ASSERT_EQ(module->as<Linear>(), module.get());
ASSERT_EQ(module->as<LinearImpl>(), module.get());
ASSERT_EQ(module->as<Module>(), module.get());
ASSERT_EQ(module->as<AGIUnit>(), nullptr);
std::shared_ptr<Module> raw = module.ptr();
ASSERT_EQ(raw->as<Linear>(), module.get());
ASSERT_EQ(raw->as<LinearImpl>(), module.get());
ASSERT_EQ(raw->as<Module>(), module.get());
ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
Module& raw_ref = *raw.get();
ASSERT_EQ(raw_ref.as<Linear>(), module.get());
ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
ASSERT_EQ(raw_ref.as<Module>(), module.get());
ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
if (auto* linear = raw_ref.as<Linear>()) {
ASSERT_EQ(linear->weight.ndimension(), 2);
}
AGIUnit unit;
ASSERT_EQ(unit.as<Linear>(), nullptr);
ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
ASSERT_EQ(unit.as<AGIUnit>(), &unit);
}
void test_DeviceOrDtypeConversionSkipsUndefinedTensor(
torch::Device to_device,
torch::Dtype to_dtype) {
{
// Case 1: Undefined tensors as parameters
Linear module(LinearOptions(10, 20).bias(false));
ASSERT_TRUE(module->weight.defined());
ASSERT_FALSE(module->bias.defined());
module->to(to_device);
ASSERT_TRUE(module->weight.defined());
ASSERT_EQ(module->weight.device().type(), to_device.type());
ASSERT_FALSE(module->bias.defined());
module->to(to_dtype);
ASSERT_TRUE(module->weight.defined());
ASSERT_EQ(module->weight.dtype(), to_dtype);
ASSERT_FALSE(module->bias.defined());
}
{
// Case 2: Undefined tensors as buffers
BatchNorm1d module(
BatchNorm1dOptions(5).track_running_stats(false).affine(true));
ASSERT_TRUE(module->weight.defined());
ASSERT_FALSE(module->running_mean.defined());
module->to(to_device);
ASSERT_TRUE(module->weight.defined());
ASSERT_EQ(module->weight.device().type(), to_device.type());
ASSERT_FALSE(module->running_mean.defined());
module->to(to_dtype);
ASSERT_TRUE(module->weight.defined());
ASSERT_EQ(module->weight.dtype(), to_dtype);
ASSERT_FALSE(module->running_mean.defined());
}
}
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) {
test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble);
}
TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) {
test_DeviceOrDtypeConversionSkipsUndefinedTensor(
torch::kCUDA, torch::kDouble);
}
TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) {
{
Linear module(LinearOptions(10, 20).bias(false));
auto params = module->parameters();
ASSERT_EQ(params.size(), 1);
auto named_params = module->named_parameters();
ASSERT_EQ(named_params.size(), 1);
ASSERT_TRUE(pointer_equal(params[0], named_params["weight"]));
ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight));
}
{
BatchNorm1d module(
BatchNorm1dOptions(5).track_running_stats(false).affine(false));
auto buffers = module->buffers();
ASSERT_EQ(buffers.size(), 0);
auto named_buffers = module->named_buffers();
ASSERT_EQ(named_buffers.size(), 0);
}
{
BatchNorm1d module(
BatchNorm1dOptions(5).track_running_stats(true).affine(false));
auto buffers = module->buffers();
ASSERT_EQ(buffers.size(), 3);
auto named_buffers = module->named_buffers();
ASSERT_EQ(named_buffers.size(), 3);
ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"]));
ASSERT_TRUE(
pointer_equal(named_buffers["running_mean"], module->running_mean));
ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"]));
ASSERT_TRUE(
pointer_equal(named_buffers["running_var"], module->running_var));
ASSERT_TRUE(
pointer_equal(buffers[2], named_buffers["num_batches_tracked"]));
ASSERT_TRUE(pointer_equal(
named_buffers["num_batches_tracked"], module->num_batches_tracked));
}
}
TEST_F(ModuleTest, Conversion_MultiCUDA) {
Linear module(128, 64);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
ASSERT_EQ(parameter.dtype(), torch::kFloat32);
}
{
module->to({torch::kCUDA, 0});
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter.device().index(), 0);
}
module->to({torch::kCUDA, 1});
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter.device().index(), 1);
}
}
{
module->to(torch::Device(torch::kCPU));
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
}
}
{
module->to(torch::kFloat64);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.dtype(), torch::kFloat64);
}
}
}
TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) {
Linear module(128, 64);
for (auto& parameter : module->parameters()) {
parameter.requires_grad_(false);
}
{
module->to(torch::kInt32);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.dtype(), torch::kInt32);
}
}
{
module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter.device().index(), 1);
}
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.dtype(), torch::kUInt8);
}
}
}
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
struct UnCloneable : Module {};
UnCloneable module;
ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
}
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
struct Cloneable : Module {
std::shared_ptr<Module> clone(
const torch::optional<torch::Device>& device =
torch::nullopt) const override {
return nullptr;
}
};
Cloneable module;
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
ASSERT_NO_THROW({ module.clone(); });
}
// NOLINTNEXTLINE(bugprone-exception-escape)
struct TestDistinctParametersModule
: public Cloneable<TestDistinctParametersModule> {
TestDistinctParametersModule() {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
reset();
}
void reset() override {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
torch::Tensor buffer;
};
void testDistinctParameters(
std::shared_ptr<Module> m1,
std::shared_ptr<Module> m2) {
auto params1 = m1->named_parameters();
auto params2 = m2->named_parameters();
ASSERT_EQ(params1.size(), 6);
ASSERT_EQ(params2.size(), 6);
for (auto& param : params1) {
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
ASSERT_TRUE(param->allclose(params2[param.key()]));
param->add_(2);
}
for (auto& param : params1) {
ASSERT_FALSE(param->allclose(params2[param.key()]));
}
auto buffers1 = m1->named_buffers();
auto buffers2 = m2->named_buffers();
ASSERT_EQ(buffers1.size(), 1);
ASSERT_EQ(buffers2.size(), 1);
for (auto& buffer : buffers1) {
ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
buffer->add_(2);
}
for (auto& buffer : buffers1) {
ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
}
}
TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
auto module = std::make_shared<TestDistinctParametersModule>();
torch::NoGradGuard no_grad;
auto module2 = module->clone();
testDistinctParameters(module, module2);
}
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) {
auto module = std::make_shared<TestDistinctParametersModule>();
torch::NoGradGuard no_grad;
torch::Device device(torch::kCUDA, 0);
module->to(device);
auto module2 = module->clone(device);
testDistinctParameters(module, module2);
}
TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) {
auto module = std::make_shared<TestDistinctParametersModule>();
torch::NoGradGuard no_grad;
torch::Device d0(torch::kCUDA, 0);
torch::Device d1(torch::kCUDA, 1);
module->to(d0);
auto module2 = module->clone(d1);
for (auto& param : module->parameters()) {
ASSERT_EQ(param.device(), d0);
}
for (auto& param : module2->parameters()) {
ASSERT_EQ(param.device(), d1);
}
// need to move the module back to d0 as allclose expects two tensors on
// the same device.
module2->to(d0);
testDistinctParameters(module, module2);
}
TEST_F(ModuleTest, ClonePreservesExternalReferences) {
// NOLINTNEXTLINE(bugprone-exception-escape)
struct TestModule : public Cloneable<TestModule> {
TestModule() {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
reset();
}
void reset() override {
weight = register_parameter("weight", torch::ones({4, 4}));
}
torch::Tensor weight;
};
auto module = std::make_shared<TestModule>();
{
torch::NoGradGuard no_grad;
module->weight += 1;
}
ASSERT_TRUE(
pointer_equal(module->weight, module->named_parameters()["weight"]));
ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
auto module2 = std::dynamic_pointer_cast<TestModule>(
std::shared_ptr<Module>(module->clone()));
ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
ASSERT_TRUE(
pointer_equal(module2->weight, module2->named_parameters()["weight"]));
ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
ASSERT_TRUE(module2->weight.allclose(module->weight));
ASSERT_FALSE(
pointer_equal(module2->weight, module->named_parameters()["weight"]));
}
TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
// NOLINTNEXTLINE(bugprone-exception-escape)
struct TestModule : public Cloneable<TestModule> {
TestModule() {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
reset();
}
void reset() override {
weight = register_parameter("weight", torch::ones({4, 4}));
}
torch::Tensor weight;
int value = 0;
};
// NOLINTNEXTLINE(bugprone-exception-escape)
struct NestedModule : public Cloneable<NestedModule> {
NestedModule() {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
reset();
}
void reset() override {
module = register_module("module", std::make_shared<TestModule>());
}
std::shared_ptr<TestModule> module;
};
auto a = std::make_shared<NestedModule>();
{
torch::NoGradGuard no_grad;
a->module->weight += 1;
a->module->value = 123;
}
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
ASSERT_TRUE(pointer_equal(
b->module->weight, b->module->named_parameters()["weight"]));
ASSERT_TRUE(
b->module->named_parameters()["weight"].allclose(a->module->weight));
ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
ASSERT_EQ(b->module->value, a->module->value);
}
TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
// NOLINTNEXTLINE(bugprone-exception-escape)
struct TestModule : public Cloneable<TestModule> {
TestModule() {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
reset();
}
void reset() override {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
torch::Tensor buffer;
};
TestModule m;
torch::Device device(torch::kCUDA, 0);
m.to(device);
auto clone = m.clone();
for (const auto& parameter : clone->parameters()) {
ASSERT_EQ(parameter.device().type(), device.type());
ASSERT_EQ(parameter.device().index(), device.index());
}
for (const auto& buffer : clone->buffers()) {
ASSERT_EQ(buffer.device().type(), device.type());
ASSERT_EQ(buffer.device().index(), device.index());
}
}
TEST_F(
ModuleTest,
CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
// NOLINTNEXTLINE(bugprone-exception-escape)
struct TestModule : public Cloneable<TestModule> {
TestModule() {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
reset();
}
void reset() override {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
torch::Tensor buffer;
};
TestModule m;
torch::Device device(torch::kCUDA, 1);
// everything is on CPU here
auto clone = m.clone(device);
for (const auto& parameter : clone->parameters()) {
ASSERT_EQ(parameter.device().type(), device.type());
ASSERT_EQ(parameter.device().index(), device.index());
}
for (const auto& buffer : clone->buffers()) {
ASSERT_EQ(buffer.device().type(), device.type());
ASSERT_EQ(buffer.device().index(), device.index());
}
}
struct ParameterTestModule : Module {
ParameterTestModule() {
a = register_parameter("a", torch::zeros({2, 2}));
b = register_parameter("b", torch::ones({2, 2}));
c = register_parameter("c", torch::ones({2, 2}) * 2);
}
torch::Tensor a, b, c;
};
TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
ParameterTestModule module;
ASSERT_EQ(module.parameters().size(), 3);
ASSERT_EQ(module.named_parameters().size(), 3);
}
TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
ParameterTestModule module;
auto parameters = module.named_parameters();
ASSERT_TRUE(parameters.contains("a"));
ASSERT_TRUE(parameters.contains("b"));
ASSERT_TRUE(parameters.contains("c"));
}
struct BufferTestModule : Module {
BufferTestModule() {
a = register_buffer("a", torch::zeros({2, 2}));
b = register_buffer("b", torch::ones({2, 2}));
c = register_buffer("c", torch::ones({2, 2}) * 2);
}
torch::Tensor a, b, c;
};
TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
BufferTestModule module;
ASSERT_EQ(module.buffers().size(), 3);
ASSERT_EQ(module.named_buffers().size(), 3);
}
TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
BufferTestModule module;
auto buffers = module.named_buffers();
ASSERT_TRUE(buffers.contains("a"));
ASSERT_TRUE(buffers.contains("b"));
ASSERT_TRUE(buffers.contains("c"));
}
struct AImpl : torch::nn::Module {
AImpl() : x_(123) {}
AImpl(int x) : x_(x) {}
int x_;
};
TORCH_MODULE(A);
TEST_F(
ModuleTest,
DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
A a;
ASSERT_TRUE(a);
ASSERT_FALSE(a.is_empty());
ASSERT_EQ(a->x_, 123);
}
TEST_F(
ModuleTest,
ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
A a(5);
ASSERT_TRUE(a);
ASSERT_FALSE(a.is_empty());
ASSERT_EQ(a->x_, 5);
}
TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
A a = nullptr;
ASSERT_FALSE(a);
ASSERT_TRUE(a.is_empty());
ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
}
struct TestModule : public torch::nn::Module {
TestModule(int64_t size) {
p1 = register_parameter("p1", torch::randn({size}));
p2 = register_parameter("p2", torch::randn({size}));
b1 = register_buffer("b1", torch::randn({size}));
b2 = register_buffer("b2", torch::randn({size}));
}
torch::Tensor forward(torch::Tensor input) {
return input;
}
torch::Tensor p1, p2, b1, b2;
};
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model.ptr(), model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (const auto i : c10::irange(expected.size())) {
// Assert pointer equality.
ASSERT_EQ(modules[i].get(), expected[i].get());
}
}
TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
std::vector<std::shared_ptr<torch::nn::Module>> modules =
model->modules(/*include_self=*/false);
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (const auto i : c10::irange(expected.size())) {
// Assert pointer equality.
ASSERT_EQ(modules[i].get(), expected[i].get());
}
}
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_modules();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model.ptr(), model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (const auto i : c10::irange(expected.size())) {
// Assert pointer equality.
ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
ASSERT_EQ(modules[i].value().get(), expected[i].get());
}
}
TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_modules(
/*name_prefix=*/std::string(), /*include_self=*/false);
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (const auto i : c10::irange(expected.size())) {
// Assert pointer equality.
ASSERT_EQ(modules[i].key(), std::to_string(i));
ASSERT_EQ(modules[i].value().get(), expected[i].get());
}
}
TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (const auto i : c10::irange(expected.size())) {
// Assert pointer equality.
ASSERT_EQ(modules[i].get(), expected[i].get());
}
// For this flat model, this should be true.
ASSERT_EQ(modules, model->modules(/*include_self=*/false));
}
TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_children();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (const auto i : c10::irange(expected.size())) {
// Assert pointer equality.
ASSERT_EQ(modules[i].key(), std::to_string(i));
ASSERT_EQ(modules[i].value().get(), expected[i].get());
}
}
TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
std::vector<torch::Tensor> parameters = module.parameters();
ASSERT_EQ(parameters.size(), 2);
ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>());
ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>());
}
TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
torch::OrderedDict<std::string, torch::Tensor> parameters =
module.named_parameters();
ASSERT_EQ(parameters.size(), 2);
ASSERT_EQ(parameters[0].key(), "p1");
ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>());
ASSERT_EQ(parameters[1].key(), "p2");
ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>());
}
TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
std::vector<torch::Tensor> buffers = module.buffers();
ASSERT_EQ(buffers.size(), 2);
ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>());
ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>());
}
TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
torch::OrderedDict<std::string, torch::Tensor> buffers =
module.named_buffers();
ASSERT_EQ(buffers.size(), 2);
ASSERT_EQ(buffers[0].key(), "b1");
ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>());
ASSERT_EQ(buffers[1].key(), "b2");
ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>());
}
struct TestContainer : torch::nn::Module {
TestContainer(int64_t number, std::vector<TestContainer> modules = {})
: tensor(torch::tensor(number)) {
for (const auto i : c10::irange(modules.size())) {
register_module(
std::to_string(i),
std::make_shared<TestContainer>(std::move(modules[i])));
}
}
torch::Tensor tensor;
};
int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
return std::dynamic_pointer_cast<TestContainer>(module)
->tensor.item<int64_t>();
}
std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
return std::make_shared<TestContainer>(TestContainer(
0,
{TestContainer(1, {TestContainer(2), TestContainer(3)}),
TestContainer(4),
TestContainer(
5,
{TestContainer(6),
TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
}
std::vector<std::pair<std::string, int64_t>>
make_key_value_pairs_for_deeply_nested_container() {
return {
{"test_prefix", 0},
{"test_prefix.0", 1},
{"test_prefix.0.0", 2},
{"test_prefix.0.1", 3},
{"test_prefix.1", 4},
{"test_prefix.2", 5},
{"test_prefix.2.0", 6},
{"test_prefix.2.1", 7},
{"test_prefix.2.1.0", 8},
{"test_prefix.2.1.1", 9}};
}
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
ASSERT_EQ(modules.size(), 10);
for (const auto i : c10::irange(modules.size())) {
ASSERT_EQ(get_test_container_item(modules[i]), i);
}
}
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_modules(/*name_prefix=*/"test_prefix");
auto expected = make_key_value_pairs_for_deeply_nested_container();
ASSERT_EQ(modules.size(), expected.size());
for (const auto i : c10::irange(expected.size())) {
ASSERT_EQ(modules[i].key(), expected[i].first);
ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
}
}
TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
ASSERT_EQ(modules.size(), 3);
ASSERT_EQ(get_test_container_item(modules[0]), 1);
ASSERT_EQ(get_test_container_item(modules[1]), 4);
ASSERT_EQ(get_test_container_item(modules[2]), 5);
}
TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_children();
ASSERT_EQ(modules.size(), 3);
ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
ASSERT_EQ(modules[0].key(), "0");
ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
ASSERT_EQ(modules[1].key(), "1");
ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
ASSERT_EQ(modules[2].key(), "2");
}
TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
int64_t index = 0;
model->apply([&index](torch::nn::Module& module) {
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
});
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
std::shared_ptr<const TestContainer> model =
make_deeply_nested_test_container();
int64_t index = 0;
model->apply([&index](const torch::nn::Module& module) {
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
});
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
auto expected = make_key_value_pairs_for_deeply_nested_container();
int64_t index = 0;
model->apply(
[&index, expected](const std::string& name, torch::nn::Module& module) {
ASSERT_EQ(name, expected[index].first);
ASSERT_EQ(
module.as<TestContainer>()->tensor.item<int64_t>(),
expected[index++].second);
},
/*name_prefix=*/"test_prefix");
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
std::shared_ptr<const TestContainer> model =
make_deeply_nested_test_container();
auto expected = make_key_value_pairs_for_deeply_nested_container();
int64_t index = 0;
model->apply(
[&index, &expected](
const std::string& name, const torch::nn::Module& module) {
ASSERT_EQ(name, expected[index].first);
ASSERT_EQ(
module.as<const TestContainer>()->tensor.item<int64_t>(),
expected[index++].second);
},
/*name_prefix=*/"test_prefix");
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
int64_t index = 0;
model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
ASSERT_EQ(get_test_container_item(module), index++);
});
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
auto expected = make_key_value_pairs_for_deeply_nested_container();
int64_t index = 0;
model->apply(
[&index, &expected](
const std::string& name,
const std::shared_ptr<torch::nn::Module>& module) {
ASSERT_EQ(name, expected[index].first);
ASSERT_EQ(get_test_container_item(module), expected[index++].second);
},
/*name_prefix=*/"test_prefix");
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
{
TestModule module(1);
ASSERT_THROWS_WITH(
module.modules(),
"It looks like you attempted to retrieve "
"your top-level module as a shared_ptr")
}
{
TestModule module(1);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
ASSERT_NO_THROW(module.modules(/*include_self=*/false));
}
{
auto module = std::make_shared<TestModule>(1);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
ASSERT_NO_THROW(module->modules());
}
}
struct EmptyModule : torch::nn::Module {};
TEST_F(ModuleTest, PrettyPrint) {
struct TestModule : torch::nn::Module {
TestModule(int x, float y) : x_(x), y_(y) {}
void pretty_print(std::ostream& stream) const override {
stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
}
int x_;
float y_;
};
ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
}
struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
int64_t forward(torch::Tensor x) {
return x.numel();
}
};
TORCH_MODULE(ModuleWithNonTensorForward);
TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
ModuleWithNonTensorForward m;
ASSERT_EQ(m(torch::ones(123)), 123);
}