blob: af413b5867ee5f99bd0704d754d83ea75e9e4250 [file] [log] [blame]
#include <ATen/native/quantized/PackedParams.h>
#include <torch/library.h>
#include <torch/torch.h>
namespace {
using namespace torch::autograd;
using namespace at;
// This class is a custom gradient function that enables quantized tensor to
// pass input gradient back to the previous layers This function can be used
// when the user is adapting mixed precision for traninig after quantization
// From torch layer, we have no access to linear_dynamic operator which needs to
// access via redispatching mechanism TO-DO : currently we are supporting per
// tensor quantization only, will expand to per channel later on
class PackedLinearWeightDynamicBackward
: public Function<PackedLinearWeightDynamicBackward> {
public:
static torch::Tensor forward(
AutogradContext* ctx,
at::Tensor input,
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
bool reduce_range) {
static auto op =
at::Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_dynamic", "")
.typed<at::Tensor(
at::Tensor,
c10::intrusive_ptr<
LinearPackedParamsBase,
c10::detail::intrusive_target_default_null_type<
LinearPackedParamsBase>> const&,
bool)>();
auto output = op.redispatch(
DispatchKeySet({DispatchKey::CPU}), input, packed_weight, reduce_range);
ctx->saved_data["weight"] = packed_weight;
return output;
}
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
auto packed_weight =
ctx->saved_data["weight"].toCustomClass<LinearPackedParamsBase>();
auto unpacked_parameters = packed_weight->unpack();
auto original_weight = std::get<0>(unpacked_parameters);
original_weight = at::permute(original_weight, {1, 0});
auto grad_output = grad_outputs[0];
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_prepack", "")
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
at::Tensor, c10::optional<at::Tensor>)>();
auto prepacked_weight = op.call(original_weight, nullopt);
auto grad_input = prepacked_weight->apply_dynamic(grad_output);
return {grad_input, torch::Tensor(), torch::Tensor()};
}
};
at::Tensor packed_linear_weight_grad(
c10::DispatchKeySet ks,
at::Tensor input,
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
bool reduce_range) {
return PackedLinearWeightDynamicBackward::apply(
input, packed_weight, reduce_range);
}
} // namespace
namespace at {
namespace native {
namespace {
TORCH_LIBRARY_IMPL(quantized, Autograd, m) {
m.impl(
TORCH_SELECTIVE_NAME("quantized::linear_dynamic"),
TORCH_FN(packed_linear_weight_grad));
}
} // namespace
} // namespace native
} // namespace at