blob: a46f434dff58dad8b19d0aedc7f29c18f8d5466f [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/core/Dict.h>
#ifdef USE_RPC
#include <torch/csrc/distributed/rpc/rref_context.h>
#endif
#include <ATen/quantized/Quantizer.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <string>
#include <type_traits>
namespace torch::jit {
using ::c10::IValue;
// Protocol 2 is the highest that can be decoded by Python 2
// See https://docs.python.org/3/library/pickle.html#data-stream-format
constexpr static uint8_t PROTOCOL_VERSION = 2;
// NOLINTNEXTLINE(bugprone-exception-escape)
Pickler::~Pickler() {
flush();
}
void Pickler::protocol() {
push<PickleOpCode>(PickleOpCode::PROTO);
push<uint8_t>(PROTOCOL_VERSION);
}
void Pickler::startTuple() {
// All attributes get pushed into a tuple and their indices saved in the
// module def
push<PickleOpCode>(PickleOpCode::MARK);
}
void Pickler::endTuple() {
push<PickleOpCode>(PickleOpCode::TUPLE);
}
void Pickler::stop() {
push<PickleOpCode>(PickleOpCode::STOP);
flush();
}
// unmemoized version called by pushIValue
void Pickler::pushIValueImpl(const IValue& ivalue) {
if (ivalue.isTensor()) {
pushTensor(ivalue);
} else if (ivalue.isTuple()) {
pushTuple(ivalue);
} else if (ivalue.isDouble()) {
pushDouble(ivalue.toDouble());
} else if (ivalue.isComplexDouble()) {
pushComplexDouble(ivalue);
} else if (ivalue.isInt()) {
pushInt(ivalue.toInt());
} else if (ivalue.isBool()) {
pushBool(ivalue.toBool());
} else if (ivalue.isString()) {
pushString(ivalue.toStringRef());
} else if (ivalue.isGenericDict()) {
pushDict(ivalue);
} else if (ivalue.isNone()) {
push<PickleOpCode>(PickleOpCode::NONE);
} else if (ivalue.isIntList()) {
pushSpecializedList(ivalue, "build_intlist", [=](const IValue& ivalue) {
for (const int64_t item : ivalue.toIntVector()) {
pushInt(item);
}
});
} else if (ivalue.isTensorList()) {
pushSpecializedList(ivalue, "build_tensorlist", [=](const IValue& ivalue) {
for (const at::Tensor& item : ivalue.toTensorVector()) {
pushIValue(item);
}
});
} else if (ivalue.isDoubleList()) {
pushSpecializedList(ivalue, "build_doublelist", [=](const IValue& ivalue) {
for (double item : ivalue.toDoubleVector()) {
pushDouble(item);
}
});
} else if (ivalue.isBoolList()) {
pushSpecializedList(ivalue, "build_boollist", [=](const IValue& ivalue) {
for (bool item : ivalue.toBoolList()) {
pushBool(item);
}
});
// note: isList must be after isIntList and friends because
// isList is true for all lists.
} else if (ivalue.isList()) {
pushGenericList(ivalue);
} else if (ivalue.isObject()) {
auto obj = ivalue.toObject();
auto type = obj->type();
if (memoized_class_types_ != nullptr) {
// memoize every class type the Pickler encountered
// This is used to make sure we capture all the run-time types
// and serialize them properly for class/interface polymorphism
memoized_class_types_->emplace_back(type);
}
auto type_name = type->name().value();
if (type_renamer_) {
type_name = type_renamer_(type);
}
pushGlobal(type_name.prefix(), type_name.name());
push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
push<PickleOpCode>(PickleOpCode::NEWOBJ);
if (checkHasValidSetGetState(type)) {
Function& getstate = type->getMethod("__getstate__");
pushIValue(getstate({obj}));
} else {
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
push<PickleOpCode>(PickleOpCode::MARK);
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
pushString(type->getAttributeName(i));
pushIValue(obj->getSlot(i));
}
push<PickleOpCode>(PickleOpCode::SETITEMS);
}
push<PickleOpCode>(PickleOpCode::BUILD);
} else if (ivalue.isDevice()) {
pushDevice(ivalue);
} else if (ivalue.isCapsule()) {
std::stringstream err;
err << "Cannot serialize custom bound C++ class";
if (memoized_class_types_ && !memoized_class_types_->empty()) {
if (auto qualname = memoized_class_types_->back()->name()) {
err << " " << qualname->qualifiedName();
}
}
err << ". Please define serialization methods via def_pickle() for "
"this class.";
AT_ERROR(err.str());
} else if (ivalue.isRRef()) {
#ifdef USE_RPC
TORCH_CHECK(
torch::distributed::rpc::getAllowJitRRefPickle() == true,
"RRef jit pickling is only allowed inside RPC calls.");
pushRRef(ivalue);
#else
TORCH_CHECK(
false, "RRef pickling is only supported with the distributed package");
#endif
} else if (ivalue.isEnum()) {
auto enum_holder = ivalue.toEnumHolder();
const auto& qualified_class_name =
enum_holder->type()->qualifiedClassName();
pushGlobal(qualified_class_name.prefix(), qualified_class_name.name());
pushIValue(enum_holder->value());
push<PickleOpCode>(PickleOpCode::REDUCE);
} else {
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
}
}
void Pickler::pushDevice(const IValue& ivalue) {
auto device = ivalue.toDevice();
auto deviceStr = device.str();
auto it = memoized_devices_map_.find(deviceStr);
if (it == memoized_devices_map_.end()) {
pushGlobal("torch", "device");
pushString(deviceStr);
push<PickleOpCode>(PickleOpCode::TUPLE1);
push<PickleOpCode>(PickleOpCode::REDUCE);
memoized_devices_map_[deviceStr] = pushNextBinPut();
} else {
pushBinGet(it->second);
}
}
#ifdef USE_RPC
void Pickler::pushRRef(const IValue& ivalue) {
// It is the same as how rref is pickled in python, see PyRRef::pickle
auto rrefInterface = ivalue.toRRef();
auto rref =
c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(rrefInterface);
pushGlobal("torch.distributed.rpc", "rref");
auto& ctx = distributed::rpc::RRefContext::getInstance();
auto rrefForkData = ctx.prepareChildFork(rref);
push<PickleOpCode>(PickleOpCode::MARK);
pushInt(rrefForkData.ownerId_);
pushInt(rrefForkData.rrefId_.createdOn_);
pushInt(rrefForkData.rrefId_.localId_);
pushInt(rrefForkData.forkId_.createdOn_);
pushInt(rrefForkData.forkId_.localId_);
pushInt(rrefForkData.parent_);
pushString(rrefForkData.typeStr_);
push<PickleOpCode>(PickleOpCode::TUPLE);
push<PickleOpCode>(PickleOpCode::REDUCE);
}
#endif
void Pickler::pushIValue(const IValue& ivalue) {
bool shouldMemoizeByPointer =
ivalue.isPtrType() && !ivalue.isString() && ivalue.use_count() > 1;
// Mutable ivalues are memoized by pointer equality, which we handle at this
// outer granularity. Immutable ivalues are memoized by value equality which
// is handled in the type-specific handlers inside pushIValueImpl.
if (shouldMemoizeByPointer) {
const void* ptr = ivalue.internalToPointer();
TORCH_CHECK(
ptr != nullptr,
"Pickler cannot memoize ",
ivalue.tagKind(),
" IValue ",
ivalue);
auto memo_entry = memoized_ivalue_map_.find(ptr);
if (memo_entry != memoized_ivalue_map_.end()) {
// This value has already been pushed, just do a BINGET
pushBinGet(memo_entry->second);
return;
}
pushIValueImpl(ivalue);
memoized_ivalues_.push_back(ivalue);
memoized_ivalue_map_[ptr] = pushNextBinPut();
} else {
pushIValueImpl(ivalue);
}
}
void Pickler::pushInt(int64_t n) {
if (n >= std::numeric_limits<uint8_t>::min() &&
n <= std::numeric_limits<uint8_t>::max()) {
push<PickleOpCode>(PickleOpCode::BININT1);
push<uint8_t>(n);
} else if (
n >= std::numeric_limits<uint16_t>::min() &&
n <= std::numeric_limits<uint16_t>::max()) {
push<PickleOpCode>(PickleOpCode::BININT2);
push<uint16_t>(n);
} else if (
n >= std::numeric_limits<int32_t>::min() &&
n <= std::numeric_limits<int32_t>::max()) {
push<PickleOpCode>(PickleOpCode::BININT);
push<int32_t>(n);
} else {
// Push 8 byte integer
push<PickleOpCode>(PickleOpCode::LONG1);
push<uint8_t>(8);
push<int64_t>(n);
}
}
void Pickler::pushBool(bool value) {
push<PickleOpCode>(value ? PickleOpCode::NEWTRUE : PickleOpCode::NEWFALSE);
}
void Pickler::pushBinGet(uint32_t memo_id) {
if (memo_id <= std::numeric_limits<uint8_t>::max()) {
push<PickleOpCode>(PickleOpCode::BINGET);
push<uint8_t>(memo_id);
} else {
// Memoized too many items, issue a LONG_BINGET instead
push<PickleOpCode>(PickleOpCode::LONG_BINGET);
push<uint32_t>(memo_id);
}
}
// unmemoized encoding of a string
void Pickler::pushStringImpl(const std::string& string) {
push<PickleOpCode>(PickleOpCode::BINUNICODE);
push<uint32_t>(string.size());
pushBytes(string);
}
void Pickler::pushString(const std::string& string) {
auto it = memoized_strings_map_.find(string);
if (it == memoized_strings_map_.end()) {
pushStringImpl(string);
memoized_strings_map_[string] = pushNextBinPut();
} else {
pushBinGet(it->second);
}
}
void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
const at::Storage& storage = tensor.storage();
void* addr = storage.unsafeGetStorageImpl();
auto it = memoized_storage_map_.find(addr);
if (it != memoized_storage_map_.end()) {
pushBinGet(it->second);
return;
}
// Tuple for persistent_load
push<PickleOpCode>(PickleOpCode::MARK);
// typename
pushString("storage");
// data_type
std::string data_type =
std::string(toString(tensor.scalar_type())).append("Storage");
pushGlobal("torch", data_type);
// root_key
std::string root_key = get_tensor_id_ != nullptr
? get_tensor_id_(tensor)
: c10::to_string(tensor_data_.size());
pushString(root_key);
// location
pushString(tensor.device().str());
// size
pushInt(tensor.storage().nbytes() / tensor.element_size());
push<PickleOpCode>(PickleOpCode::TUPLE);
push<PickleOpCode>(PickleOpCode::BINPERSID);
// TODO: Skip this if not writing tensors
memoized_storage_map_[addr] = pushNextBinPut();
tensor_data_.push_back(tensor);
}
void Pickler::pushBytes(const std::string& string) {
static const size_t kSmallStr = 32;
if (string.size() <= kSmallStr &&
bufferPos_ + string.size() <= buffer_.size()) {
// Small string that fits: buffer the data.
memcpy(buffer_.data() + bufferPos_, string.data(), string.size());
bufferPos_ += string.size();
} else {
// Otherwise, first flush, then write directly.
flush();
writer_(string.data(), string.size());
}
}
void Pickler::pushGlobal(
c10::string_view module_name,
c10::string_view class_name) {
std::string key;
key.reserve(module_name.size() + class_name.size() + 2);
key.append(module_name.data(), module_name.size());
key.push_back('\n');
key.append(class_name.data(), class_name.size());
key.push_back('\n');
const auto memo_entry = memoized_globals_map_.find(key);
if (memo_entry == memoized_globals_map_.end()) {
push<PickleOpCode>(PickleOpCode::GLOBAL);
pushBytes(key);
// Push BINPUT without adding anything to the memoized_ivalues_
size_t memo_id = pushNextBinPut();
memoized_globals_map_.insert({key, memo_id});
} else {
pushBinGet(memo_entry->second);
}
}
void Pickler::pushTensor(const IValue& ivalue) {
if (tensor_table_ == nullptr) {
pushLiteralTensor(ivalue);
} else {
pushTensorReference(ivalue);
}
}
void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) {
pushGlobal("torch._utils", "_rebuild_sparse_tensor");
push<PickleOpCode>(PickleOpCode::MARK);
// layout
auto layout = static_cast<int>(tensor.layout());
pushInt(layout);
switch (layout) {
case static_cast<int>(c10::Layout::Sparse):
// size
push<PickleOpCode>(PickleOpCode::MARK);
for (auto size : tensor.sizes()) {
pushInt(size);
}
push<PickleOpCode>(PickleOpCode::TUPLE);
// requires grad
pushIValue(tensor.requires_grad());
// indices
pushTensor(tensor._indices());
// values
pushTensor(tensor._values());
break;
case static_cast<int>(c10::Layout::SparseCsr):
push<PickleOpCode>(PickleOpCode::MARK);
for (auto size : tensor.sizes()) {
pushInt(size);
}
push<PickleOpCode>(PickleOpCode::TUPLE);
pushIValue(tensor.requires_grad());
pushTensor(tensor.crow_indices());
pushTensor(tensor.col_indices());
pushTensor(tensor.values());
break;
default:
TORCH_CHECK(
false,
"Unsupported sparse tensor layout type in serialization ",
static_cast<c10::Layout>(layout));
break;
}
// backward_hooks
pushGlobal("collections", "OrderedDict");
push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
// Construct the collections.OrderedDict for the backward_hooks
push<PickleOpCode>(PickleOpCode::REDUCE);
push<PickleOpCode>(PickleOpCode::TUPLE);
// Call torch._utils._rebuild_sparse_coo_tensor
push<PickleOpCode>(PickleOpCode::REDUCE);
}
void Pickler::pushLiteralTensor(const IValue& ivalue) {
// In contrast to tensor references, literal tensors are included in the
// pickle program binary blob. They are written to the file after the STOP
// opcode. They can't be included in the pickle program itself without a bunch
// of extra machinery since byte strings are limited to 4 GB.
//
// The format here is the same one used by `torch.save()`. The code for the
// format can be found in `torch/serialization.py`.
auto& tensor = ivalue.toTensor();
if (tensor.is_sparse() || tensor.is_sparse_csr()) {
pushLiteralSparseTensor(tensor);
return;
}
bool quantized = tensor.is_quantized();
// The arguments to this function are:
// storage, storage_offset, size, stride, requires_grad, backward_hooks
pushGlobal(
"torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2");
push<PickleOpCode>(PickleOpCode::MARK);
pushStorageOfTensor(tensor);
// storage offset
pushInt(tensor.storage_offset());
// size
push<PickleOpCode>(PickleOpCode::MARK);
for (auto size : tensor.sizes()) {
pushInt(size);
}
push<PickleOpCode>(PickleOpCode::TUPLE);
// stride
push<PickleOpCode>(PickleOpCode::MARK);
for (auto stride : tensor.strides()) {
pushInt(stride);
}
push<PickleOpCode>(PickleOpCode::TUPLE);
if (quantized) {
push<PickleOpCode>(PickleOpCode::MARK);
pushGlobal("torch", toString(tensor.qscheme()));
// tuple of (qscheme, scale, zp) or (qscheme, scales, zps, axis)
switch (tensor.qscheme()) {
case at::kPerTensorAffine:
pushDouble(tensor.q_scale());
pushInt(tensor.q_zero_point());
break;
case at::kPerChannelAffineFloatQParams:
case at::kPerChannelAffine: {
pushTensor(tensor.q_per_channel_scales());
pushTensor(tensor.q_per_channel_zero_points());
pushInt(tensor.q_per_channel_axis());
} break;
default:
TORCH_CHECK(
false,
"Unsupported tensor quantization type in serialization ",
toString(tensor.qscheme()));
break;
}
push<PickleOpCode>(PickleOpCode::TUPLE);
}
// requires_grad
pushIValue(tensor.requires_grad());
// backward_hooks
pushGlobal("collections", "OrderedDict");
push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
// Construct the collections.OrderedDict for the backward_hooks
push<PickleOpCode>(PickleOpCode::REDUCE);
if (!quantized) {
// Only push it for regular tensor if the dictionary is not empty.
auto metadata = torch::jit::getTensorMetadata(tensor);
if (!metadata.empty()) {
// IValues based on std::unordered_map<K, V> are slow and deprecated.
// Thus, pass a c10::Dict to pushDict.
c10::Dict<std::string, bool> math_bits_;
for (const auto& pair : metadata) {
math_bits_.insert(pair.first, pair.second);
}
pushDict(math_bits_);
}
}
push<PickleOpCode>(PickleOpCode::TUPLE);
// Call torch._utils._rebuild_tensor_v2
push<PickleOpCode>(PickleOpCode::REDUCE);
}
void Pickler::pushSpecializedList(
const IValue& ivalue,
const char* list_name,
const std::function<void(const IValue&)>& item_pusher) {
pushGlobal("torch.jit._pickle", list_name);
// Reduce arguments are spread (e.g. `*args`) before calling the global,
// so wrap in a tuple
push<PickleOpCode>(PickleOpCode::MARK);
push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
// Mark list
push<PickleOpCode>(PickleOpCode::MARK);
// Add all items
item_pusher(ivalue);
// Finish list
push<PickleOpCode>(PickleOpCode::APPENDS);
// Finish tuple
push<PickleOpCode>(PickleOpCode::TUPLE);
// Call reduce
push<PickleOpCode>(PickleOpCode::REDUCE);
}
static inline double swapDouble(double value) {
const char* bytes = reinterpret_cast<const char*>(&value);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double flipped;
char* out_bytes = reinterpret_cast<char*>(&flipped);
for (const auto i : c10::irange(sizeof(double))) {
out_bytes[i] = bytes[sizeof(double) - i - 1];
}
return *reinterpret_cast<double*>(out_bytes);
}
void Pickler::pushDouble(double value) {
push<PickleOpCode>(PickleOpCode::BINFLOAT);
// Python pickle format is big endian, swap.
push<double>(swapDouble(value));
}
void Pickler::pushComplexDouble(const IValue& value) {
c10::complex<double> d = value.toComplexDouble();
pushGlobal("builtins", "complex");
pushIValue(d.real());
pushIValue(d.imag());
push<PickleOpCode>(PickleOpCode::TUPLE2);
push<PickleOpCode>(PickleOpCode::REDUCE);
}
void Pickler::pushLong(const std::string& data) {
uint64_t size = data.size();
TORCH_INTERNAL_ASSERT(
size <= std::numeric_limits<uint8_t>::max(),
"Cannot pickle a long larger than 255 bytes");
push<PickleOpCode>(PickleOpCode::LONG1);
push<uint8_t>(size);
pushBytes(data);
}
void Pickler::pushTensorReference(const IValue& ivalue) {
pushGlobal("torch.jit._pickle", "build_tensor_from_id");
tensor_table_->push_back(ivalue.toTensor());
int64_t tensor_id = tensor_table_->size() - 1;
// Reduce arguments are spread (e.g. `*args`) before calling the global,
// so wrap in a tuple
push<PickleOpCode>(PickleOpCode::MARK);
pushIValue(tensor_id);
push<PickleOpCode>(PickleOpCode::TUPLE);
push<PickleOpCode>(PickleOpCode::REDUCE);
}
// startTypeTag() and endTypeTag() must be called in a pair, with 1 argument
// pushed on the stack in between them. They will add the type of a container
// ivalue to the stack as a string so we can preserve type tags across
// serialization
void Pickler::startTypeTag() {
if (tag_aggregates_) {
pushGlobal("torch.jit._pickle", "restore_type_tag");
}
}
namespace {
c10::optional<std::string> type_printer(const c10::Type& type) {
if (auto dyn = type.castRaw<c10::DynamicType>()) {
return dyn->fallback()->annotation_str(type_printer);
}
return c10::nullopt;
}
} // namespace
// See startTypeTag
void Pickler::endTypeTag(const IValue& ivalue) {
if (!tag_aggregates_) {
return;
}
TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList());
// Push the dict type
auto type = ivalue.type();
TORCH_INTERNAL_ASSERT(type);
auto annot_str = type->annotation_str(type_printer);
pushString(annot_str);
// Pop the dict and type into a tuple
push<PickleOpCode>(PickleOpCode::TUPLE2);
// Call function via reduce
push<PickleOpCode>(PickleOpCode::REDUCE);
}
void Pickler::pushDict(const IValue& ivalue) {
auto dict = ivalue.toGenericDict();
startTypeTag();
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
static_assert(
std::is_unsigned<decltype(dict.size())>::value,
"Expected size to be non-negative.");
push<PickleOpCode>(PickleOpCode::MARK);
// Sort the dict for deterministic keys
for (const auto& entry : dict) {
pushIValue(entry.key());
pushIValue(entry.value());
}
push<PickleOpCode>(PickleOpCode::SETITEMS);
endTypeTag(ivalue);
}
size_t Pickler::pushNextBinPut() {
if (memo_id_ <= std::numeric_limits<uint8_t>::max()) {
push<PickleOpCode>(PickleOpCode::BINPUT);
push<uint8_t>(memo_id_);
} else {
// Memoized too many items, issue a LONG_BINPUT instead
push<PickleOpCode>(PickleOpCode::LONG_BINPUT);
push<uint32_t>(memo_id_);
}
AT_ASSERT(memo_id_ <= std::numeric_limits<uint32_t>::max());
++memo_id_;
return memo_id_ - 1;
}
void Pickler::pushGenericList(const IValue& ivalue) {
auto list = ivalue.toListRef();
startTypeTag();
// Push the list items
push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
push<PickleOpCode>(PickleOpCode::MARK);
for (const IValue& item : list) {
pushIValue(item);
}
push<PickleOpCode>(PickleOpCode::APPENDS);
endTypeTag(ivalue);
}
void Pickler::pushTuple(const IValue& ivalue) {
auto tuple = ivalue.toTuple();
auto tuple_size = tuple->elements().size();
switch (tuple_size) {
case 0: {
push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
} break;
case 1: {
pushIValue(tuple->elements()[0]);
push<PickleOpCode>(PickleOpCode::TUPLE1);
} break;
case 2: {
pushIValue(tuple->elements()[0]);
pushIValue(tuple->elements()[1]);
push<PickleOpCode>(PickleOpCode::TUPLE2);
} break;
case 3: {
pushIValue(tuple->elements()[0]);
pushIValue(tuple->elements()[1]);
pushIValue(tuple->elements()[2]);
push<PickleOpCode>(PickleOpCode::TUPLE3);
} break;
default: {
push<PickleOpCode>(PickleOpCode::MARK);
for (const IValue& item : tuple->elements()) {
pushIValue(item);
}
push<PickleOpCode>(PickleOpCode::TUPLE);
} break;
}
}
WriteableTensorData getWriteableTensorData(
const at::Tensor& tensor,
bool to_cpu) {
WriteableTensorData result;
result.tensor_ = tensor;
result.size_ = tensor.storage().nbytes();
// TODO HIP support
if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
// NB: This new tensor is created to support cuda tensors.
// Storages can be mutated when converting tensors from cuda to cpu,
// and we need a cpu tensor to copy data from.
result.tensor_ =
at::empty({0}, tensor.options())
.set_(
tensor.storage(),
/* storage_offset = */ 0,
/* size = */
{static_cast<int64_t>(
tensor.storage().nbytes() / tensor.element_size())},
/* stride = */ {1})
.cpu();
TORCH_CHECK(
result.tensor_.storage().nbytes() == result.size_,
"Storage tensor size did not match record size");
}
return result;
}
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
// Check that the schemas for __getstate__ and __setstate__ are correct
auto getstate = cls->findMethod("__getstate__");
if (getstate == nullptr) {
return false;
}
auto get_schema = getstate->getSchema();
// Check __getstate__
// __getstate__ is expected to be (self) -> T
TORCH_CHECK(
get_schema.arguments().size() == 1,
"'__getstate__' must have 'self' as its only argument, but found ",
get_schema.arguments().size(),
" arguments");
TORCH_CHECK(
get_schema.returns().size() == 1,
"'__getstate__' must return 1 value, but found ",
get_schema.returns().size());
// Check __setstate__ if the method exists
// __setstate__ is expected to be (self, T) -> None
auto setstate = cls->findMethod("__setstate__");
if (!setstate) {
return false;
}
auto set_schema = setstate->getSchema();
TORCH_CHECK(
set_schema.arguments().size() == 2,
"'__setstate__' must have 'self' and the state as its "
"only arguments, but found ",
set_schema.arguments().size(),
" arguments");
TORCH_CHECK(
set_schema.returns().size() == 1,
"'__setstate__' must return None, but found ",
set_schema.returns().size(),
" return values");
TORCH_CHECK(
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
"'__setstate__' must return None, but found value of type",
set_schema.returns().at(0).type()->annotation_str());
// Check that the return type of __getstate__ matches the input to
// __setstate__
auto get_type = get_schema.returns().at(0).type();
auto set_type = set_schema.arguments().at(1).type();
TORCH_CHECK(
get_type->isSubtypeOf(*set_type),
"'__getstate__'s return type (",
get_type->annotation_str(),
") does not match '__setstate__'s argument type (",
set_type->annotation_str(),
")");
return true;
}
} // namespace torch::jit