| #pragma once |
| |
| #include <c10/core/Storage.h> |
| #include "caffe2/core/operator.h" |
| #include "caffe2/core/tensor.h" |
| |
| namespace caffe2 { |
| |
| #ifndef C10_MOBILE |
| struct TORCH_API OfflineTensor { |
| // A shell tensor to record shape and dtype |
| Tensor shape_tensor{CPU}; |
| |
| void setShapeAndType( |
| const std::vector<int>& sizes, |
| at::Device device, |
| caffe2::TypeMeta data_type) { |
| shape_tensor.unsafeGetTensorImpl()->set_storage_and_dtype( |
| at::Storage::create_legacy(device), data_type); |
| shape_tensor.Resize(sizes); |
| CHECK(!shape_tensor.storage_initialized()); |
| CHECK(shape_tensor.dtype_initialized()); |
| } |
| }; |
| |
| class OfflineTensorShapeFunctions : public ExternalTensorFunctionsBase { |
| public: |
| explicit OfflineTensorShapeFunctions() : ExternalTensorFunctionsBase() {} |
| ~OfflineTensorShapeFunctions() override {} |
| bool isQuantized() const override { |
| return false; |
| } |
| bool IsSameMetaType(TypeIdentifier id) override; |
| void SetupExternalTensorDescriptor( |
| const Blob* blob, |
| std::vector<std::vector<uint64_t>>* shapes, |
| std::vector<std::vector<float>>* all_scales, |
| std::vector<std::vector<int32_t>>* all_offsets, |
| ExternalTensorDescriptor* desc) override; |
| void LoadInfoOfBlob( |
| const Blob* /* unused */, |
| std::vector<float>* /* unused */, |
| std::vector<float>* /* unused */, |
| uint32_t* /* unused */) override {} |
| TypeIdentifier GetTypeMetaId() override; |
| TypeMeta GetExternalTensorType(const void* c) override; |
| vector<int64_t> GetExternalTensorInfo( |
| const void* c, |
| size_t* capacity, |
| DeviceOption* device) override; |
| }; |
| #endif |
| } // namespace caffe2 |