blob: 81d6be7b852006a3ddaf4455f82d291d9fc2b003 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/native/Repeat.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/repeat_interleave.h>
#include <ATen/ops/repeat_interleave_native.h>
#endif
template <typename index_t>
static void compute_cpu(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
TORCH_CHECK(
(result_size == cumsum_ptr[size - 1]),
"allocated size does not match required size");
at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
for (const auto i : c10::irange(i_begin, i_end)) {
int64_t end = cumsum_ptr[i];
index_t size = repeat_ptr[i];
TORCH_CHECK((size >= 0), "repeats can not be negative");
int64_t start = end - size;
for (const auto j : c10::irange(start, end)) {
result_ptr[j] = i;
}
}
});
}
namespace at::native {
Tensor repeat_interleave_cpu(
const Tensor& repeat,
c10::optional<int64_t> output_size) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() {
output = repeat_interleave_common<index_t, compute_cpu<index_t>>(
repeat, output_size);
});
return output;
}
Tensor repeat_interleave_symint(
const Tensor& self,
const Tensor& repeats,
c10::optional<int64_t> dim,
c10::optional<SymInt> output_size) {
Tensor input = self;
// Store conj and neg bits
const auto conj = input.is_conj();
if (conj) {
input = input.conj();
}
const auto neg = input.is_neg();
if (neg) {
input = input._neg_view();
}
if (!dim) {
input = input.flatten();
dim = 0;
}
Tensor repeats_ = repeats;
if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) {
repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())});
} else if (repeats.dim() == 1) {
TORCH_CHECK(
repeats.sym_size(0) == input.sym_size(dim.value()),
"repeats must have the same size as input along dim, but got repeats.size(0) = ",
repeats.sym_size(0), " and input.size(", dim.value(), ") = ", input.sym_size(dim.value())
);
} else {
AT_ERROR("repeats must be 0-dim or 1-dim tensor");
}
auto ret = input.index_select(
dim.value(), at::repeat_interleave_symint(repeats_, output_size));
// Restore conj and neg bits
if (conj) {
ret = ret.conj();
}
if (neg) {
ret = ret._neg_view();
}
return ret;
}
Tensor repeat_interleave_symint(
const Tensor& self,
c10::SymInt repeats,
c10::optional<int64_t> dim_opt,
c10::optional<SymInt> output_size) {
Tensor input = dim_opt ? self : self.flatten();
int64_t dim = c10::maybe_wrap_dim(dim_opt.value_or(0), self.dim());
TORCH_CHECK(repeats >= 0, "Repeats must be non-negative");
input = input.unsqueeze(dim + 1);
auto expand_shape = input.sym_sizes().vec();
expand_shape[dim + 1] = repeats;
input = input.expand_symint(expand_shape);
// This argument doesn't really make sense for the scalar overload, but exists
// for consistency with the tensor overload
if (output_size) {
auto calculated_size = (repeats * expand_shape[dim]).guard_int(__FILE__, __LINE__);
TORCH_CHECK(*output_size == calculated_size, "repeat_interleave: Invalid output_size, expected ",
calculated_size, " but got ", *output_size);
}
return input.clone(at::MemoryFormat::Contiguous).flatten(dim, dim + 1);
}
} // namespace at::native