| #include <c10/cuda/CUDACachingAllocator.h> |
| #include <c10/cuda/CUDAGuard.h> |
| #include <mutex> |
| #include <unordered_map> |
| #include <utility> |
| |
| #include <torch/csrc/cuda/CUDAPluggableAllocator.h> |
| |
| namespace torch::cuda::CUDAPluggableAllocator { |
| |
| int device_count = 0; |
| |
| void custom_raw_deleter(void* ptr); |
| |
| _AllocationMetadata::_AllocationMetadata() |
| : size(0), device_idx(-1), stream{} {} |
| |
| _AllocationMetadata::_AllocationMetadata( |
| size_t size, |
| c10::DeviceIndex device_idx, |
| cudaStream_t stream) |
| : size(size), device_idx(device_idx), stream(stream) {} |
| |
| // This is a fast API to just register allocators |
| // based on function pointers (ie. external .so libraries) |
| // This avoids having to link against libtorch for C++ based custom allocators |
| // And also use this from python |
| CUDAPluggableAllocator::CUDAPluggableAllocator( |
| std::function<void*(size_t, int, cudaStream_t)> alloc_fn, |
| std::function<void(void*, size_t, int, cudaStream_t)> free_fn) |
| : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {} |
| |
| CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other) |
| : alloc_fn_(other.alloc_fn_), |
| free_fn_(other.free_fn_), |
| init_fn_(other.init_fn_), |
| reset_fn_(other.reset_fn_), |
| memory_fraction_fn_(other.memory_fraction_fn_), |
| base_alloc_fn_(other.base_alloc_fn_), |
| record_stream_fn_(other.record_stream_fn_), |
| begin_allocate_to_pool_fn_(other.begin_allocate_to_pool_fn_), |
| end_allocate_to_pool_fn_(other.end_allocate_to_pool_fn_), |
| relase_pool_fn_(other.relase_pool_fn_) {} |
| |
| void CUDAPluggableAllocator::set_init_fn(std::function<void(int)> init_fn) { |
| init_fn_ = std::move(init_fn); |
| } |
| |
| void CUDAPluggableAllocator::set_reset_fn(std::function<void()> reset_fn) { |
| reset_fn_ = std::move(reset_fn); |
| } |
| |
| void CUDAPluggableAllocator::set_memory_fraction_fn( |
| std::function<void(double, int)> memory_fraction_fn) { |
| memory_fraction_fn_ = std::move(memory_fraction_fn); |
| } |
| |
| void CUDAPluggableAllocator::set_base_alloc_fn( |
| std::function<void*(void*, size_t*)> base_alloc_fn) { |
| base_alloc_fn_ = std::move(base_alloc_fn); |
| } |
| |
| void CUDAPluggableAllocator::set_record_stream_fn( |
| std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn) { |
| record_stream_fn_ = std::move(record_stream_fn); |
| } |
| |
| void CUDAPluggableAllocator::set_begin_allocate_to_pool( |
| std::function< |
| void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)> |
| capture_begin_fn) { |
| begin_allocate_to_pool_fn_ = std::move(capture_begin_fn); |
| } |
| |
| void CUDAPluggableAllocator::set_end_allocate_to_pool_fn( |
| std::function<void(int, c10::cuda::MempoolId_t)> capture_about_to_end_fn) { |
| end_allocate_to_pool_fn_ = std::move(capture_about_to_end_fn); |
| } |
| |
| void CUDAPluggableAllocator::set_release_pool( |
| std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn) { |
| relase_pool_fn_ = std::move(capture_destroy_fn); |
| } |
| |
| void* CUDAPluggableAllocator::malloc( |
| size_t size, |
| c10::DeviceIndex device, |
| cudaStream_t stream) { |
| void* r = alloc_fn_(size, device, stream); |
| { |
| const std::lock_guard<std::mutex> lock(allocator_mutex_); |
| allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream)); |
| } |
| return r; |
| } |
| |
| c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) { |
| c10::DeviceIndex device = -1; |
| C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); |
| cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); |
| void* r = this->malloc(size, device, stream); |
| c10::DataPtr data_ptr = { |
| r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; |
| return data_ptr; |
| } |
| |
| c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const { |
| return &custom_raw_deleter; |
| } |
| |
| void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) { |
| c10::DeviceIndex device = -1; |
| C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); |
| cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); |
| return malloc(nbytes, device, stream); |
| } |
| |
| void* CUDAPluggableAllocator::raw_alloc_with_stream( |
| size_t nbytes, |
| cudaStream_t stream) { |
| c10::DeviceIndex device = -1; |
| C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); |
| return malloc(nbytes, device, stream); |
| } |
| |
| void CUDAPluggableAllocator::raw_delete(void* ptr) { |
| cudaStream_t stream{}; |
| c10::DeviceIndex device_idx = -1; |
| size_t size = 0; |
| { |
| const std::lock_guard<std::mutex> lock(allocator_mutex_); |
| TORCH_CHECK( |
| allocation_metadata_.count(ptr), |
| "Trying to free a pointer not allocated here"); |
| _AllocationMetadata& metadata = allocation_metadata_[ptr]; |
| size = metadata.size; |
| device_idx = metadata.device_idx; |
| stream = metadata.stream; |
| allocation_metadata_.erase(ptr); |
| } |
| free_fn_(ptr, size, device_idx, stream); |
| } |
| |
| void CUDAPluggableAllocator::init(int device_count) { |
| if (init_fn_) { |
| init_fn_(device_count); |
| } |
| initialized_ = true; |
| } |
| |
| bool CUDAPluggableAllocator::initialized() { |
| return initialized_; |
| } |
| |
| void CUDAPluggableAllocator::setMemoryFraction( |
| double fraction, |
| c10::DeviceIndex device) { |
| if (memory_fraction_fn_) { |
| memory_fraction_fn_(fraction, device); |
| } |
| } |
| |
| void CUDAPluggableAllocator::emptyCache() { |
| if (reset_fn_) { |
| return reset_fn_(); |
| } |
| } |
| |
| void CUDAPluggableAllocator::cacheInfo( |
| c10::DeviceIndex device, |
| size_t* largestBlock) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support cacheInfo. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) { |
| if (base_alloc_fn_) { |
| return base_alloc_fn_(ptr, size); |
| } else { |
| return ptr; |
| } |
| } |
| |
| void CUDAPluggableAllocator::recordStream( |
| const c10::DataPtr& ptr, |
| streamType stream) { |
| if (record_stream_fn_) { |
| record_stream_fn_(ptr.get(), stream); |
| } |
| } |
| |
| c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: |
| getDeviceStats(c10::DeviceIndex device) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support getDeviceStats. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| void CUDAPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support resetAccumulatedStats. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support resetPeakStats. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator:: |
| snapshot() { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support snapshot. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| std::shared_ptr<void> CUDAPluggableAllocator::getIpcDevPtr(std::string handle) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support getIpcDevPtr. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| // CUDAGraph interactions |
| void CUDAPluggableAllocator::beginAllocateToPool( |
| c10::DeviceIndex device, |
| c10::cuda::MempoolId_t mempool_id, |
| std::function<bool(cudaStream_t)> filter) { |
| if (begin_allocate_to_pool_fn_) { |
| begin_allocate_to_pool_fn_(device, mempool_id, std::move(filter)); |
| } |
| } |
| |
| void CUDAPluggableAllocator::endAllocateToPool( |
| c10::DeviceIndex device, |
| c10::cuda::MempoolId_t mempool_id) { |
| if (end_allocate_to_pool_fn_) { |
| end_allocate_to_pool_fn_(device, mempool_id); |
| } |
| } |
| |
| void CUDAPluggableAllocator::releasePool( |
| c10::DeviceIndex device, |
| c10::cuda::MempoolId_t mempool_id) { |
| if (relase_pool_fn_) { |
| relase_pool_fn_(device, mempool_id); |
| } |
| } |
| |
| void CUDAPluggableAllocator::recordHistory( |
| bool enabled, |
| c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, |
| size_t alloc_trace_max_entries, |
| c10::cuda::CUDACachingAllocator::RecordContext when) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support recordHistory. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| void CUDAPluggableAllocator::attachOutOfMemoryObserver( |
| c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support attachOutOfMemoryObserver. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| void CUDAPluggableAllocator::attachAllocatorTraceTracker( |
| c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not support attachAllocatorTraceTracker. " |
| "attachAllocatorTraceTracker is only used inside Pytorch."); |
| } |
| |
| std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> |
| CUDAPluggableAllocator::getCheckpointState( |
| c10::DeviceIndex device, |
| at::cuda::MempoolId_t id) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support getCheckpointState. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| c10::cuda::CUDACachingAllocator::CheckpointDelta CUDAPluggableAllocator:: |
| setCheckpointPoolState( |
| c10::DeviceIndex device, |
| std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps) { |
| TORCH_CHECK( |
| false, |
| "CUDAPluggableAllocator does not yet support setCheckpointPoolState. " |
| "If you need it, please file an issue describing your use case."); |
| } |
| |
| void CUDAPluggableAllocator::enablePeerAccess( |
| c10::DeviceIndex dev, |
| c10::DeviceIndex dev_to_access) { |
| c10::cuda::CUDAGuard device_guard(dev); |
| cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0); |
| if (err == cudaErrorPeerAccessAlreadyEnabled) { |
| // ignore and clear the error if access was already enabled |
| (void)cudaGetLastError(); |
| } else { |
| C10_CUDA_CHECK(err); |
| } |
| } |
| |
| cudaError_t CUDAPluggableAllocator::memcpyAsync( |
| void* dst, |
| int dstDevice, |
| const void* src, |
| int srcDevice, |
| size_t count, |
| cudaStream_t stream, |
| bool p2p_enabled) { |
| return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream); |
| } |
| |
| std::string CUDAPluggableAllocator::name() { |
| return "pluggable"; |
| } |
| |
| void CUDAPluggableAllocator::copy_data( |
| void* dest, |
| const void* src, |
| std::size_t count) const { |
| C10_CUDA_CHECK( |
| cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice)); |
| } |
| |
| std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
| current_custom_allocator; |
| |
| std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
| getCurrentAllocator() { |
| return current_custom_allocator; |
| } |
| |
| // TODO: add more functions in the argument |
| std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
| createCustomAllocator( |
| std::function<void*(size_t, int, cudaStream_t)> alloc_fn, |
| std::function<void(void*, size_t, int, cudaStream_t)> free_fn) { |
| std::shared_ptr<CUDAPluggableAllocator> allocator( |
| new CUDAPluggableAllocator(std::move(alloc_fn), std::move(free_fn))); |
| allocator->init(device_count); |
| return allocator; |
| } |
| |
| void changeCurrentAllocator( |
| const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>& |
| allocator) { |
| TORCH_CHECK( |
| !c10::cuda::CUDACachingAllocator::allocator.load()->initialized(), |
| "Can't swap an already initialized allocator"); |
| c10::cuda::CUDACachingAllocator::allocator.store(allocator.get()); |
| current_custom_allocator = allocator; |
| } |
| |
| void custom_raw_deleter(void* ptr) { |
| current_custom_allocator->raw_delete(ptr); |
| } |
| |
| } // namespace torch::cuda::CUDAPluggableAllocator |