blob: abe3743a172ff0fbea681fe31d48ea68f1c28983 [file] [log] [blame]
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/InferSize.h>
#include <ATen/TensorUtils.h>
#include <torch/library.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/ATen.h>
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_to_copy.h>
#include <ATen/ops/to_native.h>
#include <ATen/ops/resize.h>
#include <ATen/ops/as_strided.h>
#include <ATen/ops/as_strided_copy.h>
#include <ATen/ops/empty_strided_native.h>
#include <ATen/ops/_unsafe_view.h>
#endif
namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) {
const auto& schema = op.schema();
TORCH_INTERNAL_ASSERT(!schema.hasAnyAliasInfo(), "mutating and aliasing ops should all have codegen'd kernels");
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
auto any_functional_inputs = false;
auto any_tensor_inputs = false;
for (uint64_t idx = 0; idx < num_arguments; ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
any_tensor_inputs = true;
auto t = ivalue.toTensor();
if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
any_functional_inputs = true;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isTensorList()) {
any_tensor_inputs = true;
auto tensors = ivalue.toTensorList();
if (at::functionalization::impl::isFunctionalTensor(tensors)) {
any_functional_inputs = true;
at::functionalization::impl::sync(tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isOptionalTensorList()) {
any_tensor_inputs = true;
auto opt_tensors = ivalue.toOptionalTensorList();
if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
any_functional_inputs = true;
at::functionalization::impl::sync(opt_tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
(*stack)[arguments_begin + idx] = t_new;
}
}
}
// we should wrap the output if any inputs were wrapped,
// OR if we're hitting a factory function (with no tensor inputs)
auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
{
at::AutoDispatchSkipFunctionalize guard;
op.callBoxed(stack);
}
const auto num_returns = schema.returns().size();
const auto returns_begin = stack->size() - num_returns;
auto returns = torch::jit::last(stack, num_returns);
for (const auto idx : c10::irange(num_returns)) {
const auto& ivalue = returns[idx];
if (ivalue.isTensor() && should_wrap_outputs) {
auto t = ivalue.toTensor();
if (!t.defined()) continue;
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList() && should_wrap_outputs) {
auto tensors = ivalue.toTensorList();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
auto opt_tensors = ivalue.toOptionalTensorList();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
(*stack)[returns_begin + idx] = t_new;
}
}
}
}
// Vanilla implementation to compute contiguous strides given some sizes.
// Should probably refactor this into shared code (also used in TensorImpl.h)
std::vector<int64_t> compute_contiguous_strides(c10::IntArrayRef sizes) {
auto n = sizes.size();
std::vector<int64_t> strides(n);
if (n == 0) return strides;
strides[n - 1] = 1;
for (int64_t i = n - 2; i >= 0; --i) {
strides[i] = strides[i+1] * sizes[i];
}
return strides;
}
// resize_() is special because:
// - when we resize to a larger size, it acts as a mutation
// - when we resize to a smaller size, it acts as a view
// See Note [resize_ in Functionalization] for more dtails
const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format) {
// First unwrap the tensor arguments
at::Tensor self_;
if (at::functionalization::impl::isFunctionalTensor(self)) {
at::functionalization::impl::sync(self);
self_ = at::functionalization::impl::from_functional_tensor(self);
} else {
self_ = self;
}
// Case 1: arguments are not functional tensors, so we no-op and redispatch.
if (!at::functionalization::impl::isFunctionalTensor(self)) {
at::AutoDispatchSkipFunctionalize guard;
at::Tensor tmp_output = self_.resize_(size, memory_format);
return self;
}
// Case 2: actually functionalize resize_()
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::resize(self_, size, memory_format);
}
auto itemsize = self.dtype().itemsize();
auto storage_offset = self.storage_offset();
auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
auto needs_resize_storage = new_size_bytes > self.storage().nbytes();
if (needs_resize_storage) {
// If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
// See Note[resize_() in functionalization pass]
auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
func_impl->maybe_replace_storage(tmp_output);
// See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
// So we don't need to treat the output of resize as view tensor.
return self;
}
// Otherwise, we know that we're resizing to a smaller size.
// resize_() is effectively a view operator.
// The output of resizing is equivalent to taking a slice of a larger tensor.
// We have to emulate this "slicing" with an as_strided call.
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
if (reapply_views) {
return base.as_strided(size, compute_contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, compute_contiguous_strides(size));
}
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, compute_contiguous_strides(size));
}
);
at::functionalization::impl::mutate_view_meta(self, view_meta);
return self;
}
at::Tensor lift_functionalize(const at::Tensor & self) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
return at::functionalization::impl::to_functional_tensor(self);
}
bool device_opted_into_functionalization(c10::Device self_device, c10::optional<c10::Device> tgt_device) {
// If the target device is empty, then the output tensor should be on the same device as the input
auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
}
// note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
// We should probably get rid of this though.
at::Tensor _to_copy_functionalize(
const at::Tensor & self,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
bool non_blocking,
c10::optional<at::MemoryFormat> memory_format) {
at::Tensor self_;
if (at::functionalization::impl::isFunctionalTensor(self)) {
// sync any pending updates
at::functionalization::impl::sync(self);
// pass the unwrapped tensor to the backend
self_ = at::functionalization::impl::from_functional_tensor(self);
} else {
self_ = self;
}
at::AutoDispatchSkipFunctionalize guard;
auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
// Special case: if the Functionalize key is not in TLS, we assume that we're running
// on a lazy backend (LTC).
// In that case, if we're copying to a non-functionalize-enabled device,
// then the functionalization pass should "end". We need to sync any updates on the input
// tensor, but we shouldn't wrap the output.
if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
if (!device_opted_into_functionalization(self.device(), device)) {
return out;
}
}
return at::functionalization::impl::to_functional_tensor(out);
}
// Why is _unsafe_view special-cased here?
// Basically just to satisfy autograd's debug asserts.
// The situation:
// - _unsafe_view's autograd kernel has debug asserts to confirm
// that the input and output alias storage.
// - _unsafe_view's schema in native_functions.yaml
// does not contain alias annotations, so it advertises as non-aliasing.
// - functionalization will then treat _unsafe_view like a non-aliasing op.
// Specifically, autograd will redispatch to functionalization's
// boxed fallback kernel, which creates a new FunctionalTensorWrapper output
// that does **not** alias storage with the input, tripping the assert.
// The kernel written here just manually re-ifies the aliasing relationship.
//
// Another way to handle this would be to fix unsafe_view's alias annotations
// in native_functions.yaml, but I think this would be a pessimization.
// The idea with _unsafe_view is that you're guaranteed that the input
// is a temporary, and don't actually have to worry about propagating
// mutations between the input and output.
at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::IntArrayRef size) {
if (!at::functionalization::impl::isFunctionalTensor(self)) {
at::AutoDispatchSkipFunctionalize guard;
return at::_unsafe_view(self, size);
}
auto self_ = at::functionalization::impl::from_functional_tensor(self);
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::_unsafe_view(self_, size);
}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view(base, size);
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view(mutated_view, base.sizes());
}
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, view_meta);
// See Note [Propagating strides in the functionalization pass]
// (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
auto inferred_size = at::infer_size_dv(size, self.numel());
auto stride = at::detail::computeStride(self.sizes(), self.strides(), inferred_size);
TORCH_INTERNAL_ASSERT(stride.has_value());
out.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride.value());
return out;
}
TORCH_LIBRARY_IMPL(_, Functionalize, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
}
TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
m.impl("resize_", TORCH_FN(resize__functionalization));
m.impl("lift", TORCH_FN(lift_functionalize));
m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
}