| #include <c10/util/flat_hash_map.h> |
| #include <c10/util/irange.h> |
| #include <c10/xpu/XPUCachingAllocator.h> |
| |
| #include <deque> |
| #include <mutex> |
| #include <set> |
| #include <vector> |
| |
| namespace c10::xpu::XPUCachingAllocator { |
| |
| using namespace c10::CachingDeviceAllocator; |
| |
| // newly allocated memory with 512-byte alignment. |
| constexpr size_t kDeviceAlignment = 512; |
| // all sizes are rounded to at least 512 bytes |
| constexpr size_t kMinBlockSize = 512; |
| // largest "small" allocation is 1 MiB |
| constexpr size_t kSmallSize = 1048576; |
| // "small" allocations are packed in 2 MiB blocks |
| constexpr size_t kSmallBuffer = 2097152; |
| // "large" allocations may be packed in 20 MiB blocks |
| constexpr size_t kLargeBuffer = 20971520; |
| // allocations between 1 and 10 MiB may use kLargeBuffer |
| constexpr size_t kMinLargeAlloc = 10485760; |
| // round up large allocations to 2 MiB |
| constexpr size_t kRoundLarge = 2097152; |
| |
| namespace { |
| using stream_set = ska::flat_hash_set<xpu::XPUStream>; |
| |
| struct Block; |
| typedef bool (*Comparison)(const Block*, const Block*); |
| bool BlockComparatorSize(const Block* a, const Block* b); |
| |
| struct BlockPool { |
| BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {} |
| std::set<Block*, Comparison> blocks; |
| const bool is_small; |
| }; |
| |
| struct Block { |
| DeviceIndex device; |
| sycl::queue* queue{nullptr}; // underlying queue of the allocation stream |
| stream_set stream_uses; // streams on which the block was used |
| size_t size; // block size in bytes |
| size_t requested_size; // memory originally requested |
| BlockPool* pool{nullptr}; // owning memory pool |
| void* ptr{nullptr}; // memory address |
| bool allocated{false}; // in-use flag |
| Block* prev{nullptr}; // prev block if split from a larger allocation |
| Block* next{nullptr}; // next block if split from a larger allocation |
| int event_count{0}; // number of outstanding XPU events |
| |
| Block( |
| DeviceIndex device, |
| sycl::queue* queue, |
| size_t size, |
| BlockPool* pool, |
| void* ptr) |
| : device(device), |
| queue(queue), |
| stream_uses(), |
| size(size), |
| requested_size(0), |
| pool(pool), |
| ptr(ptr) {} |
| |
| // constructor for search key |
| Block(DeviceIndex device, sycl::queue* queue, size_t size) |
| : device(device), |
| queue(queue), |
| stream_uses(), |
| size(size), |
| requested_size(0) {} |
| |
| bool is_split() const { |
| return (prev != nullptr) || (next != nullptr); |
| } |
| }; |
| |
| bool BlockComparatorSize(const Block* a, const Block* b) { |
| if (a->queue != b->queue) { |
| return reinterpret_cast<uintptr_t>(a->queue) < |
| reinterpret_cast<uintptr_t>(b->queue); |
| } |
| if (a->size != b->size) { |
| return a->size < b->size; |
| } |
| return reinterpret_cast<uintptr_t>(a->ptr) < |
| reinterpret_cast<uintptr_t>(b->ptr); |
| } |
| |
| struct AllocParams { |
| AllocParams( |
| DeviceIndex device, |
| size_t size, |
| sycl::queue* queue, |
| BlockPool* pool, |
| size_t alloc_size) |
| : search_key(device, queue, size), |
| pool(pool), |
| alloc_size(alloc_size), |
| block(nullptr) {} |
| |
| DeviceIndex device() const { |
| return search_key.device; |
| } |
| |
| sycl::queue* queue() const { |
| return search_key.queue; |
| } |
| |
| size_t size() const { |
| return search_key.size; |
| } |
| |
| Block search_key; |
| BlockPool* pool; |
| size_t alloc_size; |
| Block* block; |
| StatTypes stat_types = {}; |
| }; |
| |
| } // anonymous namespace |
| |
| class DeviceCachingAllocator { |
| private: |
| mutable std::recursive_mutex mutex; |
| DeviceStats stats; |
| BlockPool large_blocks; // unallocated cached blocks larger than 1 MB |
| BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller |
| ska::flat_hash_set<Block*> active_blocks; // allocated or in use by a stream |
| ska::flat_hash_map<xpu::XPUStream, std::deque<std::pair<sycl::event, Block*>>> |
| xpu_events; |
| DeviceIndex device_index; |
| |
| size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { |
| if (!src || src->allocated || src->event_count > 0 || |
| !src->stream_uses.empty()) { |
| return 0; |
| } |
| |
| TORCH_INTERNAL_ASSERT(dst->is_split() && src->is_split()); |
| if (dst->prev == src) { // [src dst] |
| dst->ptr = src->ptr; |
| dst->prev = src->prev; |
| if (dst->prev) { |
| dst->prev->next = dst; |
| } |
| } else { // [dst src] |
| dst->next = src->next; |
| if (dst->next) { |
| dst->next->prev = dst; |
| } |
| } |
| const size_t subsumed_size = src->size; |
| dst->size += subsumed_size; |
| auto erased = pool.blocks.erase(src); |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1); |
| delete src; |
| |
| return subsumed_size; |
| } |
| |
| void free_block(Block* block) { |
| TORCH_INTERNAL_ASSERT( |
| !block->allocated && block->event_count == 0 && |
| block->stream_uses.empty()); |
| |
| size_t original_block_size = block->size; |
| size_t requested_size = block->requested_size; |
| auto& pool = *block->pool; |
| const std::array<Block*, 2> merge_candidates = {block->prev, block->next}; |
| for (Block* merge_candidate : merge_candidates) { |
| try_merge_blocks(block, merge_candidate, pool); |
| } |
| |
| active_blocks.erase(block); |
| bool inserted = pool.blocks.insert(block).second; |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); |
| |
| StatTypes stat_types = get_stat_types_for_pool(pool); |
| for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
| stats.active_bytes[stat_type].decrease(original_block_size); |
| stats.requested_bytes[stat_type].decrease(requested_size); |
| }); |
| } |
| |
| void process_events() { |
| using namespace sycl::info; |
| for (auto it = xpu_events.begin(); it != xpu_events.end();) { |
| while (!it->second.empty()) { |
| auto& e = it->second.front(); |
| auto event = e.first; |
| auto* block = e.second; |
| if (event.get_info<event::command_execution_status>() != |
| event_command_status::complete) { |
| break; |
| } |
| block->event_count--; |
| if (block->event_count == 0) { |
| free_block(block); |
| } |
| it->second.pop_front(); |
| } |
| |
| if (it->second.empty()) { |
| it = xpu_events.erase(it); |
| } else { |
| it++; |
| } |
| } |
| } |
| |
| static size_t round_size(size_t size) { |
| if (size < kMinBlockSize) { |
| return kMinBlockSize; |
| } else { |
| return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); |
| } |
| } |
| |
| static size_t get_allocation_size(size_t size) { |
| if (size <= kSmallSize) { |
| return kSmallBuffer; |
| } else if (size < kMinLargeAlloc) { |
| return kLargeBuffer; |
| } else { |
| return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); |
| } |
| } |
| |
| BlockPool& get_pool(size_t size) { |
| if (size < kSmallSize) { |
| return small_blocks; |
| } else { |
| return large_blocks; |
| } |
| } |
| |
| bool get_free_block(AllocParams& p) { |
| BlockPool& pool = *p.pool; |
| auto it = pool.blocks.lower_bound(&p.search_key); |
| if (it == pool.blocks.end() || (*it)->queue != p.queue()) { |
| return false; |
| } |
| p.block = *it; |
| pool.blocks.erase(it); |
| return true; |
| } |
| |
| bool alloc_block(AllocParams& p) { |
| auto size = p.alloc_size; |
| auto device = p.device(); |
| void* ptr = sycl::aligned_alloc_device( |
| kDeviceAlignment, |
| size, |
| xpu::get_raw_device(device), |
| xpu::get_device_context()); |
| if (!ptr) { |
| return false; |
| } |
| p.block = new Block(device, p.queue(), size, p.pool, ptr); |
| for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { |
| stats.reserved_bytes[stat_type].increase(size); |
| }); |
| return true; |
| } |
| |
| void synchronize_and_free_events() { |
| for (auto& xe : xpu_events) { |
| for (auto& e : xe.second) { |
| auto event = e.first; |
| auto* block = e.second; |
| event.wait(); |
| block->event_count--; |
| if (block->event_count == 0) { |
| free_block(block); |
| } |
| } |
| } |
| xpu_events.clear(); |
| } |
| |
| void release_block(Block* block) { |
| /* |
| * Note [Safe to Free Blocks on BlockPool] |
| * |
| * Callers must ensure that all accesses to the block, whose raw pointer is |
| * allocated by SYCL APIs, have been completed before invoking sycl::free. |
| * |
| * We have to do a device-level synchronization before free these blocks to |
| * guarantee that all kernels can access to the blocks have finished. |
| */ |
| sycl::free(block->ptr, xpu::get_device_context()); |
| auto* pool = block->pool; |
| pool->blocks.erase(block); |
| |
| StatTypes stat_types = get_stat_types_for_pool(*pool); |
| for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
| stats.reserved_bytes[stat_type].decrease(block->size); |
| }); |
| |
| delete block; |
| } |
| |
| void release_blocks(BlockPool& pool) { |
| auto it = pool.blocks.begin(); |
| while (it != pool.blocks.end()) { |
| Block* block = *it; |
| ++it; |
| if (!block->prev && !block->next) { |
| release_block(block); |
| } |
| } |
| } |
| |
| bool release_cached_blocks() { |
| synchronize_and_free_events(); |
| // See Note [Safe to Free Blocks on BlockPool] |
| c10::xpu::syncStreamsOnDevice(device_index); |
| |
| release_blocks(large_blocks); |
| release_blocks(small_blocks); |
| return true; |
| } |
| |
| bool should_split(const Block* block, size_t size) { |
| size_t remaining = block->size - size; |
| if (block->pool->is_small) { |
| return remaining >= kMinBlockSize; |
| } else { |
| return remaining > kSmallSize; |
| } |
| } |
| |
| StatTypes get_stat_types_for_pool(const BlockPool& pool) { |
| StatTypes stat_types = {}; |
| stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
| stat_types[static_cast<size_t>( |
| pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true; |
| return stat_types; |
| } |
| |
| Block* alloc_found_block( |
| AllocParams params, |
| size_t orig_size, |
| bool split_remainder) { |
| auto size = params.size(); |
| auto device = params.device(); |
| BlockPool* pool = params.pool; |
| sycl::queue* queue = params.queue(); |
| |
| TORCH_INTERNAL_ASSERT( |
| params.block != nullptr && params.block->ptr != nullptr); |
| Block* block = params.block; |
| Block* remaining = nullptr; |
| |
| if (split_remainder) { |
| remaining = block; |
| |
| block = new Block(device, queue, size, pool, block->ptr); |
| block->prev = remaining->prev; |
| if (block->prev) { |
| block->prev->next = block; |
| } |
| block->next = remaining; |
| |
| remaining->prev = block; |
| remaining->ptr = static_cast<char*>(remaining->ptr) + size; |
| remaining->size -= size; |
| bool inserted = pool->blocks.insert(remaining).second; |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); |
| } |
| |
| block->allocated = true; |
| block->requested_size = orig_size; |
| bool inserted = active_blocks.insert(block).second; |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted) |
| |
| for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { |
| stats.allocated_bytes[stat_type].increase(block->size); |
| stats.active_bytes[stat_type].increase(block->size); |
| stats.requested_bytes[stat_type].increase(block->requested_size); |
| }); |
| |
| return block; |
| } |
| |
| void insert_events(Block* block) { |
| stream_set streams(std::move(block->stream_uses)); |
| TORCH_INTERNAL_ASSERT(block->stream_uses.empty()); |
| for (auto& stream : streams) { |
| block->event_count++; |
| xpu_events[stream].emplace_back( |
| stream.queue().ext_oneapi_submit_barrier(), block); |
| } |
| } |
| |
| public: |
| DeviceCachingAllocator(DeviceIndex device_index) |
| : large_blocks(/* small */ false), |
| small_blocks(/* small */ true), |
| device_index(device_index) {} |
| |
| Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) { |
| std::scoped_lock<std::recursive_mutex> lock(mutex); |
| process_events(); |
| size_t size = round_size(orig_size); |
| auto& pool = get_pool(size); |
| const size_t alloc_size = get_allocation_size(size); |
| AllocParams params(device, size, &queue, &pool, alloc_size); |
| params.stat_types = get_stat_types_for_pool(pool); |
| |
| // First, try to get a block from the existing pool. |
| bool block_found = get_free_block(params); |
| // Can't reuse an existing block, try to get a new one. |
| if (!block_found) { |
| block_found = alloc_block(params) || |
| (release_cached_blocks() && alloc_block(params)); |
| } |
| if (!block_found) { |
| c10::xpu::DeviceProp device_prop; |
| c10::xpu::get_device_properties(&device_prop, device); |
| auto device_total = device_prop.global_mem_size; |
| auto allocated_bytes = |
| stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)] |
| .current; |
| auto reserved_bytes = |
| stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)] |
| .current; |
| TORCH_CHECK_WITH( |
| OutOfMemoryError, |
| false, |
| "XPU out of memory. Tried to allocate ", |
| format_size(alloc_size), |
| ". GPU ", |
| static_cast<int>(device), |
| " has a total capacity of ", |
| format_size(device_total), |
| ". Of the allocated memory ", |
| format_size(allocated_bytes), |
| " is allocated by PyTorch, and ", |
| format_size(reserved_bytes - allocated_bytes), |
| " is reserved by PyTorch but unallocated.", |
| " Please use `empty_cache` to release all unoccupied cached memory."); |
| } |
| bool split_remainder = should_split(params.block, params.size()); |
| return alloc_found_block(std::move(params), orig_size, split_remainder); |
| } |
| |
| void free(Block* block) { |
| std::scoped_lock<std::recursive_mutex> lock(mutex); |
| block->allocated = false; |
| |
| StatTypes stat_types = get_stat_types_for_pool(*block->pool); |
| for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
| stats.allocated_bytes[stat_type].decrease(block->size); |
| }); |
| |
| if (!block->stream_uses.empty()) { |
| insert_events(block); |
| } else { |
| free_block(block); |
| } |
| } |
| |
| void recordStream(Block* block, xpu::XPUStream stream) { |
| std::scoped_lock<std::recursive_mutex> lock(mutex); |
| if (stream.queue() == *block->queue) { |
| return; |
| } |
| block->stream_uses.insert(stream); |
| } |
| |
| void emptyCache() { |
| std::scoped_lock<std::recursive_mutex> lock(mutex); |
| release_cached_blocks(); |
| } |
| |
| DeviceStats getStats() { |
| std::scoped_lock<std::recursive_mutex> lock(mutex); |
| return stats; |
| } |
| |
| void resetAccumulatedStats() { |
| std::scoped_lock<std::recursive_mutex> lock(mutex); |
| |
| for (const auto statType : |
| c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { |
| stats.allocated_bytes[statType].reset_accumulated(); |
| stats.reserved_bytes[statType].reset_accumulated(); |
| stats.active_bytes[statType].reset_accumulated(); |
| stats.requested_bytes[statType].reset_accumulated(); |
| } |
| } |
| |
| void resetPeakStats() { |
| std::scoped_lock<std::recursive_mutex> lock(mutex); |
| |
| for (const auto statType : |
| c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { |
| stats.allocated_bytes[statType].reset_peak(); |
| stats.reserved_bytes[statType].reset_peak(); |
| stats.active_bytes[statType].reset_peak(); |
| stats.requested_bytes[statType].reset_peak(); |
| } |
| } |
| }; |
| |
| void local_raw_delete(void* ptr); |
| |
| class XPUAllocator : public Allocator { |
| private: |
| std::mutex mutex; |
| ska::flat_hash_map<void*, Block*> allocated_blocks; |
| |
| void add_allocated_block(Block* block) { |
| std::lock_guard<std::mutex> lock(mutex); |
| allocated_blocks[block->ptr] = block; |
| } |
| |
| Block* get_allocated_block(void* ptr, bool remove = false) { |
| std::scoped_lock<std::mutex> lock(mutex); |
| auto it = allocated_blocks.find(ptr); |
| if (it == allocated_blocks.end()) { |
| return nullptr; |
| } |
| Block* block = it->second; |
| if (remove) { |
| allocated_blocks.erase(it); |
| } |
| return block; |
| } |
| |
| public: |
| std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators; |
| |
| void init(DeviceIndex device_count) { |
| const auto size = static_cast<DeviceIndex>(device_allocators.size()); |
| if (size < device_count) { |
| device_allocators.resize(device_count); |
| for (const auto i : c10::irange(size, device_count)) { |
| device_allocators[i] = std::make_unique<DeviceCachingAllocator>(i); |
| } |
| } |
| } |
| |
| void malloc( |
| void** devPtr, |
| DeviceIndex device, |
| size_t size, |
| sycl::queue& queue) { |
| TORCH_INTERNAL_ASSERT( |
| 0 <= device && static_cast<size_t>(device) < device_allocators.size(), |
| "Allocator not initialized for device ", |
| static_cast<int16_t>(device), |
| ": did you call init?"); |
| Block* block = device_allocators[device]->malloc(device, size, queue); |
| add_allocated_block(block); |
| *devPtr = block->ptr; |
| const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| if (C10_UNLIKELY(interp)) { |
| (*interp)->trace_gpu_memory_allocation( |
| c10::kXPU, reinterpret_cast<uintptr_t>(*devPtr)); |
| } |
| } |
| |
| void free(void* ptr) { |
| if (!ptr) { |
| return; |
| } |
| Block* block = get_allocated_block(ptr, /* remove */ true); |
| TORCH_CHECK(block, "invalid device pointer: ", ptr); |
| device_allocators[block->device]->free(block); |
| const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| if (C10_UNLIKELY(interp)) { |
| (*interp)->trace_gpu_memory_deallocation( |
| c10::kXPU, reinterpret_cast<uintptr_t>(block->ptr)); |
| } |
| } |
| |
| void emptyCache() { |
| for (auto& da : device_allocators) { |
| da->emptyCache(); |
| } |
| } |
| |
| void recordStream(const DataPtr& ptr, XPUStream stream) { |
| if (!ptr.get()) { |
| return; |
| } |
| if (ptr.get_deleter() != &local_raw_delete) { |
| return; |
| } |
| |
| Block* block = get_allocated_block(ptr.get()); |
| TORCH_CHECK(block, "No allocated block can be found."); |
| device_allocators[block->device]->recordStream(block, stream); |
| } |
| |
| DataPtr allocate(size_t size) override { |
| auto device = c10::xpu::current_device(); |
| void* r = nullptr; |
| if (size != 0) { |
| this->malloc(&r, device, size, xpu::getCurrentXPUStream(device)); |
| } |
| return {r, r, &local_raw_delete, Device(DeviceType::XPU, device)}; |
| } |
| |
| DeleterFnPtr raw_deleter() const override { |
| return &local_raw_delete; |
| } |
| |
| void* raw_alloc(size_t size) { |
| if (size == 0) { |
| return nullptr; |
| } |
| auto device = c10::xpu::current_device(); |
| void* r = nullptr; |
| malloc(&r, device, size, xpu::getCurrentXPUStream(device)); |
| return r; |
| } |
| |
| void* raw_alloc_with_stream(size_t size, XPUStream stream) { |
| if (size == 0) { |
| return nullptr; |
| } |
| auto device = c10::xpu::current_device(); |
| void* r = nullptr; |
| malloc(&r, device, size, stream); |
| return r; |
| } |
| |
| void raw_delete(void* ptr) { |
| this->free(ptr); |
| } |
| |
| void copy_data(void* dest, const void* src, std::size_t count) const final { |
| xpu::getCurrentXPUStream().queue().memcpy(dest, src, count); |
| } |
| |
| void assertValidDevice(DeviceIndex device) { |
| const auto device_num = device_allocators.size(); |
| TORCH_CHECK( |
| 0 <= device && device < static_cast<int64_t>(device_num), |
| "Invalid device argument ", |
| device, |
| ": did you call init?"); |
| } |
| |
| DeviceStats getDeviceStats(DeviceIndex device) { |
| assertValidDevice(device); |
| return device_allocators[device]->getStats(); |
| } |
| |
| void resetPeakStats(DeviceIndex device) { |
| assertValidDevice(device); |
| device_allocators[device]->resetPeakStats(); |
| } |
| |
| void resetAccumulatedStats(DeviceIndex device) { |
| assertValidDevice(device); |
| device_allocators[device]->resetAccumulatedStats(); |
| } |
| }; |
| |
| static XPUAllocator allocator; |
| |
| void local_raw_delete(void* ptr) { |
| allocator.free(ptr); |
| } |
| |
| Allocator* get() { |
| return &allocator; |
| } |
| |
| void init(DeviceIndex device_count) { |
| return allocator.init(device_count); |
| } |
| |
| void emptyCache() { |
| return allocator.emptyCache(); |
| } |
| |
| void resetPeakStats(DeviceIndex device) { |
| return allocator.resetPeakStats(device); |
| } |
| |
| void resetAccumulatedStats(DeviceIndex device) { |
| return allocator.resetAccumulatedStats(device); |
| } |
| |
| DeviceStats getDeviceStats(DeviceIndex device) { |
| return allocator.getDeviceStats(device); |
| } |
| |
| void* raw_alloc(size_t size) { |
| return allocator.raw_alloc(size); |
| } |
| |
| void raw_delete(void* ptr) { |
| return allocator.raw_delete(ptr); |
| } |
| |
| void recordStream(const DataPtr& dataPtr, XPUStream stream) { |
| return allocator.recordStream(dataPtr, stream); |
| } |
| |
| REGISTER_ALLOCATOR(kXPU, &allocator) |
| |
| } // namespace c10::xpu::XPUCachingAllocator |