| #include <ATen/ATen.h> |
| #include <ATen/core/Dict.h> |
| #ifdef USE_RPC |
| #include <torch/csrc/distributed/rpc/rref_context.h> |
| #endif |
| #include <torch/csrc/jit/api/function_impl.h> |
| #include <torch/csrc/jit/mobile/type_parser.h> |
| #include <torch/csrc/jit/serialization/pickler.h> |
| #include <torch/csrc/jit/serialization/storage_context.h> |
| #include <torch/csrc/jit/serialization/unpickler.h> |
| #include <string> |
| |
| namespace torch::jit { |
| |
| using ::c10::IValue; |
| |
| static void restoreAccurateTypeTagsIfPossible(const IValue& root) { |
| if (root.isObject()) { |
| restoreAccurateTypeTags(root, root.type()); |
| } |
| } |
| |
| // Pickled objects are stored in a form compatible with Python pickling. |
| // In torchscript List[T]/Dict[K, V] are statically typed and contain |
| // dynamic type tags that allow T, K, and V to be recovered. But this |
| // info is not stored in the Python pickling information. However, we |
| // can recover this information from the static type of the top-level |
| // object being unpickled, because we have a record of the type of the |
| // objects it contains as attributes. |
| // `IfPossible` - we can only do this recovery when we have an object as |
| // the top-level unpickled thing (which is guaranteed for Modules, but |
| // not for torch.load/torch.save). Otherwise we do not know the types |
| // of the contained objects and cannot restore the tags. |
| void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { |
| struct Work { |
| TypePtr type; |
| IValue value; |
| }; |
| std::vector<Work> to_process = {{type_tag, root}}; |
| std::unordered_set<const void*> scanned; |
| while (!to_process.empty()) { |
| Work w = std::move(to_process.back()); |
| to_process.pop_back(); |
| // ensure we only scan each pointer value once, otherwise this |
| // can become exponential (and if we allow recursive data in the future, |
| // it would not terminiate). |
| if (w.value.isPtrType()) { |
| const void* key = w.value.internalToPointer(); |
| auto it = scanned.find(key); |
| if (it != scanned.end()) { |
| continue; |
| } |
| scanned.emplace_hint(it, key); |
| } |
| auto kind = w.type->kind(); |
| if (auto dyn = w.type->castRaw<c10::DynamicType>()) { |
| kind = dyn->dynamicKind(); |
| } |
| switch (kind) { |
| case TensorType::Kind: |
| case StorageType::Kind: |
| case NumberType::Kind: |
| case FloatType::Kind: |
| case ComplexType::Kind: |
| case IntType::Kind: |
| case NoneType::Kind: |
| case GeneratorType::Kind: |
| case QuantizerType::Kind: |
| case BoolType::Kind: |
| case VarType::Kind: |
| case CapsuleType::Kind: |
| case PyObjectType::Kind: |
| case StringType::Kind: |
| case FunctionType::Kind: |
| case DeviceObjType::Kind: |
| case StreamObjType::Kind: |
| case QSchemeType::Kind: |
| case LayoutType::Kind: |
| case MemoryFormatType::Kind: |
| case ScalarTypeType::Kind: |
| case RRefType::Kind: |
| case AnyType::Kind: |
| case AnyListType::Kind: |
| case AnyTupleType::Kind: |
| case AnyClassType::Kind: |
| case AnyEnumType::Kind: |
| // no op, there is nothing to tag |
| break; |
| case c10::SymIntType::Kind: |
| TORCH_CHECK(!w.value.toSymInt().is_symbolic()); |
| // no op, there is nothing to tag |
| break; |
| case c10::SymFloatType::Kind: |
| TORCH_CHECK(!w.value.toSymFloat().is_symbolic()); |
| // no op, there is nothing to tag |
| break; |
| case DynamicType::Kind: |
| case UnionType::Kind: |
| case EnumType::Kind: |
| // TODO(gmagogsfm): Implement serialization/deserialization of Enum. |
| TORCH_INTERNAL_ASSERT(false); |
| case TupleType::Kind: { |
| auto t = w.value.toTuple(); |
| for (size_t i = 0; i < w.type->containedTypeSize(); ++i) { |
| Work elem = {w.type->containedType(i), t->elements().at(i)}; |
| to_process.emplace_back(std::move(elem)); |
| } |
| } break; |
| case FutureType::Kind: { |
| auto f = w.value.toFuture(); |
| if (f->completed()) { |
| Work elem = {w.type->containedType(0), f->value()}; |
| to_process.emplace_back(std::move(elem)); |
| } |
| } break; |
| case AwaitType::Kind: { |
| auto aw = w.value.toAwait(); |
| if (aw->completed()) { |
| Work elem = {w.type->containedType(0), aw->wait()}; |
| to_process.emplace_back(std::move(elem)); |
| } |
| } break; |
| case OptionalType::Kind: { |
| if (!w.value.isNone()) { |
| Work elem = {w.type->containedType(0), w.value}; |
| to_process.emplace_back(std::move(elem)); |
| } |
| } break; |
| case ListType::Kind: { |
| // specialized lists do not need their type refined, so we can exit |
| // early here |
| if (!w.value.isList()) { |
| break; |
| } |
| auto elem_type = w.type->containedType(0); |
| auto lst = w.value.toList(); |
| lst.unsafeSetElementType(elem_type); |
| for (const IValue& item : lst) { |
| Work elem = {elem_type, item}; |
| to_process.emplace_back(std::move(elem)); |
| } |
| } break; |
| case DictType::Kind: { |
| auto d = w.value.toGenericDict(); |
| auto keyType = w.type->containedType(0); |
| auto valType = w.type->containedType(1); |
| d.unsafeSetKeyType(keyType); |
| d.unsafeSetValueType(valType); |
| for (const auto& item : d) { |
| Work kelem = {keyType, item.key()}; |
| Work velem = {valType, item.value()}; |
| to_process.emplace_back(std::move(kelem)); |
| to_process.emplace_back(std::move(velem)); |
| } |
| } break; |
| // in both cases the dynamic type is a class, and we are going to tag with |
| // the dynamic type |
| case InterfaceType::Kind: |
| case ClassType::Kind: { |
| auto obj = w.value.toObject(); |
| auto typ = obj->type(); // note: intentionally using the dynamic type, |
| // the static type is potentially less accurate |
| for (size_t i = 0; i < typ->numAttributes(); ++i) { |
| Work elem = {typ->getAttribute(i), obj->getSlot(i)}; |
| to_process.emplace_back(std::move(elem)); |
| } |
| }; |
| } |
| } |
| } |
| |
| namespace { |
| template <typename T> |
| bool is(const Type& type) { |
| if (type.kind() == T::Kind) { |
| return true; |
| } |
| if (auto dyn = type.castRaw<c10::DynamicType>()) { |
| return dyn->tag() == c10::DynamicTypeTrait<T>::tagValue(); |
| } |
| return false; |
| } |
| } // namespace |
| |
| void restoreContainerTypeTags(const IValue& ivalue, const TypePtr& type) { |
| if (is<DictType>(*type)) { |
| auto dict = ivalue.toGenericDict(); |
| dict.unsafeSetKeyType(type->containedType(0)); |
| dict.unsafeSetValueType(type->containedType(1)); |
| } else if (is<ListType>(*type)) { |
| ivalue.toList().unsafeSetElementType(type->containedType(0)); |
| } else { |
| AT_ERROR("Unknown type for tag restoration: " + type->annotation_str()); |
| } |
| } |
| |
| IValue Unpickler::parse_ivalue() { |
| run(); |
| TORCH_CHECK( |
| stack_.size() == 1, |
| "Unpickler expected 1 element on the stack, but found ", |
| stack_.size()); |
| if (version_ <= 2) { |
| // See [type tag serialization] |
| restoreAccurateTypeTagsIfPossible(stack_[0]); |
| } |
| return stack_[0]; |
| } |
| |
| double Unpickler::readFloat() { |
| AT_ASSERT(sizeof(double) == 8); |
| double big_endian = read<double>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| double little_endian; |
| |
| // Pickle floats are big endian, so reverse the bytes |
| auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian); |
| std::reverse_copy( |
| big_endian_ptr, |
| big_endian_ptr + sizeof(big_endian), |
| reinterpret_cast<char*>(&little_endian)); |
| |
| return little_endian; |
| } |
| |
| void Unpickler::run() { |
| // Expect a PROTO opcode and protocol number at the start of blob |
| auto opcode = readOpCode(); |
| TORCH_CHECK( |
| opcode == PickleOpCode::PROTO, |
| "Expected PROTO opcode at the start" |
| " of pickle archive, found ", |
| int(static_cast<uint8_t>(opcode))); |
| uint8_t protocol = read<uint8_t>(); |
| TORCH_CHECK( |
| protocol == 2, |
| "Only Pickle protocol 2 is supported, found protocol = ", |
| protocol); |
| |
| while (true) { |
| PickleOpCode opcode = readInstruction(); |
| if (opcode == PickleOpCode::STOP) { |
| return; |
| } |
| } |
| } |
| void Unpickler::setInput(size_t memo_id) { |
| AT_ASSERT(!stack_.empty()); |
| if (memo_id >= memo_table_.size()) { |
| memo_table_.insert( |
| memo_table_.end(), memo_id - memo_table_.size(), IValue()); |
| memo_table_.push_back(stack_.back()); |
| } else { |
| memo_table_[memo_id] = stack_.back(); |
| } |
| } |
| |
| // emplace_back on bool vectors does not exist on some systems |
| // avoid it by calling push_back for bool |
| template <typename T> |
| inline void append(std::vector<T>& a, T&& e) { |
| a.emplace_back(std::forward<T>(e)); |
| } |
| template <> |
| inline void append<bool>(std::vector<bool>& a, bool&& e) { |
| a.push_back(e); |
| } |
| |
| static std::vector<int64_t> tupleToIntList(const IValue& v) { |
| return fmap(v.toTupleRef().elements(), [](const IValue& v) -> int64_t { |
| return v.toInt(); |
| }); |
| } |
| |
| // note we cannot use toIntList, toDoubleList because during unpickling the |
| // lists are not yet tagged |
| template <typename T> |
| static std::vector<T> convertList(const IValue& v) { |
| return fmap(v.toListRef(), [](const IValue& elem) { return elem.to<T>(); }); |
| } |
| |
| PickleOpCode Unpickler::readInstruction() { |
| auto opcode = readOpCode(); |
| switch (opcode) { |
| case PickleOpCode::EMPTY_LIST: { |
| stack_.emplace_back(c10::impl::GenericList(AnyType::get())); |
| } break; |
| case PickleOpCode::EMPTY_TUPLE: { |
| if (empty_tuple_.isNone()) { |
| // we only need one object, since tuples are not mutable. |
| empty_tuple_ = c10::ivalue::Tuple::create(std::vector<IValue>()); |
| } |
| stack_.emplace_back(empty_tuple_); |
| } break; |
| case PickleOpCode::BINPUT: { |
| size_t memo_id = read<uint8_t>(); |
| setInput(memo_id); |
| } break; |
| case PickleOpCode::LONG_BINPUT: { |
| TORCH_CHECK( |
| std::numeric_limits<size_t>::max() >= |
| std::numeric_limits<uint32_t>::max(), |
| "Found a LONG_BINPUT opcode, but size_t on this system is " |
| "not big enough to decode it"); |
| size_t memo_id = read<uint32_t>(); |
| setInput(memo_id); |
| } break; |
| case PickleOpCode::MARK: { |
| // Mark location of the container ivalue in the stack |
| marks_.push_back(stack_.size()); |
| } break; |
| case PickleOpCode::NEWTRUE: { |
| stack_.emplace_back(true); |
| } break; |
| case PickleOpCode::NEWFALSE: { |
| stack_.emplace_back(false); |
| } break; |
| case PickleOpCode::NONE: { |
| stack_.emplace_back(); |
| } break; |
| case PickleOpCode::BININT1: { |
| uint8_t value = read<uint8_t>(); |
| stack_.emplace_back(int64_t(value)); |
| } break; |
| case PickleOpCode::BININT2: { |
| uint16_t value = read<uint16_t>(); |
| stack_.emplace_back(int64_t(value)); |
| } break; |
| case PickleOpCode::BININT: { |
| int32_t value = read<int32_t>(); |
| stack_.emplace_back(int64_t(value)); |
| } break; |
| case PickleOpCode::LONG1: { |
| // Only read LONG1s with 8 as the length |
| uint8_t length = read<uint8_t>(); |
| TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length)); |
| stack_.emplace_back(int64_t(read<int64_t>())); |
| } break; |
| case PickleOpCode::BINUNICODE: { |
| uint32_t length = read<uint32_t>(); |
| stack_.emplace_back(readBytes(length)); |
| } break; |
| case PickleOpCode::BINFLOAT: |
| stack_.emplace_back(readFloat()); |
| break; |
| case PickleOpCode::TUPLE: { |
| TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); |
| size_t start = marks_.back(); |
| marks_.pop_back(); |
| std::vector<IValue> elements; |
| const auto tupleSize = stack_.size() - start; |
| switch (tupleSize) { |
| case 3: { |
| auto e3 = pop(stack_); |
| auto e2 = pop(stack_); |
| auto e1 = pop(stack_); |
| stack_.emplace_back(c10::ivalue::Tuple::create( |
| std::move(e1), std::move(e2), std::move(e3))); |
| break; |
| } |
| case 2: { |
| auto e2 = pop(stack_); |
| auto e1 = pop(stack_); |
| stack_.emplace_back( |
| c10::ivalue::Tuple::create(std::move(e1), std::move(e2))); |
| break; |
| } |
| case 1: |
| stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_))); |
| break; |
| default: { |
| elements.reserve(stack_.size() - start); |
| auto start_it = stack_.begin() + start; |
| for (auto it = start_it; it != stack_.end(); ++it) { |
| elements.emplace_back(std::move(*it)); |
| } |
| stack_.erase(start_it, stack_.end()); |
| stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements))); |
| break; |
| } |
| } |
| } break; |
| case PickleOpCode::TUPLE1: { |
| TORCH_CHECK( |
| stack_.size() > 0, |
| "Parsing error: stack_ contains ", |
| stack_.size(), |
| " elements, at least 1 expected"); |
| stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_))); |
| } break; |
| case PickleOpCode::TUPLE2: { |
| TORCH_CHECK( |
| stack_.size() > 1, |
| "Parsing error: stack_ contains ", |
| stack_.size(), |
| " elements, at least 2 expected"); |
| auto e2 = pop(stack_); |
| auto e1 = pop(stack_); |
| stack_.emplace_back( |
| c10::ivalue::Tuple::create(std::move(e1), std::move(e2))); |
| } break; |
| case PickleOpCode::TUPLE3: { |
| TORCH_CHECK( |
| stack_.size() > 2, |
| "Parsing error: stack_ contains ", |
| stack_.size(), |
| " elements, at least 3 expected"); |
| auto e3 = pop(stack_); |
| auto e2 = pop(stack_); |
| auto e1 = pop(stack_); |
| stack_.emplace_back(c10::ivalue::Tuple::create( |
| std::move(e1), std::move(e2), std::move(e3))); |
| } break; |
| case PickleOpCode::EMPTY_DICT: |
| stack_.emplace_back( |
| c10::impl::GenericDict(AnyType::get(), AnyType::get())); |
| break; |
| case PickleOpCode::APPENDS: { |
| TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); |
| size_t start = marks_.back(); |
| TORCH_CHECK( |
| start > 0 && start <= stack_.size(), |
| "Parsing error: wrong start index for stack_"); |
| auto list_ivalue = stack_.at(start - 1); |
| readList(list_ivalue); |
| } break; |
| case PickleOpCode::LIST: { |
| IValue list_ivalue = c10::impl::GenericList(AnyType::get()); |
| readList(list_ivalue); |
| stack_.push_back(std::move(list_ivalue)); |
| } break; |
| case PickleOpCode::DICT: { |
| TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); |
| size_t start = marks_.back(); |
| marks_.pop_back(); |
| auto dict = c10::impl::GenericDict(AnyType::get(), AnyType::get()); |
| for (size_t i = start; i < stack_.size(); i += 2) { |
| dict.insert_or_assign(stack_[i], stack_[i + 1]); |
| } |
| stack_.erase(stack_.begin() + start, stack_.end()); |
| stack_.emplace_back(std::move(dict)); |
| } break; |
| case PickleOpCode::SETITEMS: { |
| TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); |
| size_t start = marks_.back(); |
| marks_.pop_back(); |
| TORCH_CHECK( |
| start > 0 && start <= stack_.size(), |
| "Parsing error: wrong start index for stack_"); |
| auto dict = stack_.at(start - 1).toGenericDict(); |
| for (size_t i = start; i < stack_.size(); i += 2) { |
| dict.insert_or_assign(stack_[i], stack_[i + 1]); |
| } |
| stack_.erase(stack_.begin() + start, stack_.end()); |
| } break; |
| case PickleOpCode::BINGET: { |
| auto pos = read<uint8_t>(); |
| TORCH_CHECK( |
| memo_table_.size() > pos, |
| "Parsing error: out of bounds access at ", |
| (size_t)pos, |
| " to memo_table_ which is of size ", |
| memo_table_.size()); |
| stack_.push_back(memo_table_.at(pos)); |
| } break; |
| case PickleOpCode::LONG_BINGET: { |
| auto pos = read<uint32_t>(); |
| TORCH_CHECK( |
| memo_table_.size() > pos, |
| "Parsing error: out of bounds access at ", |
| (size_t)pos, |
| " to memo_table_ which is of size ", |
| memo_table_.size()); |
| stack_.push_back(memo_table_.at(pos)); |
| } break; |
| case PickleOpCode::STOP: |
| break; |
| case PickleOpCode::GLOBAL: { |
| // Module name, it's not needed for anything |
| auto module_name = readString(); |
| auto class_name = readString(); |
| readGlobal(module_name, class_name); |
| } break; |
| case PickleOpCode::NEWOBJ: { |
| TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty"); |
| // pop empty tuple, the actual action is stored in the globals_stack_ |
| stack_.pop_back(); |
| } break; |
| // because we have NEWOBJ do nothing, BUILD and REDUCE end up doing |
| // the same thing |
| case PickleOpCode::BUILD: |
| case PickleOpCode::REDUCE: { |
| // stack is: <functor_idx> <functor_arg> |
| // extract <functor_idx> and remove from the stack: |
| TORCH_CHECK( |
| stack_.size() > 1, |
| "Parsing error: stack_ contains ", |
| stack_.size(), |
| " elements, at least 2 expected"); |
| std::swap(*(stack_.end() - 2), *(stack_.end() - 1)); |
| size_t idx = stack_.back().toInt(); |
| stack_.pop_back(); |
| // stack is: <functor_arg> |
| TORCH_CHECK( |
| idx < globals_.size(), |
| "Parsing error: out of bounds access to globals_"); |
| globals_.at(idx)(); |
| } break; |
| case PickleOpCode::BINPERSID: { |
| TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty"); |
| auto tuple = pop(stack_).toTuple(); |
| const auto& args = tuple->elements(); |
| AT_ASSERT( |
| args.at(0).toStringRef() == "storage", |
| "unknown PERSID key ", |
| args.at(0).toStringRef()); |
| at::ScalarType type = args.at(1).toScalarType(); |
| const std::string& key = args.at(2).toStringRef(); |
| |
| at::Device device(args.at(3).toStringRef()); |
| if (device_) { |
| device = *device_; |
| } |
| |
| at::Storage storage; |
| if (storage_context_ != nullptr && storage_context_->hasStorage(key)) { |
| // for torch.package logic where storage may be loaded already |
| storage = storage_context_->getStorage(key); |
| } else { |
| int64_t numel = args.at(4).toInt(); |
| caffe2::TypeMeta dtype = at::CPU(type).typeMeta(); |
| |
| at::DataPtr storage_ptr; |
| if (numel > 0) { |
| // If there are no elements in the tensor, there's no point in |
| // reading a zero (0) byte file from the input stream and paying |
| // that cost. |
| storage_ptr = read_record_(key); |
| } |
| |
| storage = at::Storage( |
| c10::Storage::use_byte_size_t(), |
| numel * dtype.itemsize(), |
| std::move(storage_ptr), |
| /*allocator=*/nullptr, |
| /*resizable=*/false); // NB: we didn't set any allocator for the |
| // tensor |
| if (storage_context_ != nullptr) { |
| storage_context_->addStorage(key, storage); |
| } |
| } |
| |
| auto options = at::CPU(type).options(); |
| if (use_storage_device_) { |
| options = options.device(storage.device()); |
| device = storage.device(); |
| } |
| |
| at::Tensor tensor; |
| if (options.backend() == c10::Backend::QuantizedCPU) { |
| tensor = at::_empty_affine_quantized({}, options, 0, 0) |
| .set_(storage, 0, {}, {}); |
| } else { |
| tensor = at::empty({0}, options).set_(storage); |
| } |
| |
| if (device.is_cuda() || device.is_xpu() || device.is_meta() || |
| device.is_hpu()) { |
| tensor = tensor.to(device, tensor.scalar_type()); |
| } else if (device.type() != DeviceType::CPU) { |
| AT_ERROR( |
| "supported devices include CPU, CUDA and HPU, however got ", |
| DeviceTypeName(device.type(), false)); |
| } |
| stack_.emplace_back(std::move(tensor)); |
| } break; |
| case PickleOpCode::SETITEM: { |
| // At this OpCode, stack looks like |
| // | Stack Bottom | |
| // | ...... | |
| // | Dict | -> (stack_size - 3) |
| // | Key | -> (stack_size - 2) |
| // | Value | -> (stack_size - 1) |
| auto stack_size = stack_.size(); |
| auto dict_pos = stack_size - 3; |
| auto key_pos = stack_size - 2; |
| auto val_pos = stack_size - 1; |
| auto dict = stack_.at(dict_pos).toGenericDict(); |
| dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos)); |
| stack_.erase(stack_.begin() + (key_pos), stack_.end()); |
| } break; |
| default: { |
| AT_ERROR( |
| "Unknown opcode for unpickling at ", |
| reinterpret_cast<void*>(opcode), |
| ": ", |
| int(static_cast<uint8_t>(opcode))); |
| } break; |
| } |
| return opcode; |
| } |
| |
| void Unpickler::readGlobal( |
| const std::string& module_name, |
| const std::string& class_name) { |
| if (this->skip_next_read_global) { |
| // See [NOTE] skip_next_read_global |
| this->skip_next_read_global--; |
| if (this->skip_next_read_global == 1) { |
| // Pass through to the correct handler |
| } else if (this->skip_next_read_global == 0) { |
| // Corresponds to the type of `Tensor` being unpickled |
| if (module_name != "torch" || class_name != "Tensor") { |
| TORCH_WARN( |
| "Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++"); |
| } |
| stack_.emplace_back(int64_t(globals_.size() - 1)); |
| return; |
| } else { |
| TORCH_CHECK(false, "INVALID VALUES") |
| } |
| } |
| // TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this |
| // is only here for bc-compatibility reasons |
| if (module_name == "__main__") { |
| if (class_name == "TensorID") { |
| globals_.emplace_back([this] { |
| auto setitem_data = stack_.back(); |
| stack_.pop_back(); |
| TORCH_INTERNAL_ASSERT( |
| !tensor_table_.empty(), |
| "Pickler tried to write a tensor but had no tensor table to write to"); |
| stack_.emplace_back(tensor_table_.at(setitem_data.toInt())); |
| }); |
| } else if (class_name == "IntList") { |
| globals_.emplace_back([this] { |
| stack_.back().toList().unsafeSetElementType(IntType::get()); |
| }); |
| } else { |
| AT_ERROR("Unknown pickler class id", class_name); |
| } |
| } else if (module_name == "torch.jit._pickle") { |
| if (class_name == "build_tensor_from_id") { |
| globals_.emplace_back([this] { |
| // Pop reduce arg off the stack |
| auto data = stack_.back().toTupleRef().elements().at(0); |
| stack_.pop_back(); |
| TORCH_CHECK( |
| !tensor_table_.empty(), |
| "Found a tensor table reference but Unpickler" |
| " has no tensor table\n"); |
| stack_.emplace_back(tensor_table_.at(data.toInt())); |
| }); |
| } else if (class_name == "restore_type_tag") { |
| globals_.emplace_back([this] { |
| auto tuple = stack_.back().toTuple(); |
| const auto& data = tuple->elements(); |
| auto type_str = data.at(1).toStringRef(); |
| stack_.pop_back(); |
| TypePtr type = nullptr; |
| auto entry = type_cache_.find(type_str); |
| if (entry != type_cache_.end()) { |
| type = entry->second; |
| } else { |
| if (type_resolver_ == nullptr) { |
| // If we haven't injected a custom way of retrieving types from |
| // names, use a barebones type parser. |
| type = type_parser_(type_str); |
| } else { |
| type = type_resolver_(type_str).type_; |
| } |
| type_cache_[type_str] = type; |
| } |
| // TODO: Use lookahead to avoid creating the tuple and immediately |
| // destroying it here |
| restoreContainerTypeTags(data.at(0), type); |
| stack_.emplace_back(data.at(0)); |
| }); |
| } else { |
| TypePtr elem_type = nullptr; |
| if (class_name == "build_intlist") { |
| elem_type = IntType::get(); |
| } else if (class_name == "build_tensorlist") { |
| elem_type = TensorType::get(); |
| } else if (class_name == "build_doublelist") { |
| elem_type = FloatType::get(); |
| } else if (class_name == "build_boollist") { |
| elem_type = BoolType::get(); |
| } else { |
| AT_ERROR("Unknown pickler class id ", class_name); |
| } |
| // Unpickle a list specialization (e.g. List[Tensor], List[int], ...) |
| globals_.emplace_back([this, elem_type] { |
| // Pop reduce arg off the stack |
| auto data = stack_.back().toTupleRef().elements().at(0).toList(); |
| stack_.pop_back(); |
| data.unsafeSetElementType(elem_type); |
| stack_.emplace_back(std::move(data)); |
| }); |
| } |
| } else if ( |
| module_name == "torch._utils" && |
| (class_name == "_rebuild_tensor_v2" || |
| class_name == "_rebuild_qtensor")) { |
| // Unpickle a tensor |
| bool quantized = class_name == "_rebuild_qtensor"; |
| rebuildTensor(quantized); |
| } else if ( |
| module_name == "torch._tensor" && |
| (class_name == "_rebuild_from_type_v2")) { |
| // Unpickle a Tensor with Python attributes or |
| // a Subclassed Tensor. |
| rebuildTensorFromTypeV2(); |
| } else if ( |
| module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { |
| rebuildSparseTensor(); |
| } else if (module_name == "builtins" && class_name == "complex") { |
| globals_.emplace_back([this] { |
| auto tuple = pop(stack_).toTuple(); |
| const auto& elems = tuple->elements(); |
| AT_ASSERT(elems.size() == 2); |
| auto complex = |
| c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble()); |
| stack_.emplace_back(complex); |
| }); |
| |
| } else if (module_name == "collections" && class_name == "OrderedDict") { |
| // collections.OrderedDict is used in tensor serialization for a tensor's |
| // backward hooks (but they are not actually saved with this Pickler) |
| globals_.emplace_back([this] { |
| // drop the Tuple that was argument to OrderedDict, and replace it |
| // with None OrderedDicts only appear in tensor deserialization and |
| // their value is never used |
| stack_.back() = IValue(); |
| }); |
| } else if (module_name == "torch" && class_name == "device") { |
| globals_.emplace_back([this] { |
| auto device_string = stack_.back().toTupleRef().elements().at(0); |
| stack_.pop_back(); |
| stack_.emplace_back(c10::Device(device_string.toStringRef())); |
| }); |
| stack_.emplace_back(int64_t(globals_.size() - 1)); |
| return; |
| } else if (module_name == "torch.distributed.rpc" && class_name == "rref") { |
| #ifdef USE_RPC |
| return rebuildRRef(); |
| #else |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "RRef unpickling is only supported with the distributed package"); |
| #endif |
| } else if (module_name == "torch") { |
| // Try to manually resolve several global enums |
| // NOTE: this does not put a global into the global table, |
| // like the other branches here because no REDUCE or BUILD will |
| // be called on this value. Instead, we just put it on the stack |
| // and return early |
| c10::optional<c10::ScalarType> scalar_type; |
| #define CHECK_SCALAR(_, name) \ |
| if (class_name == #name "Storage") { \ |
| scalar_type = c10::k##name; \ |
| } |
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR) |
| #undef CHECK_SCALAR |
| if (scalar_type.has_value()) { |
| stack_.emplace_back(int64_t(*scalar_type)); |
| return; |
| } |
| |
| c10::optional<at::QScheme> qscheme; |
| for (int i = 0; i < at::COMPILE_TIME_NUM_QSCHEMES; ++i) { |
| if (class_name == toString(static_cast<at::QScheme>(i))) { |
| qscheme = static_cast<at::QScheme>(i); |
| } |
| } |
| if (qscheme.has_value()) { |
| stack_.emplace_back(int64_t(*qscheme)); |
| return; |
| } |
| TORCH_CHECK( |
| false, |
| "Unpickler found unknown torch global, 'torch.", |
| class_name, |
| "'"); |
| } else { |
| TORCH_CHECK( |
| type_resolver_, |
| "Unpickler found unknown type ", |
| module_name, |
| ".", |
| class_name); |
| at::StrongTypePtr type = |
| type_resolver_(c10::QualifiedName(module_name, class_name)); |
| if (auto enum_type = type.type_->cast<c10::EnumType>()) { |
| globals_.emplace_back([this, enum_type] { |
| auto val = stack_.back(); |
| stack_.pop_back(); |
| for (const auto& p : enum_type->enumNamesValues()) { |
| if (p.second == val) { |
| auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>( |
| enum_type, p.first, p.second); |
| stack_.emplace_back(std::move(enum_holder)); |
| return; |
| } |
| } |
| }); |
| } else { |
| // Otherwise, global is a class/object type. |
| globals_.emplace_back([this, type] { |
| auto val = stack_.back(); |
| stack_.pop_back(); |
| auto obj = obj_loader_(type, val); |
| stack_.emplace_back(std::move(obj)); |
| }); |
| } |
| } |
| stack_.emplace_back(int64_t(globals_.size() - 1)); |
| } |
| |
| void Unpickler::rebuildSparseTensor() { |
| globals_.emplace_back([this] { |
| auto tup = pop(stack_).toTuple(); |
| const auto& elements = tup->elements(); |
| size_t idx = 0; |
| auto layout = elements.at(idx++).toInt(); |
| at::Tensor result; |
| switch (layout) { |
| case static_cast<int>(c10::Layout::Sparse): { |
| std::vector<int64_t> size = tupleToIntList(elements.at(idx++)); |
| bool requires_grad = elements.at(idx++).toBool(); |
| auto& indices_tensor = elements.at(idx++).toTensor(); |
| auto& values_tensor = elements.at(idx++).toTensor(); |
| auto options = values_tensor.options() |
| .layout(c10::Layout::Sparse) |
| .requires_grad(requires_grad); |
| result = at::_sparse_coo_tensor_unsafe( |
| indices_tensor, values_tensor, size, options); |
| result = autograd::make_variable(result, options.requires_grad()); |
| break; |
| } |
| case static_cast<int>(c10::Layout::SparseCsr): { |
| std::vector<int64_t> size = tupleToIntList(elements.at(idx++)); |
| bool requires_grad = elements.at(idx++).toBool(); |
| auto& crow_indices = elements.at(idx++).toTensor(); |
| auto& col_indices = elements.at(idx++).toTensor(); |
| auto& values_tensor = elements.at(idx++).toTensor(); |
| auto options = values_tensor.options() |
| .layout(c10::Layout::SparseCsr) |
| .requires_grad(requires_grad); |
| result = at::_sparse_csr_tensor_unsafe( |
| crow_indices, col_indices, values_tensor, size, options); |
| result = |
| autograd::make_variable(std::move(result), options.requires_grad()); |
| break; |
| } |
| default: |
| TORCH_CHECK( |
| false, |
| "Unsupported sparse tensor layout type in serialization ", |
| static_cast<c10::Layout>(layout)); |
| break; |
| } |
| stack_.emplace_back(std::move(result)); |
| }); |
| } |
| |
| void Unpickler::rebuildTensor(bool quantized) { |
| globals_.emplace_back([this, quantized] { |
| auto tup = pop(stack_).toTuple(); |
| const auto& elements = tup->elements(); |
| size_t idx = 0; |
| auto& storage_tensor = elements.at(idx++).toTensor(); |
| int64_t storage_offset = elements.at(idx++).toInt(); |
| std::vector<int64_t> size = tupleToIntList(elements.at(idx++)); |
| std::vector<int64_t> stride = tupleToIntList(elements.at(idx++)); |
| at::Tensor result; |
| if (quantized) { |
| auto qparams_tuple = elements.at(idx++).toTuple(); |
| const auto& qparams = qparams_tuple->elements(); |
| auto qscheme = static_cast<at::QScheme>(qparams.at(0).toInt()); |
| switch (qscheme) { |
| case at::kPerTensorAffine: { |
| double q_scale = qparams.at(1).toDouble(); |
| int64_t q_zero_point = qparams.at(2).toInt(); |
| result = at::_empty_affine_quantized( |
| {0}, storage_tensor.options(), q_scale, q_zero_point); |
| } break; |
| case at::kPerChannelAffineFloatQParams: |
| case at::kPerChannelAffine: { |
| const auto& scales = qparams.at(1).toTensor(); |
| const auto& zero_points = qparams.at(2).toTensor(); |
| int64_t axis = qparams.at(3).toInt(); |
| result = at::_empty_per_channel_affine_quantized( |
| {0}, scales, zero_points, axis, storage_tensor.options()); |
| } break; |
| default: |
| TORCH_CHECK( |
| false, |
| "Unsupported tensor quantization type in serialization ", |
| toString(qscheme)); |
| break; |
| } |
| } else { |
| result = at::empty({0}, storage_tensor.options()); |
| } |
| bool requires_grad = elements.at(idx++).toBool(); |
| idx++; // backwards hooks is empty |
| at::TensorImpl* impl = result.unsafeGetTensorImpl(); |
| impl->set_storage_keep_dtype(storage_tensor.storage()); |
| impl->set_storage_offset(storage_offset); |
| impl->set_sizes_and_strides(size, stride); |
| result = autograd::make_variable(result, requires_grad); |
| |
| // Handle if math_bits were pickled. |
| // See `args` of _reduce_ex_internal |
| // for a regular tensor (final else case). |
| // Tensors pickled before this patch didn't |
| // have this argument for storing MathBits, |
| // in that case, we do nothing. |
| // NOTE: `math_bits` is the 7th arg. |
| // NOTE: This is only meant for regular tensor and not quantized |
| // which also has 7 args serialized. |
| if (!quantized && elements.size() == 7) { |
| auto math_bits = elements.at(idx++).toGenericDict(); |
| torch::jit::setTensorMetadata(result, math_bits); |
| } |
| |
| stack_.emplace_back(std::move(result)); |
| }); |
| } |
| |
| void Unpickler::rebuildTensorFromTypeV2() { |
| // [NOTE] skip_next_read_global |
| // When rebuilding Tensor with Python Attr or Subclassed Tensor, |
| // we receive `(func, type(self), args, state)` on stack for |
| // `rebuildTensorFromTypeV2`. |
| // Thus next call to readGlobal corresponds to `func` which is |
| // the function to rebuild the base tensor. |
| // The call after `func` to readGlobal corresponds to `type` of the |
| // Tensor where we raise warning if the type is not `torch.Tensor`. |
| this->skip_next_read_global = 2; |
| auto curr_globals_idx = globals_.size(); |
| globals_.emplace_back([this, curr_globals_idx] { |
| // args is a tuple with following data |
| // (function to rebuild base tensor, type of tensor, |
| // arguments to construct base tensor, Python State (as dict)) |
| auto args = pop(stack_).toTuple(); |
| size_t tup_idx = 0; |
| const auto args_elems = args->elements(); |
| auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple(); |
| auto py_state = args_elems.at(tup_idx + 3).toGenericDict(); |
| if (!py_state.empty()) { |
| TORCH_WARN( |
| "Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded"); |
| } |
| // This calls the function to rebuild the |
| // base tensor. |
| // Eg. `rebuildTensor`, `rebuildSpareTensor`. |
| stack_.emplace_back(base_tensor_args); |
| globals_[curr_globals_idx + 1](); |
| stack_.emplace_back(pop(stack_)); |
| }); |
| } |
| |
| #ifdef USE_RPC |
| void Unpickler::rebuildRRef() { |
| globals_.emplace_back([this] { |
| // It is the same as how rref is unpickled in python, |
| // see PyRRef::unpickle |
| auto tuple = std::move(stack_.back()).toTuple(); |
| const auto& args = tuple->elements(); |
| stack_.pop_back(); |
| TORCH_INTERNAL_ASSERT( |
| args.size() == distributed::rpc::RFD_TUPLE_SIZE, |
| "Pickled RRefForkData must contain 7 numbers."); |
| auto ownerId = |
| static_cast<int16_t>(args.at(distributed::rpc::OWNER_IDX).toInt()); |
| // const reference will extend the lifetime of the temporary variable |
| const auto& rrefId = distributed::rpc::RRefId( |
| static_cast<int16_t>(args.at(distributed::rpc::RREFID_ON_IDX).toInt()), |
| static_cast<int64_t>(args.at(distributed::rpc::RREFID_ID_IDX).toInt())); |
| const auto& forkId = distributed::rpc::RRefId( |
| static_cast<int16_t>(args.at(distributed::rpc::FORKID_ON_IDX).toInt()), |
| static_cast<int64_t>(args.at(distributed::rpc::FORKID_ID_IDX).toInt())); |
| auto parent = |
| static_cast<int16_t>(args.at(distributed::rpc::PARENT_IDX).toInt()); |
| const auto& typeStr = static_cast<std::string>( |
| args.at(distributed::rpc::TYPE_IDX).toStringRef()); |
| auto rrefForkData = distributed::rpc::RRefForkData( |
| ownerId, rrefId, forkId, parent, typeStr); |
| auto& ctx = distributed::rpc::RRefContext::getInstance(); |
| c10::intrusive_ptr<distributed::rpc::RRef> rref; |
| TORCH_INTERNAL_ASSERT( |
| type_resolver_ != nullptr, "type_resolver_ is nullptr."); |
| at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr)); |
| rref = ctx.getOrCreateRRef(rrefForkData, type.type_); |
| ctx.notifyOwnerAndParentOfFork( |
| rrefForkData.forkId_, rrefForkData.parent_, rref); |
| stack_.emplace_back( |
| c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref)); |
| }); |
| stack_.emplace_back(int64_t(globals_.size() - 1)); |
| return; |
| } |
| #endif |
| |
| void Unpickler::readSlowWithBuffer(char* dest, size_t sz) { |
| // First, read any partial from buffer (may be 0). |
| // We explicitly assume that sz > buffer_remaining_, |
| // and that sz is never bigger than buffer_.size(). |
| AT_ASSERT(sz > buffer_remaining_); |
| const size_t from_old_buf = buffer_remaining_; |
| if (from_old_buf != 0) { |
| memcpy(dest, buffer_.data() + buffer_pos_, from_old_buf); |
| } |
| const size_t needed = sz - from_old_buf; |
| // Full read into the buffer. The calls here all explicitly |
| // assume that one buffer will be enough for any sz. |
| AT_ASSERT(sz <= buffer_.size()); |
| buffer_remaining_ = reader_(buffer_.data(), buffer_.size()); |
| if (buffer_remaining_ < needed) { |
| AT_ERROR("Unexpected end of pickler archive."); |
| } |
| memcpy(dest + from_old_buf, buffer_.data(), needed); |
| buffer_pos_ = needed; // assignment (0'ed from read) |
| buffer_remaining_ -= needed; |
| } |
| |
| // Read a number of bytes from the input stream |
| std::string Unpickler::readBytes(size_t length) { |
| std::string data; |
| static const size_t kSmallString = 64; |
| if (length <= buffer_remaining_) { |
| // Fast-path: entirely in buffer. |
| data.assign(buffer_.data() + buffer_pos_, length); |
| buffer_pos_ += length; |
| buffer_remaining_ -= length; |
| } else if (length <= kSmallString) { |
| // If the string is smallish, do a full buffer read, |
| // and read out of that buffer. |
| data.resize(length); |
| readSlowWithBuffer(&data[0], length); |
| } else { |
| // Otherwise, for larger strings, read what we can from |
| // the buffer, and then read directly to the destination. |
| const size_t from_old_buf = buffer_remaining_; |
| if (from_old_buf != 0) { |
| data.reserve(length); |
| data.append(buffer_.data() + buffer_pos_, from_old_buf); |
| } |
| data.resize(length); |
| const size_t needed = length - from_old_buf; |
| size_t nread = reader_(&data[from_old_buf], needed); |
| if (nread != needed) { |
| AT_ERROR("Unexpected end of pickler archive."); |
| } |
| buffer_remaining_ = 0; |
| // buffer_pos_ has no meaning with buffer_remaining_ == 0. |
| } |
| return data; |
| } |
| |
| // Pop all the list items off of the stack and append them to the list at |
| // the corresponding MARK |
| void Unpickler::readList(IValue list_ivalue) { |
| TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); |
| size_t start = marks_.back(); |
| marks_.pop_back(); |
| auto num_elements = stack_.size() - start; |
| auto elements = c10::ArrayRef<IValue>(stack_).slice(start); |
| if (list_ivalue.isIntList()) { |
| auto list = std::move(list_ivalue).toIntList(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem.toInt()); |
| } |
| } else if (list_ivalue.isTensorList()) { |
| auto list = std::move(list_ivalue).toTensorList(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem.toTensor()); |
| } |
| } else if (list_ivalue.isDoubleList()) { |
| auto list = std::move(list_ivalue).toDoubleList(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem.toDouble()); |
| } |
| } else if (list_ivalue.isBoolList()) { |
| auto list = std::move(list_ivalue).toBoolList(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.push_back(elem.toBool()); |
| } |
| } else if (list_ivalue.isList()) { |
| auto list = std::move(list_ivalue).toList(); |
| list.reserve(num_elements); |
| for (const auto& elem : elements) { |
| list.emplace_back(elem); |
| } |
| } else { |
| AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind()); |
| } |
| |
| stack_.erase(stack_.begin() + start, stack_.end()); |
| } |
| |
| inline bool is_valid_python_id_char(char c) { |
| return c == '_' || c == '.' || (c >= '0' && c <= '9') || |
| (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); |
| } |
| |
| // Read a newline terminated string |
| std::string Unpickler::readString() { |
| std::string ss; |
| while (true) { |
| auto* const bufferStart = buffer_.data() + buffer_pos_; |
| const auto bufferLeft = buffer_.size() - buffer_pos_; |
| char* const newlinePtr = |
| static_cast<char*>(memchr(bufferStart, '\n', bufferLeft)); |
| if (newlinePtr) { |
| // read up to newline and we are done. |
| auto const charsRead = newlinePtr - bufferStart; |
| ss.append(bufferStart, charsRead); |
| buffer_remaining_ -= charsRead + 1; |
| buffer_pos_ += charsRead + 1; |
| break; |
| } else { |
| // read whole buffer, refill |
| for (const char* p = bufferStart; p < bufferStart + bufferLeft; ++p) { |
| // Simple check just in case there is no terminating '\n' |
| TORCH_CHECK( |
| is_valid_python_id_char(*p), |
| "Found character '", |
| int(uint8_t(*p)), |
| "' in string, ", |
| "strings must be qualified Python identifiers"); |
| } |
| ss.append(bufferStart, bufferLeft); |
| buffer_remaining_ = reader_(buffer_.data(), buffer_.size()); |
| buffer_pos_ = 0; |
| } |
| } |
| return ss; |
| } |
| |
| } // namespace torch::jit |