blob: a1db68087a121222b1b11bb431f0ab494f73255c [file] [log] [blame]
#include <gtest/gtest.h>
#include <ATen/MetaFunctions.h>
#include <torch/torch.h>
#include <vector>
TEST(MetaTensorTest, MetaDeviceApi) {
auto a = at::ones({4}, at::kFloat);
auto b = at::ones({3, 4}, at::kFloat);
// at::add() will return a meta tensor if its inputs are also meta tensors.
auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta));
ASSERT_EQ(a.device(), c10::kCPU);
ASSERT_EQ(b.device(), c10::kCPU);
ASSERT_EQ(out_meta.device(), c10::kMeta);
c10::IntArrayRef sizes_actual = out_meta.sizes();
std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
ASSERT_EQ(sizes_actual, sizes_expected);
}
TEST(MetaTensorTest, MetaNamespaceApi) {
auto a = at::ones({4}, at::kFloat);
auto b = at::ones({3, 4}, at::kFloat);
// The at::meta:: namespace take in tensors from any backend
// and return a meta tensor.
auto out_meta = at::meta::add(a, b);
ASSERT_EQ(a.device(), c10::kCPU);
ASSERT_EQ(b.device(), c10::kCPU);
ASSERT_EQ(out_meta.device(), c10::kMeta);
c10::IntArrayRef sizes_actual = out_meta.sizes();
std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
ASSERT_EQ(sizes_actual, sizes_expected);
}