| |
| #include <c10/cuda/CUDACachingAllocator.h> |
| |
| #include <c10/cuda/CUDAException.h> |
| #include <c10/cuda/CUDAFunctions.h> |
| #include <c10/cuda/CUDAGuard.h> |
| #include <c10/util/UniqueVoidPtr.h> |
| #include <c10/util/flat_hash_map.h> |
| #include <c10/util/irange.h> |
| #include <c10/util/llvmMathExtras.h> |
| |
| #include <cuda_runtime_api.h> |
| #include <algorithm> |
| #include <bitset> |
| #include <deque> |
| #include <iterator> |
| #include <map> |
| #include <memory> |
| #include <mutex> |
| #include <regex> |
| #include <set> |
| #include <vector> |
| |
| namespace c10 { |
| |
| C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); |
| |
| namespace cuda { |
| namespace CUDACachingAllocator { |
| |
| // |
| // Yet another caching allocator for CUDA device allocations. |
| // |
| // - Allocations are associated with a stream. Once freed, blocks can be |
| // re-allocated on the same stream, but not on any other stream. |
| // - The allocator attempts to find the smallest cached block that will fit the |
| // requested size. If the block is larger than the requested size, it may be |
| // split. If no block is found, the allocator will delegate to cudaMalloc. |
| // - If the cudaMalloc fails, the allocator will attempt to free one cached |
| // block of sufficient size that is not split and retry the allocation. |
| // If this also fails, the allocator will attempt to free all cached blocks |
| // that are not split and retry the allocation. |
| // - Large (>1MB) and small allocations are stored in separate pools. |
| // Small requests are packed into 2MB buffers. Large requests will use the |
| // smallest available free block or allocate a new block using cudaMalloc. |
| // - To reduce fragmentation, requests between 1MB and 10MB will allocate and |
| // split a 20MB block, if no free block of sufficient size is available. |
| // - To further reduce fragmentation, blocks >= 200MB are not allowed to be |
| // split. These oversize cached blocks will still satisfy requests within |
| // 20MB of the oversize cached block size. |
| // |
| // With this allocator, allocations and frees should logically be considered |
| // "usages" of the memory segment associated with streams, just like kernel |
| // launches. The programmer must insert the proper synchronization if memory |
| // segments are used from multiple streams. |
| // |
| // The library provides a recordStream() function to help insert the correct |
| // synchronization when allocations are used on multiple streams. This will |
| // ensure that the block is not reused before each recorded stream completes |
| // work. |
| // |
| |
| /** |
| * Note [Interaction with CUDA graph capture] |
| * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| * Graph capture performs a dry run of a region of execution, freezing all CUDA |
| * work (and virtual addresses used during that work) into a "graph." The graph |
| * may be "replayed" like a single giant kernel, with greatly reduced CPU |
| * overhead as well as modestly improved GPU performance. |
| * |
| * Because capture bakes in memory addresses, the memory used during capture |
| * must be available for the graph to use during replay. DeviceCachingAllocator |
| * assigns and frees memory eagerly and dynamically, so if we're not careful |
| * about managing graphs' memory, at replay time those memory addresses could be |
| * use by other tensors. |
| * |
| * To guarantee a graph's baked in addresses are safe to reuse in replay, |
| * DeviceAllocator satisfies allocations from a graph-private memory pool during |
| * capture, and doesn't begin cudaFreeing those addresses until the graph is |
| * destroyed. |
| * |
| * Within the private pool, allocations are freed and reassigned as usual during |
| * capture. Memory regions will be used in a consistent order during replay. So |
| * a private pool doesn't use memory more wastefully than the default pools |
| * during capture, but it does reserve its high-water mark of used memory away |
| * from the default pools as long as the capture(s) it served survive |
| * (regardless whether those captures are idle or replaying). |
| * |
| * CUDAGraph's requests for private pools are mediated by |
| * DeviceAllocator::notifyCaptureBegin, notifyCaptureEnd, and |
| * notifyCaptureDestroy. |
| */ |
| |
| namespace { |
| |
| using stream_set = ska::flat_hash_set<cuda::CUDAStream>; |
| |
| constexpr size_t kMinBlockSize = |
| 512; // all sizes are rounded to at least 512 bytes |
| constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB |
| constexpr size_t kSmallBuffer = |
| 2097152; // "small" allocations are packed in 2 MiB blocks |
| constexpr size_t kLargeBuffer = |
| 20971520; // "large" allocations may be packed in 20 MiB blocks |
| constexpr size_t kMinLargeAlloc = |
| 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer |
| constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB |
| |
| using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>; |
| |
| void update_stat(Stat& stat, int64_t amount) { |
| stat.current += amount; |
| |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
| stat.current >= 0, |
| "Negative tracked stat in CUDA allocator (likely logic error)."); |
| |
| stat.peak = std::max(stat.current, stat.peak); |
| if (amount > 0) { |
| stat.allocated += amount; |
| } |
| if (amount < 0) { |
| stat.freed += -amount; |
| } |
| } |
| |
| void reset_accumulated_stat(Stat& stat) { |
| stat.allocated = 0; |
| stat.freed = 0; |
| } |
| |
| void reset_peak_stat(Stat& stat) { |
| stat.peak = stat.current; |
| } |
| |
| template <typename Func> |
| void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { |
| for (const auto stat_type : c10::irange(stat_types.size())) { |
| if (stat_types[stat_type]) { |
| f(stat_type); |
| } |
| } |
| } |
| |
| void update_stat_array( |
| StatArray& stat_array, |
| int64_t amount, |
| const StatTypes& stat_types) { |
| for_each_selected_stat_type( |
| stat_types, [&stat_array, amount](size_t stat_type) { |
| update_stat(stat_array[stat_type], amount); |
| }); |
| } |
| |
| struct Block; |
| struct PrivatePool; |
| typedef bool (*Comparison)(const Block*, const Block*); |
| |
| struct BlockPool { |
| BlockPool( |
| Comparison comparator, |
| bool small, |
| PrivatePool* private_pool = nullptr) |
| : blocks(comparator), is_small(small), owner_PrivatePool(private_pool) {} |
| std::set<Block*, Comparison> blocks; |
| const bool is_small; |
| PrivatePool* owner_PrivatePool; |
| }; |
| |
| struct Block { |
| int device; // gpu |
| cudaStream_t stream; // allocation stream |
| stream_set stream_uses; // streams on which the block was used |
| size_t size; // block size in bytes |
| BlockPool* pool; // owning memory pool |
| void* ptr; // memory address |
| bool allocated; // in-use flag |
| Block* prev; // prev block if split from a larger allocation |
| Block* next; // next block if split from a larger allocation |
| int event_count; // number of outstanding CUDA events |
| int gc_count; // counter for prioritizing older / less useful blocks for |
| // garbage collection |
| |
| Block( |
| int device, |
| cudaStream_t stream, |
| size_t size, |
| BlockPool* pool, |
| void* ptr) |
| : device(device), |
| stream(stream), |
| stream_uses(), |
| size(size), |
| pool(pool), |
| ptr(ptr), |
| allocated(0), |
| prev(nullptr), |
| next(nullptr), |
| event_count(0), |
| gc_count(0) {} |
| |
| // constructor for search key |
| Block(int device, cudaStream_t stream, size_t size) |
| : device(device), |
| stream(stream), |
| stream_uses(), |
| size(size), |
| pool(nullptr), |
| ptr(nullptr), |
| allocated(0), |
| prev(nullptr), |
| next(nullptr), |
| event_count(0), |
| gc_count(0) {} |
| |
| bool is_split() const { |
| return (prev != nullptr) || (next != nullptr); |
| } |
| }; |
| |
| static bool BlockComparator(const Block* a, const Block* b) { |
| if (a->stream != b->stream) { |
| return (uintptr_t)a->stream < (uintptr_t)b->stream; |
| } |
| if (a->size != b->size) { |
| return a->size < b->size; |
| } |
| return (uintptr_t)a->ptr < (uintptr_t)b->ptr; |
| } |
| |
| static std::string format_size(uint64_t size) { |
| std::ostringstream os; |
| os.precision(2); |
| os << std::fixed; |
| if (size <= 1024) { |
| os << size << " bytes"; |
| } else if (size <= 1048576) { |
| os << (size / 1024.0); |
| os << " KiB"; |
| } else if (size <= 1073741824ULL) { |
| os << size / 1048576.0; |
| os << " MiB"; |
| } else { |
| os << size / 1073741824.0; |
| os << " GiB"; |
| } |
| return os.str(); |
| } |
| |
| struct AllocParams { |
| AllocParams( |
| int device, |
| size_t size, |
| cudaStream_t stream, |
| BlockPool* pool, |
| size_t alloc_size, |
| DeviceStats& stats) |
| : search_key(device, stream, size), |
| pool(pool), |
| alloc_size(alloc_size), |
| block(nullptr), |
| err(cudaSuccess) {} |
| |
| int device() const { |
| return search_key.device; |
| } |
| cudaStream_t stream() const { |
| return search_key.stream; |
| } |
| size_t size() const { |
| return search_key.size; |
| } |
| |
| Block search_key; |
| BlockPool* pool; |
| size_t alloc_size; |
| Block* block; |
| StatTypes stat_types = {false}; |
| cudaError_t err; |
| }; |
| |
| // CUDA graphs helper |
| struct PrivatePool { |
| PrivatePool() |
| : use_count(1), |
| cudaMalloc_count(0), |
| large_blocks(BlockComparator, /*is_small=*/false, this), |
| small_blocks(BlockComparator, /*is_small=*/true, this) {} |
| PrivatePool(const PrivatePool&) = delete; |
| PrivatePool(PrivatePool&&) = delete; |
| PrivatePool& operator=(const PrivatePool&) = delete; |
| // Number of live graphs using this pool |
| int use_count; |
| // Number of unfreed cudaMallocs made for this pool. When use_count and |
| // cudaMalloc_count drop to zero, we can delete this PrivatePool from |
| // graph_pools. |
| int cudaMalloc_count; |
| // Instead of maintaining private BlockPools here, I could stuff all blocks |
| // (private or no) into the top-level large_blocks and small_blocks, and |
| // distinguish private blocks by adding a "pool id" check above the stream |
| // check in BlockComparator. BlockComparator is performance- critial though, |
| // I'd rather not add more logic to it. |
| BlockPool large_blocks; |
| BlockPool small_blocks; |
| }; |
| |
| struct MempoolIdHash { |
| std::size_t operator()(const MempoolId_t& mempool_id) const noexcept { |
| return mempool_id.first != 0 ? mempool_id.first : mempool_id.second; |
| } |
| }; |
| |
| cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { |
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
| if (at::cuda::currentStreamCaptureStatusMayInitCtx() == |
| at::cuda::CaptureStatus::None) { |
| #endif |
| return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size)); |
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
| } else { |
| // It's ok to capture cudaMallocs, as long as we never cudaFree those |
| // addresses before replay. |
| // Capturing cudaMalloc behaves nicely: it gives the graph new VA, |
| // but is ignored (won't leakily allocate new memory) in replays. |
| at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed}; |
| return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size)); |
| } |
| #endif |
| } |
| |
| } // namespace |
| |
| class CachingAllocatorConfig { |
| public: |
| static size_t max_split_size() { |
| return instance().m_max_split_size; |
| } |
| static double garbage_collection_threshold() { |
| return instance().m_garbage_collection_threshold; |
| } |
| |
| // This is used to round-up allocation size to nearest power of 2 divisions. |
| // More description below in function roundup_power2_next_division |
| // As ane example, if we want 4 divisions between 2's power, this can be done |
| // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 |
| static size_t roundup_power2_divisions() { |
| return instance().m_roundup_power2_divisions; |
| } |
| |
| private: |
| static CachingAllocatorConfig& instance() { |
| static CachingAllocatorConfig* s_instance = ([]() { |
| auto inst = new CachingAllocatorConfig(); |
| inst->parseArgs(); |
| return inst; |
| })(); |
| return *s_instance; |
| } |
| |
| CachingAllocatorConfig() |
| : m_max_split_size(std::numeric_limits<size_t>::max()), |
| m_roundup_power2_divisions(0), |
| m_garbage_collection_threshold(0) {} |
| size_t m_max_split_size; |
| size_t m_roundup_power2_divisions; |
| double m_garbage_collection_threshold; |
| |
| void parseArgs() { |
| const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF"); |
| if (val != NULL) { |
| const std::string config(val); |
| |
| std::regex exp("[\\s,]+"); |
| std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); |
| std::sregex_token_iterator end; |
| std::vector<std::string> options(it, end); |
| |
| for (auto option : options) { |
| std::regex exp2("[:]+"); |
| std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); |
| std::sregex_token_iterator end2; |
| std::vector<std::string> kv(it2, end2); |
| if (kv.size() >= 2) { |
| /* Maximum split size in MB. Limited to large size blocks */ |
| if (kv[0].compare("max_split_size_mb") == 0) { |
| size_t val2 = stoi(kv[1]); |
| TORCH_CHECK( |
| val2 > kLargeBuffer / (1024 * 1024), |
| "CachingAllocator option max_split_size_mb too small, must be > ", |
| kLargeBuffer / (1024 * 1024), |
| ""); |
| val2 = std::max(val2, kLargeBuffer / (1024 * 1024)); |
| val2 = std::min( |
| val2, (std::numeric_limits<size_t>::max() / (1024 * 1024))); |
| m_max_split_size = val2 * 1024 * 1024; |
| } else if (kv[0].compare("roundup_power2_divisions") == 0) { |
| size_t val2 = stoi(kv[1]); |
| TORCH_CHECK( |
| llvm::isPowerOf2_64(val2), |
| "For roundups, the divisons has to be power of 2 ", |
| ""); |
| m_roundup_power2_divisions = val2; |
| } else if (kv[0].compare("garbage_collection_threshold") == 0) { |
| /* |
| * Perform garbage collection of GPU memory blocks to avoid |
| * triggering expensive sync-and-reclaim-all operation. Upon setting |
| * the threshold (e.g., 0.8), the allocator will start reclaiming |
| * blocks if GPU memory capacity usage exceeds the threshold (i.e., |
| * 80% of total memory). |
| * Values 0.0 and 1.0 are not allowed as they are less meaningful. |
| */ |
| double val2 = stod(kv[1]); |
| TORCH_CHECK( |
| val2 > 0, |
| "garbage_collect_threshold too small, set it 0.0~1.0", |
| ""); |
| TORCH_CHECK( |
| val2 < 1.0, |
| "garbage_collect_threshold too big, set it 0.0~1.0", |
| ""); |
| m_garbage_collection_threshold = val2; |
| } else { |
| TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]); |
| } |
| } |
| } |
| } |
| } |
| }; |
| |
| class DeviceCachingAllocator { |
| private: |
| // lock around all operations |
| mutable std::recursive_mutex mutex; |
| |
| // device statistics |
| DeviceStats stats; |
| |
| // unallocated cached blocks larger than 1 MB |
| BlockPool large_blocks; |
| |
| // unallocated cached blocks 1 MB or smaller |
| BlockPool small_blocks; |
| |
| // allocated or in use by a stream. Holds all active allocations, |
| // whether they came from graph_pools or one of the BlockPools above. |
| ska::flat_hash_set<Block*> active_blocks; |
| |
| // captures_underway tracks if a capture might be underway on any stream. |
| // Most of the time it's zero, in which case malloc can avoid calling |
| // cudaStreamGetCaptureInfo in the hot path. |
| int captures_underway = 0; |
| // See free() for this thing's purpose |
| std::vector<Block*> needs_events_deferred_until_no_capture; |
| // outstanding cuda events |
| ska::flat_hash_map< |
| cuda::CUDAStream, |
| std::deque<std::pair<cudaEvent_t, Block*>>> |
| cuda_events; |
| |
| // record used memory. |
| size_t total_allocated_memory = 0; |
| |
| size_t allowed_memory_maximum = 0; |
| |
| bool set_fraction = false; |
| |
| // Members specific to CUDA graphs |
| |
| // Private pools for CUDA graphs |
| ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash> |
| graph_pools; |
| // Pools no longer referenced by any graph. Their BlockPools are eligible for |
| // free_blocks. Can't be a vector or deque because we might erase entries in |
| // any order. Could be an std::list, but we don't care much, access and |
| // insert/erase are rare. |
| ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash> |
| graph_pools_freeable; |
| |
| // Maps a capturing stream to its assigned private pool, |
| // in case we want multiple captures to share the same pool |
| ska::flat_hash_map<CaptureId_t, MempoolId_t> capture_to_pool_map; |
| |
| public: |
| DeviceCachingAllocator() |
| : large_blocks(BlockComparator, /*is_small=*/false), |
| small_blocks(BlockComparator, /*is_small=*/true) { |
| stats.max_split_size = CachingAllocatorConfig::max_split_size(); |
| } |
| |
| // All public methods (except the above) acquire the allocator mutex. |
| // Thus, do not call a public method from another public method. |
| |
| Block* malloc(int device, size_t size, cudaStream_t stream) { |
| std::unique_lock<std::recursive_mutex> lock(mutex); |
| |
| if (C10_LIKELY(captures_underway == 0)) { |
| // Processes end-of-life events for outstanding allocations used on |
| // multiple streams (checks if their GPU-side uses are complete and |
| // recycles their memory if so) |
| // |
| // Q. Why skip process_events if a capture might be underway? |
| // A. process_events involves cudaEventQueries, illegal during CUDA graph |
| // capture. |
| // Dumb simple solution: defer reclaiming these allocations until after |
| // capture. Cross-stream memory use is uncommon, so the deferral's |
| // effect on memory use during capture should be small. |
| process_events(); |
| } |
| |
| size = round_size(size); |
| auto& pool = get_pool(size, stream); |
| const size_t alloc_size = get_allocation_size(size); |
| AllocParams params(device, size, stream, &pool, alloc_size, stats); |
| params.stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
| params.stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true; |
| |
| // First, try to get a block from the existing pool. |
| bool block_found = |
| // Search pool |
| get_free_block(params) |
| // Trigger callbacks and retry search |
| || (trigger_free_memory_callbacks(params) && get_free_block(params)); |
| |
| // Can't reuse an existing block; try to get a new one. |
| if (!block_found) { |
| // Do garbage collection if the flag is set. |
| if (C10_UNLIKELY( |
| set_fraction && |
| CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) { |
| garbage_collect_cached_blocks(); |
| } |
| // Attempt allocate |
| block_found = alloc_block(params, false) |
| // Free enough available cached blocks to satisfy alloc and retry |
| // alloc. |
| || (release_available_cached_blocks(params) && |
| alloc_block(params, false)) |
| // Free all non-split cached blocks and retry alloc. |
| || (C10_LIKELY(captures_underway == 0) && release_cached_blocks() && |
| alloc_block(params, true)); |
| } |
| |
| if (!block_found) { |
| // For any error code other than cudaErrorMemoryAllocation, |
| // alloc_block should have thrown an exception already. |
| TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation); |
| |
| size_t device_free; |
| size_t device_total; |
| C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); |
| std::string allowed_info; |
| |
| if (set_fraction) { |
| allowed_info = format_size(allowed_memory_maximum) + " allowed; "; |
| } |
| |
| stats.num_ooms += 1; |
| |
| // "total capacity": total global memory on GPU |
| // "allowed": memory is allowed to use, which set by fraction. |
| // "already allocated": memory allocated by the program using the |
| // caching allocator |
| // "free": free memory as reported by the CUDA API |
| // "cached": memory held by the allocator but not used by the program |
| // |
| // The "allocated" amount does not include memory allocated outside |
| // of the caching allocator, such as memory allocated by other programs |
| // or memory held by the driver. |
| // |
| // The sum of "allocated" + "free" + "cached" may be less than the |
| // total capacity due to memory held by the driver and usage by other |
| // programs. |
| // |
| // Note that at this point free_cached_blocks has already returned all |
| // possible "cached" memory to the driver. The only remaining "cached" |
| // memory is split from a larger block that is partially in-use. |
| TORCH_CHECK_WITH( |
| CUDAOutOfMemoryError, |
| false, |
| "CUDA out of memory. Tried to allocate ", |
| format_size(alloc_size), |
| " (GPU ", |
| device, |
| "; ", |
| format_size(device_total), |
| " total capacity; ", |
| format_size( |
| stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)] |
| .current), |
| " already allocated; ", |
| format_size(device_free), |
| " free; ", |
| allowed_info, |
| format_size( |
| stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)] |
| .current), |
| " reserved in total by PyTorch)", |
| " If reserved memory is >> allocated memory try setting max_split_size_mb to avoid" |
| " fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", |
| ""); |
| } |
| |
| TORCH_INTERNAL_ASSERT( |
| params.err == cudaSuccess && params.block != nullptr && |
| params.block->ptr != nullptr); |
| Block* block = params.block; |
| Block* remaining = nullptr; |
| |
| const bool already_split = block->is_split(); |
| if (should_split(block, size)) { |
| remaining = block; |
| |
| block = new Block(device, stream, size, &pool, block->ptr); |
| block->prev = remaining->prev; |
| if (block->prev) { |
| block->prev->next = block; |
| } |
| block->next = remaining; |
| |
| remaining->prev = block; |
| remaining->ptr = static_cast<char*>(remaining->ptr) + size; |
| remaining->size -= size; |
| bool inserted = pool.blocks.insert(remaining).second; |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); |
| |
| if (already_split) { |
| // An already-split inactive block is being shrunk by size bytes. |
| update_stat_array( |
| stats.inactive_split_bytes, -block->size, params.stat_types); |
| } else { |
| // A new split inactive block is being created from a previously unsplit |
| // block, size remaining->size bytes. |
| for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { |
| update_stat(stats.inactive_split_bytes[stat_type], remaining->size); |
| update_stat(stats.inactive_split[stat_type], 1); |
| }); |
| } |
| } else if (already_split) { |
| // An already-split block is becoming active |
| for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { |
| update_stat(stats.inactive_split_bytes[stat_type], -block->size); |
| update_stat(stats.inactive_split[stat_type], -1); |
| }); |
| } |
| |
| block->allocated = true; |
| bool inserted = active_blocks.insert(block).second; |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); |
| |
| for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { |
| update_stat(stats.allocation[stat_type], 1); |
| update_stat(stats.allocated_bytes[stat_type], block->size); |
| update_stat(stats.active[stat_type], 1); |
| update_stat(stats.active_bytes[stat_type], block->size); |
| }); |
| if (block->size >= CachingAllocatorConfig::max_split_size()) |
| update_stat(stats.oversize_allocations, 1); |
| |
| c10::reportMemoryUsageToProfiler( |
| block->ptr, |
| block->size, |
| stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
| stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
| c10::Device(c10::DeviceType::CUDA, device)); |
| |
| return block; |
| } |
| |
| void free(Block* block) { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| |
| block->allocated = false; |
| |
| // following logic might modifying underlaying Block, causing the size |
| // changed. We store ahead for reporting |
| auto orig_block_ptr = block->ptr; |
| auto orig_block_size = block->size; |
| |
| StatTypes stat_types = {false}; |
| stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
| stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] = |
| true; |
| for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
| update_stat(stats.allocation[stat_type], -1); |
| update_stat(stats.allocated_bytes[stat_type], -block->size); |
| }); |
| if (block->size >= CachingAllocatorConfig::max_split_size()) |
| update_stat(stats.oversize_allocations, -1); |
| |
| if (!block->stream_uses.empty()) { |
| if (C10_UNLIKELY(captures_underway)) { |
| // It's forbidden to cudaEventQuery an event recorded during CUDA graph |
| // capture. We conservatively defer recording end-of-life events until |
| // the next call to process_events() (which won't happen until no |
| // captures are underway) |
| needs_events_deferred_until_no_capture.push_back(block); |
| } else { |
| insert_events(block); |
| } |
| } else { |
| free_block(block); |
| } |
| |
| c10::reportMemoryUsageToProfiler( |
| orig_block_ptr, |
| -orig_block_size, |
| stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
| stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
| c10::Device(c10::DeviceType::CUDA, block->device)); |
| } |
| |
| void* getBaseAllocation(Block* block, size_t* outSize) { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| while (block->prev) { |
| block = block->prev; |
| } |
| void* basePtr = block->ptr; |
| if (outSize) { |
| size_t size = 0; |
| while (block) { |
| size += block->size; |
| block = block->next; |
| } |
| *outSize = size; |
| } |
| return basePtr; |
| } |
| |
| void recordStream(Block* block, cuda::CUDAStream stream) { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| if (stream.stream() == block->stream) { |
| // ignore uses on the allocation stream, since those don't require any |
| // special synchronization |
| return; |
| } |
| block->stream_uses.insert(stream); |
| } |
| |
| /** set memory fraction to limit maximum allocated memory **/ |
| void setMemoryFraction(double fraction) { |
| size_t device_free; |
| size_t device_total; |
| C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); |
| allowed_memory_maximum = static_cast<size_t>(fraction * device_total); |
| set_fraction = true; |
| } |
| |
| /** returns cached blocks to the system allocator **/ |
| void emptyCache() { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| release_cached_blocks(); |
| } |
| |
| /** Retrieves info (total size + largest block) of the memory cache **/ |
| void cacheInfo(size_t* total, size_t* largest) { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| if (*largest == |
| 0) { // make an initial guess if a zero *largest is passed in |
| size_t tmp_bytes; |
| C10_CUDA_CHECK(cudaMemGetInfo( |
| largest, // Use free memory as an optimistic initial guess of *largest |
| &tmp_bytes)); |
| } |
| cache_info_aux(large_blocks, total, largest); |
| cache_info_aux(small_blocks, total, largest); |
| for (const auto& gp : graph_pools) { |
| cache_info_aux(gp.second->large_blocks, total, largest); |
| cache_info_aux(gp.second->small_blocks, total, largest); |
| } |
| } |
| |
| /** Returns a copy of the memory allocator stats **/ |
| DeviceStats getStats() { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| return stats; |
| } |
| |
| /** Resets the historical accumulation stats for the device **/ |
| void resetAccumulatedStats() { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| |
| for (const auto statType : |
| c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { |
| reset_accumulated_stat(stats.allocation[statType]); |
| reset_accumulated_stat(stats.segment[statType]); |
| reset_accumulated_stat(stats.active[statType]); |
| reset_accumulated_stat(stats.inactive_split[statType]); |
| reset_accumulated_stat(stats.allocated_bytes[statType]); |
| reset_accumulated_stat(stats.reserved_bytes[statType]); |
| reset_accumulated_stat(stats.active_bytes[statType]); |
| reset_accumulated_stat(stats.inactive_split_bytes[statType]); |
| } |
| |
| stats.num_alloc_retries = 0; |
| stats.num_ooms = 0; |
| reset_accumulated_stat(stats.oversize_allocations); |
| reset_accumulated_stat(stats.oversize_segments); |
| } |
| |
| /** Resets the historical peak stats for the device **/ |
| void resetPeakStats() { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| |
| for (const auto statType : |
| c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { |
| reset_peak_stat(stats.allocation[statType]); |
| reset_peak_stat(stats.segment[statType]); |
| reset_peak_stat(stats.active[statType]); |
| reset_peak_stat(stats.inactive_split[statType]); |
| reset_peak_stat(stats.allocated_bytes[statType]); |
| reset_peak_stat(stats.reserved_bytes[statType]); |
| reset_peak_stat(stats.active_bytes[statType]); |
| reset_peak_stat(stats.inactive_split_bytes[statType]); |
| } |
| reset_peak_stat(stats.oversize_allocations); |
| reset_peak_stat(stats.oversize_segments); |
| } |
| |
| /** Dump a complete snapshot of the memory held by the allocator. Potentially |
| * VERY expensive. **/ |
| std::vector<SegmentInfo> snapshot() const { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| |
| std::vector<SegmentInfo> result; |
| const auto all_blocks = get_all_blocks(); |
| |
| for (const Block* const head_block : all_blocks) { |
| if (head_block->prev != nullptr) { |
| continue; |
| } |
| result.emplace_back(); |
| SegmentInfo& segment_info = result.back(); |
| segment_info.device = head_block->device; |
| segment_info.address = reinterpret_cast<int64_t>(head_block->ptr); |
| segment_info.is_large = (!head_block->pool->is_small); |
| |
| const Block* block = head_block; |
| while (block != nullptr) { |
| segment_info.blocks.emplace_back(); |
| BlockInfo& block_info = segment_info.blocks.back(); |
| |
| block_info.size = block->size; |
| block_info.allocated = block->allocated; |
| block_info.active = block->allocated || (block->event_count > 0) || |
| !block->stream_uses.empty(); |
| |
| segment_info.total_size += block_info.size; |
| if (block_info.allocated) { |
| segment_info.allocated_size += block_info.size; |
| } |
| if (block_info.active) { |
| segment_info.active_size += block_info.size; |
| } |
| |
| block = block->next; |
| } |
| } |
| |
| std::sort( |
| result.begin(), |
| result.end(), |
| [](const SegmentInfo& a, const SegmentInfo& b) { |
| return a.address < b.address; |
| }); |
| |
| return result; |
| } |
| |
| // This function takes the size and number of divisions argument and rounds |
| // up the size argument for the nearest power-of-2 division. |
| // For example, if we need to round-up 1200 and number of divisions is 4, |
| // the size 1200 lies between 1024 and 2048 and if we do 4 divisions between |
| // them, the values are 1024, 1280, 1536, and 1792. So the function will |
| // return 1280 as the nearest ceiling of power-2 divison. |
| static size_t roundup_power2_next_division(size_t size, size_t divisions) { |
| if (C10_UNLIKELY(size <= 4 || divisions <= 1)) { |
| return size; |
| } |
| if (llvm::isPowerOf2_64(size)) { |
| return size; |
| } |
| |
| // divide the space between these 2's power into equal divisions |
| // If division is zero, return the power-of-2 ceiling. |
| size_t power2_floor = llvm::PowerOf2Floor(size); |
| size_t power2_divison = |
| power2_floor >> (63 - llvm::countLeadingZeros(divisions)); |
| if (C10_UNLIKELY(power2_divison == 0)) { |
| return (power2_floor << 1); |
| } |
| size_t round_size_floor = size & (~(power2_divison - 1)); |
| return (round_size_floor == size) ? size |
| : round_size_floor + power2_divison; |
| } |
| |
| static size_t round_size(size_t size) { |
| if (size < kMinBlockSize) { |
| return kMinBlockSize; |
| } else { |
| auto divisions = CachingAllocatorConfig::roundup_power2_divisions(); |
| if (divisions > 0 && size > (kMinBlockSize * divisions)) { |
| return roundup_power2_next_division(size, divisions); |
| } else { |
| return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); |
| } |
| } |
| } |
| |
| // See Note [Interaction with CUDA graph capture] |
| |
| // Called by CUDAGraph::capture_begin |
| void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| captures_underway++; |
| auto it = graph_pools.find(mempool_id); |
| if (it == graph_pools.end()) { |
| // mempool_id does not reference an existing pool. Make a new pool for |
| // this capture. |
| graph_pools.emplace(mempool_id, std::make_unique<PrivatePool>()); |
| } else { |
| // mempool_id references an existing pool, which the current capture will |
| // share. Check this pool is live (at least one other capture already |
| // references it). |
| TORCH_INTERNAL_ASSERT(it->second->use_count > 0); |
| it->second->use_count++; |
| } |
| // Maps this graph_id to mempool_id and makes sure this graph_id wasn't |
| // somehow assigned a mempool_id already. Keeps essential effect (insert) |
| // out of macro. |
| bool inserted = capture_to_pool_map.insert({graph_id, mempool_id}).second; |
| TORCH_INTERNAL_ASSERT(inserted); |
| } |
| |
| // Called by CUDAGraph::capture_end |
| void notifyCaptureEnd(CaptureId_t graph_id) { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| captures_underway--; |
| auto it = capture_to_pool_map.find(graph_id); |
| TORCH_INTERNAL_ASSERT(it != capture_to_pool_map.end()); |
| capture_to_pool_map.erase(it); |
| } |
| |
| // Called by CUDAGraph::reset |
| void notifyCaptureDestroy(MempoolId_t mempool_id) { |
| std::lock_guard<std::recursive_mutex> lock(mutex); |
| // The instantiated cudaGraphExec_t has been destroyed. We can't blindly |
| // delete and cudaFree the mempool its capture used, because |
| // 1. other graph(s) might share the same pool |
| // 2. the user might still hold references to output tensors allocated |
| // during capture. |
| // To handle 1 and 2, we track the number of graphs using this particular |
| // mempool. When the count reaches 0, we tell free_cached_blocks it may now |
| // cudaFree blocks from this graph's pool when it discovers they're unused |
| // (unsplit). |
| auto it = graph_pools.find(mempool_id); |
| TORCH_INTERNAL_ASSERT(it != graph_pools.end()); |
| auto uc = --(it->second->use_count); |
| TORCH_INTERNAL_ASSERT(uc >= 0); |
| if (uc == 0) { |
| // Allows free_cached_blocks to begin cudaFreeing this pool's memory, |
| // and makes sure this pool wasn't somehow made freeable already. |
| bool inserted = |
| graph_pools_freeable.insert({mempool_id, it->second.get()}).second; |
| TORCH_INTERNAL_ASSERT(inserted); |
| } |
| } |
| |
| private: |
| // All private methods do not acquire the allocator mutex. |
| |
| std::vector<const Block*> get_all_blocks() const { |
| std::vector<const Block*> blocks; |
| blocks.insert( |
| blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end()); |
| blocks.insert( |
| blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end()); |
| for (const auto& gp : graph_pools) { |
| blocks.insert( |
| blocks.end(), |
| gp.second->small_blocks.blocks.begin(), |
| gp.second->small_blocks.blocks.end()); |
| blocks.insert( |
| blocks.end(), |
| gp.second->large_blocks.blocks.begin(), |
| gp.second->large_blocks.blocks.end()); |
| } |
| blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end()); |
| return blocks; |
| } |
| |
| /** moves a block into a pool of cached free blocks */ |
| void free_block(Block* block) { |
| TORCH_INTERNAL_ASSERT( |
| !block->allocated && block->event_count == 0 && |
| block->stream_uses.empty()); |
| |
| size_t original_block_size = block->size; |
| |
| auto& pool = *block->pool; |
| int64_t net_change_inactive_split_blocks = 0; |
| int64_t net_change_inactive_split_size = 0; |
| |
| const std::array<Block*, 2> merge_candidates = {block->prev, block->next}; |
| for (Block* merge_candidate : merge_candidates) { |
| const int64_t subsumed_size = |
| try_merge_blocks(block, merge_candidate, pool); |
| if (subsumed_size > 0) { |
| net_change_inactive_split_blocks -= 1; |
| net_change_inactive_split_size -= subsumed_size; |
| } |
| } |
| |
| active_blocks.erase(block); |
| // Makes sure the Block* isn't already present in the pool we're freeing it |
| // back into. |
| bool inserted = pool.blocks.insert(block).second; |
| TORCH_INTERNAL_ASSERT(inserted); |
| |
| if (block->is_split()) { |
| net_change_inactive_split_blocks += 1; |
| net_change_inactive_split_size += block->size; |
| } |
| |
| StatTypes stat_types = {false}; |
| stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
| stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true; |
| for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
| update_stat( |
| stats.inactive_split[stat_type], net_change_inactive_split_blocks); |
| update_stat( |
| stats.inactive_split_bytes[stat_type], |
| net_change_inactive_split_size); |
| update_stat(stats.active[stat_type], -1); |
| update_stat(stats.active_bytes[stat_type], -original_block_size); |
| }); |
| } |
| |
| /** combine previously split blocks. returns the size of the subsumed block, |
| * or 0 on failure. */ |
| size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { |
| if (!src || src->allocated || src->event_count > 0 || |
| !src->stream_uses.empty()) { |
| return 0; |
| } |
| |
| AT_ASSERT(dst->is_split() && src->is_split()); |
| |
| if (dst->prev == src) { |
| dst->ptr = src->ptr; |
| dst->prev = src->prev; |
| if (dst->prev) { |
| dst->prev->next = dst; |
| } |
| } else { |
| dst->next = src->next; |
| if (dst->next) { |
| dst->next->prev = dst; |
| } |
| } |
| |
| const size_t subsumed_size = src->size; |
| dst->size += subsumed_size; |
| auto erased = pool.blocks.erase(src); |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1); |
| delete src; |
| |
| return subsumed_size; |
| } |
| |
| BlockPool& get_pool(size_t size, cudaStream_t stream) { |
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
| // captures_underway is a conservative guess that the current stream may be |
| // capturing. It's only > 0 if some thread has begun and not yet ended a |
| // capture, so it's usually 0, and we can short-circuit |
| // cudaStreamCaptureStatus (which does a TLS lookup). |
| if (C10_UNLIKELY(captures_underway)) { |
| CaptureId_t id; |
| cudaStreamCaptureStatus status; |
| C10_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id)); |
| if (status != cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { |
| TORCH_INTERNAL_ASSERT( |
| status != |
| cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated); |
| // Retrieves the private pool assigned to this capture. |
| auto it0 = capture_to_pool_map.find(id); |
| TORCH_INTERNAL_ASSERT(it0 != capture_to_pool_map.end()); |
| auto it1 = graph_pools.find(it0->second); |
| TORCH_INTERNAL_ASSERT(it1 != graph_pools.end()); |
| if (size <= kSmallSize) { |
| return it1->second->small_blocks; |
| } else { |
| return it1->second->large_blocks; |
| } |
| } |
| } |
| #endif |
| if (size <= kSmallSize) { |
| return small_blocks; |
| } else { |
| return large_blocks; |
| } |
| } |
| |
| StatType get_stat_type_for_pool(const BlockPool& pool) { |
| return pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL; |
| } |
| |
| bool should_split(const Block* block, size_t size) { |
| size_t remaining = block->size - size; |
| if (block->pool->is_small) { |
| return remaining >= kMinBlockSize; |
| } else { |
| return (size < CachingAllocatorConfig::max_split_size()) && |
| (remaining > kSmallSize); |
| } |
| } |
| |
| static size_t get_allocation_size(size_t size) { |
| if (size <= kSmallSize) { |
| return kSmallBuffer; |
| } else if (size < kMinLargeAlloc) { |
| return kLargeBuffer; |
| } else { |
| return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); |
| } |
| } |
| |
| bool get_free_block(AllocParams& p) { |
| BlockPool& pool = *p.pool; |
| |
| if (C10_UNLIKELY( |
| set_fraction && |
| CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) { |
| // Track block reuse interval only when garbage collection is enabled. |
| for (auto& b : pool.blocks) { |
| ++b->gc_count; |
| } |
| } |
| auto it = pool.blocks.lower_bound(&p.search_key); |
| if (it == pool.blocks.end() || (*it)->stream != p.stream()) |
| return false; |
| // Do not return an oversized block for a large request |
| if ((p.size() < CachingAllocatorConfig::max_split_size()) && |
| ((*it)->size >= CachingAllocatorConfig::max_split_size())) |
| return false; |
| // Allow oversized block size to be rounded up but within a limit |
| if ((p.size() >= CachingAllocatorConfig::max_split_size()) && |
| ((*it)->size >= p.size() + kLargeBuffer)) |
| return false; |
| p.block = *it; |
| (*it)->gc_count = 0; // Denote this block has been used |
| pool.blocks.erase(it); |
| return true; |
| } |
| |
| bool trigger_free_memory_callbacks(AllocParams& p) { |
| bool freed_memory = false; |
| for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) { |
| freed_memory |= |
| FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute(); |
| } |
| return freed_memory; |
| } |
| |
| void garbage_collect_cached_blocks() { |
| // Free unused cached blocks to reclaim GPU memory. |
| // Unlike release_cached_blocks(), this does not enforce synchronization and |
| // therefore should be of less overheads. |
| |
| size_t gc_threshold = static_cast<size_t>( |
| CachingAllocatorConfig::garbage_collection_threshold() * |
| allowed_memory_maximum); |
| // No need to trigger GC yet |
| if (total_allocated_memory <= gc_threshold) { |
| return; |
| } |
| const auto target_size = total_allocated_memory - gc_threshold; |
| size_t gc_reclaimed = 0; |
| |
| // Calculate the total age of the free-able blocks. We'll use it later to |
| // get "avg age" threshold. |
| double total_age = 0.0; |
| int freeable_block_count = 0; |
| for (auto& b : large_blocks.blocks) { |
| if (!b->is_split()) { |
| total_age += b->gc_count; |
| ++freeable_block_count; |
| } |
| } |
| // No free-able blocks? |
| if (freeable_block_count == 0) { |
| return; |
| } |
| |
| // Repeat GC until we reach reclaim > target size. |
| bool block_freed = true; |
| while (gc_reclaimed < target_size && block_freed == true && |
| freeable_block_count > 0) { |
| // Free blocks exceeding this age threshold first. |
| double age_threshold = total_age / freeable_block_count; |
| // Stop iteration if we can no longer free a block. |
| block_freed = false; |
| |
| // Free blocks of > avg age. Don't stop upon reaching the target_size, |
| // we don't want this GC to be triggered frequently. |
| auto it = large_blocks.blocks.begin(); |
| while (it != large_blocks.blocks.end()) { |
| Block* block = *it; |
| ++it; |
| if (!block->is_split() && block->gc_count >= age_threshold) { |
| block_freed = true; |
| gc_reclaimed += block->size; |
| total_age -= block->gc_count; // Decrement the age |
| freeable_block_count--; // One less block that can be freed |
| release_block(block); |
| } |
| } |
| } |
| } |
| |
| bool alloc_block(AllocParams& p, bool isRetry) { |
| // Defensively checks for preexisting CUDA error state. |
| C10_CUDA_CHECK(cudaGetLastError()); |
| |
| size_t size = p.alloc_size; |
| void* ptr; |
| |
| if (isRetry) { |
| stats.num_alloc_retries += 1; |
| } |
| |
| if (set_fraction && |
| total_allocated_memory + size > allowed_memory_maximum) { |
| p.err = cudaErrorMemoryAllocation; |
| return false; |
| } else { |
| p.err = cudaMallocMaybeCapturing(&ptr, size); |
| if (p.err != cudaSuccess) { |
| if (p.err == cudaErrorMemoryAllocation) { |
| // If this is the first attempt (!isRetry), we can forgive and clear |
| // CUDA's |
| // internal error state. |
| // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH |
| // will take |
| // over to throw a helpful exception. The user can choose to catch |
| // the exception, free some stuff in their script, and attempt their |
| // allocation again. In this case, we can also forgive and clear |
| // CUDA's internal error state. |
| cudaGetLastError(); |
| } else { |
| // If the error's unrelated to memory allocation, we should throw |
| // immediately. |
| C10_CUDA_CHECK(p.err); |
| } |
| return false; |
| } |
| } |
| |
| if (p.pool->owner_PrivatePool) { |
| // The block is for a CUDA graph's PrivatePool. |
| p.pool->owner_PrivatePool->cudaMalloc_count++; |
| } |
| |
| total_allocated_memory += size; |
| p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); |
| for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { |
| update_stat(stats.segment[stat_type], 1); |
| update_stat(stats.reserved_bytes[stat_type], size); |
| }); |
| if (size >= CachingAllocatorConfig::max_split_size()) |
| update_stat(stats.oversize_segments, 1); |
| |
| // p.block came from new, not cudaMalloc. It should not be nullptr here. |
| TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); |
| return true; |
| } |
| |
| /** Free one or more oversize blocks to the system allocator. But only enough |
| * **/ |
| /** to satisfy the target size **/ |
| bool release_available_cached_blocks(const AllocParams& p) { |
| if (CachingAllocatorConfig::max_split_size() == |
| std::numeric_limits<size_t>::max()) |
| return false; |
| BlockPool& pool = *p.pool; |
| Block key = p.search_key; |
| key.size = (key.size < CachingAllocatorConfig::max_split_size()) |
| ? CachingAllocatorConfig::max_split_size() |
| : key.size; |
| auto it = pool.blocks.lower_bound(&key); |
| if (it == pool.blocks.end() || (*it)->stream != p.stream()) { |
| // No single block is large enough; free multiple oversize blocks, |
| // starting with the largest |
| if (it == pool.blocks.begin()) |
| return false; |
| size_t totalReleased = 0; |
| --it; // Back up one item. Now on the largest block for the correct |
| // stream |
| while ((totalReleased < key.size) && |
| ((*it)->size >= CachingAllocatorConfig::max_split_size()) && |
| ((*it)->stream == p.stream())) { |
| auto cur = it; |
| totalReleased += (*it)->size; |
| if (it != pool.blocks.begin()) { |
| --it; |
| release_block(*cur); |
| } else { |
| release_block(*cur); |
| break; |
| } |
| } |
| if (totalReleased < key.size) |
| return false; |
| } else { |
| release_block(*it); |
| } |
| return true; |
| } |
| |
| bool release_cached_blocks() { |
| // First ensure that all blocks that can't currently be allocated due to |
| // outstanding events are returned to the pool. |
| synchronize_and_free_events(); |
| |
| // Free all non-split cached blocks to system allocator |
| release_blocks(large_blocks); |
| release_blocks(small_blocks); |
| |
| for (auto it = graph_pools_freeable.begin(); |
| it != graph_pools_freeable.end();) { |
| // See notifyCaptureDestroy for the strategy here. |
| TORCH_INTERNAL_ASSERT(it->second->use_count == 0); |
| release_blocks(it->second->small_blocks); |
| release_blocks(it->second->large_blocks); |
| if (it->second->cudaMalloc_count == 0) { |
| auto erase_count = graph_pools.erase(it->first); |
| TORCH_INTERNAL_ASSERT(erase_count == 1); |
| it = graph_pools_freeable.erase(it); |
| } else { |
| ++it; |
| } |
| } |
| |
| return true; |
| } |
| |
| void release_block(Block* block) { |
| C10_CUDA_CHECK(cudaFree((void*)block->ptr)); |
| total_allocated_memory -= block->size; |
| |
| auto* pool = block->pool; |
| if (pool->owner_PrivatePool) { |
| // The cudaFreed block belonged to a CUDA graph's PrivatePool. |
| TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0); |
| pool->owner_PrivatePool->cudaMalloc_count--; |
| } |
| |
| StatTypes stat_types = {false}; |
| stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
| stat_types[static_cast<size_t>(get_stat_type_for_pool(*pool))] = true; |
| for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
| update_stat(stats.segment[stat_type], -1); |
| update_stat(stats.reserved_bytes[stat_type], -block->size); |
| }); |
| if (block->size >= CachingAllocatorConfig::max_split_size()) |
| update_stat(stats.oversize_segments, -1); |
| |
| pool->blocks.erase(block); |
| delete block; |
| } |
| |
| void release_blocks(BlockPool& pool) { |
| // Frees all non-split blocks |
| auto it = pool.blocks.begin(); |
| while (it != pool.blocks.end()) { |
| Block* block = *it; |
| ++it; |
| if (!block->prev && !block->next) { |
| release_block(block); |
| } |
| } |
| } |
| |
| cudaEvent_t create_event_internal() { |
| cudaEvent_t event; |
| C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); |
| return event; |
| } |
| |
| void free_event_internal(cudaEvent_t event) { |
| C10_CUDA_CHECK(cudaEventDestroy(event)); |
| } |
| |
| void synchronize_and_free_events() { |
| // Synchronize on outstanding events and then free associated blocks. |
| |
| // This function syncs, so capture should not be underway. Might as well |
| // make sure capture-deferred end of life events get processed too. |
| TORCH_INTERNAL_ASSERT(captures_underway == 0); |
| insert_events_deferred_until_no_capture(); |
| |
| for (auto& st : cuda_events) { |
| for (auto& e : st.second) { |
| cudaEvent_t event = e.first; |
| Block* block = e.second; |
| |
| C10_CUDA_CHECK(cudaEventSynchronize(event)); |
| free_event_internal(event); |
| |
| block->event_count--; |
| if (block->event_count == 0) { |
| free_block(block); |
| } |
| } |
| } |
| |
| cuda_events.clear(); |
| } |
| |
| void insert_events(Block* block) { |
| int prev_device; |
| C10_CUDA_CHECK(cudaGetDevice(&prev_device)); |
| |
| stream_set streams(std::move(block->stream_uses)); |
| AT_ASSERT(block->stream_uses.empty()); |
| for (auto& stream : streams) { |
| C10_CUDA_CHECK(cudaSetDevice(stream.device_index())); |
| |
| cudaEvent_t event = create_event_internal(); |
| C10_CUDA_CHECK(cudaEventRecord(event, stream.stream())); |
| |
| block->event_count++; |
| cuda_events[stream].emplace_back(event, block); |
| } |
| |
| C10_CUDA_CHECK(cudaSetDevice(prev_device)); |
| } |
| |
| void insert_events_deferred_until_no_capture() { |
| if (C10_UNLIKELY(needs_events_deferred_until_no_capture.size() > 0)) { |
| for (auto* block : needs_events_deferred_until_no_capture) { |
| TORCH_INTERNAL_ASSERT(!block->stream_uses.empty()); |
| insert_events(block); |
| } |
| needs_events_deferred_until_no_capture.clear(); |
| } |
| } |
| |
| void process_events() { |
| insert_events_deferred_until_no_capture(); |
| |
| // Process outstanding cudaEvents. Events that are completed are |
| // removed from the queue, and the 'event_count' for the |
| // corresponding allocation is decremented. We maintain a separate |
| // list of events per stream to avoid head-of-line delays if one |
| // or more streams has long-running operations. |
| for (auto it = cuda_events.begin(); it != cuda_events.end();) { |
| while (!it->second.empty()) { |
| auto& e = it->second.front(); |
| cudaEvent_t event = e.first; |
| Block* block = e.second; |
| |
| cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(event)); |
| if (err == cudaErrorNotReady) { |
| // ignore and clear the error if not ready |
| cudaGetLastError(); |
| break; |
| } else if (err != cudaSuccess) { |
| C10_CUDA_CHECK(err); |
| } |
| |
| free_event_internal(event); |
| |
| block->event_count--; |
| if (block->event_count == 0) { |
| free_block(block); |
| } |
| it->second.pop_front(); |
| } |
| |
| if (it->second.empty()) { |
| it = cuda_events.erase(it); |
| } else { |
| it++; |
| } |
| } |
| } |
| |
| // Accumulates sizes of all memory blocks for given device in given pool |
| void cache_info_aux(const BlockPool& pool, size_t* total, size_t* largest) { |
| for (const auto& block : pool.blocks) { |
| const auto blocksize = block->size; |
| *total += blocksize; |
| if (blocksize > *largest) { |
| *largest = blocksize; |
| } |
| } |
| } |
| }; |
| |
| class THCCachingAllocator { |
| private: |
| std::mutex mutex; |
| |
| // allocated blocks by device pointer |
| ska::flat_hash_map<void*, Block*> allocated_blocks; |
| |
| // lock around calls to cudaFree (to prevent deadlocks with NCCL) |
| mutable std::mutex cuda_free_mutex; |
| |
| void add_allocated_block(Block* block) { |
| std::lock_guard<std::mutex> lock(mutex); |
| allocated_blocks[block->ptr] = block; |
| } |
| |
| public: |
| std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator; |
| |
| std::mutex* getCudaFreeMutex() const { |
| return &cuda_free_mutex; |
| } |
| |
| Block* get_allocated_block(void* ptr, bool remove = false) { |
| std::lock_guard<std::mutex> lock(mutex); |
| auto it = allocated_blocks.find(ptr); |
| if (it == allocated_blocks.end()) { |
| return nullptr; |
| } |
| Block* block = it->second; |
| if (remove) { |
| allocated_blocks.erase(it); |
| } |
| return block; |
| } |
| |
| void init(int device_count) { |
| const auto size = static_cast<int64_t>(device_allocator.size()); |
| if (size < device_count) { |
| device_allocator.resize(device_count); |
| for (const auto i : c10::irange(size, device_count)) { |
| device_allocator[i] = std::make_unique<DeviceCachingAllocator>(); |
| } |
| } |
| } |
| |
| /** allocates a block which is safe to use from the provided stream */ |
| void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) { |
| TORCH_INTERNAL_ASSERT( |
| 0 <= device && static_cast<size_t>(device) < device_allocator.size(), |
| "Allocator not initialized for device ", |
| device, |
| ": did you call init?"); |
| Block* block = device_allocator[device]->malloc(device, size, stream); |
| add_allocated_block(block); |
| *devPtr = (void*)block->ptr; |
| } |
| |
| void free(void* ptr) { |
| if (!ptr) { |
| return; |
| } |
| Block* block = get_allocated_block(ptr, true /* remove */); |
| if (!block) { |
| TORCH_CHECK(false, "invalid device pointer: ", ptr); |
| } |
| device_allocator[block->device]->free(block); |
| } |
| |
| void setMemoryFraction(double fraction, int device) { |
| TORCH_INTERNAL_ASSERT( |
| 0 <= device && static_cast<size_t>(device) < device_allocator.size(), |
| "Allocator not initialized for device ", |
| device, |
| ": did you call init?"); |
| TORCH_INTERNAL_ASSERT( |
| 0 <= fraction && fraction <= 1, |
| "invalid fraction:", |
| fraction, |
| ". Please set within (0, 1)."); |
| int activated_device; |
| C10_CUDA_CHECK(cudaGetDevice(&activated_device)); |
| if (activated_device != device) { |
| C10_CUDA_CHECK(cudaSetDevice(device)); |
| } |
| device_allocator[device]->setMemoryFraction(fraction); |
| } |
| |
| void emptyCache() { |
| for (auto& da : device_allocator) |
| da->emptyCache(); |
| } |
| |
| void* getBaseAllocation(void* ptr, size_t* outSize) { |
| Block* block = get_allocated_block(ptr); |
| if (!block) { |
| TORCH_CHECK(false, "invalid device pointer: ", ptr); |
| } |
| return device_allocator[block->device]->getBaseAllocation(block, outSize); |
| } |
| |
| void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) { |
| // Empty tensor's storage().data() might be a null ptr. As there is no |
| // blocks associated with those tensors, it is fine to do nothing here. |
| if (!ptr.get()) { |
| return; |
| } |
| |
| // If a tensor is not allocated by this instance, simply skip |
| // This usually happens when CUDA tensors are shared across processes, |
| // we have implemented reference counting based sharing mechanism to |
| // guarantee tensors won't be accidentally freed by one process while |
| // they are still being used in another |
| if (ptr.get_deleter() != &raw_delete) |
| return; |
| |
| Block* block = get_allocated_block(ptr.get()); |
| // block must not be null reaching here |
| TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found"); |
| device_allocator[block->device]->recordStream(block, stream); |
| } |
| |
| std::vector<SegmentInfo> snapshot() { |
| std::vector<SegmentInfo> result; |
| for (auto& da : device_allocator) { |
| auto snap = da->snapshot(); |
| result.insert(result.end(), snap.begin(), snap.end()); |
| } |
| |
| return result; |
| } |
| }; |
| |
| THCCachingAllocator caching_allocator; |
| |
| // Returns whether to force all allocations to bypass the caching allocator and |
| // go straight to cudaMalloc. This setting is useful when debugging GPU memory |
| // errors, since the caching allocator foils cuda-memcheck. |
| bool forceUncachedAllocator() { |
| static bool force_uncached = |
| getenv("PYTORCH_NO_CUDA_MEMORY_CACHING") != nullptr; |
| return force_uncached; |
| } |
| |
| static void uncached_delete(void* ptr) { |
| C10_CUDA_CHECK(cudaFree(ptr)); |
| } |
| |
| // NB: I decided not to fold this into THCCachingAllocator, because the latter |
| // has a lot more methods and it wasn't altogether clear that they should |
| // actually be publicly exposed |
| struct CudaCachingAllocator : public Allocator { |
| DataPtr allocate(size_t size) const override { |
| constexpr size_t one_exa_bytes = 1152921504606846976ULL; |
| TORCH_CHECK_WITH( |
| CUDAOutOfMemoryError, |
| size < one_exa_bytes, |
| "CUDA out of memory. Tried to allocate more than 1EB memory."); |
| int device; |
| C10_CUDA_CHECK(cudaGetDevice(&device)); |
| void* r = nullptr; |
| if (forceUncachedAllocator()) { |
| // Deliberately don't use cudaMallocMaybeCapturing here, to force an error |
| // if someone tries to use forceUncachedAllocator while capturing. |
| C10_CUDA_CHECK(cudaMalloc(&r, size)); |
| return {r, r, &uncached_delete, Device(DeviceType::CUDA, device)}; |
| } |
| if (size != 0) { |
| caching_allocator.malloc( |
| &r, device, size, cuda::getCurrentCUDAStream(device)); |
| } |
| return {r, r, &raw_delete, Device(DeviceType::CUDA, device)}; |
| } |
| DeleterFnPtr raw_deleter() const override { |
| if (forceUncachedAllocator()) { |
| return &uncached_delete; |
| } else { |
| return &raw_delete; |
| } |
| } |
| }; |
| |
| CudaCachingAllocator device_allocator; |
| |
| Allocator* get(void) { |
| return &device_allocator; |
| } |
| |
| void init(int device_count) { |
| caching_allocator.init(device_count); |
| } |
| |
| void setMemoryFraction(double fraction, int device) { |
| caching_allocator.setMemoryFraction(fraction, device); |
| } |
| |
| void emptyCache(void) { |
| caching_allocator.emptyCache(); |
| } |
| |
| void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) { |
| caching_allocator.device_allocator[dev_id]->cacheInfo( |
| cachedAndFree, largestBlock); |
| } |
| |
| void* getBaseAllocation(void* ptr, size_t* size) { |
| return caching_allocator.getBaseAllocation(ptr, size); |
| } |
| |
| void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) { |
| caching_allocator.recordStream(ptr, stream); |
| } |
| |
| std::mutex* getFreeMutex() { |
| return caching_allocator.getCudaFreeMutex(); |
| } |
| |
| static inline void assertValidDevice(int device) { |
| const auto device_num = caching_allocator.device_allocator.size(); |
| TORCH_CHECK( |
| 0 <= device && device < static_cast<int64_t>(device_num), |
| "Invalid device argument ", |
| device, |
| ": did you call init?"); |
| } |
| |
| DeviceStats getDeviceStats(int device) { |
| assertValidDevice(device); |
| return caching_allocator.device_allocator[device]->getStats(); |
| } |
| |
| void resetAccumulatedStats(int device) { |
| assertValidDevice(device); |
| caching_allocator.device_allocator[device]->resetAccumulatedStats(); |
| } |
| |
| void resetPeakStats(int device) { |
| assertValidDevice(device); |
| caching_allocator.device_allocator[device]->resetPeakStats(); |
| } |
| |
| std::vector<SegmentInfo> snapshot() { |
| return caching_allocator.snapshot(); |
| } |
| |
| // CUDAGraph interactions |
| void notifyCaptureBegin( |
| int device, |
| CaptureId_t graph_id, |
| MempoolId_t mempool_id) { |
| assertValidDevice(device); |
| caching_allocator.device_allocator[device]->notifyCaptureBegin( |
| graph_id, mempool_id); |
| } |
| |
| void notifyCaptureEnd(int device, CaptureId_t graph_id) { |
| assertValidDevice(device); |
| caching_allocator.device_allocator[device]->notifyCaptureEnd(graph_id); |
| } |
| |
| void notifyCaptureDestroy(int device, MempoolId_t mempool_id) { |
| assertValidDevice(device); |
| caching_allocator.device_allocator[device]->notifyCaptureDestroy(mempool_id); |
| } |
| |
| // |
| // In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr |
| // is called by the receiving process to map the CUDA memory from the sending |
| // process into its own address space. |
| // |
| // CUDA IPC only allows sharing a big memory block associated with a |
| // cudaIpcMemHandle_t and it can be opened only **once** per context per |
| // process. There can be multiple types of storage in the same IPC mem block, so |
| // we must cache the device ptr to construct typed storage as it comes. |
| // |
| // ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the |
| // process that can be used to access the memory block in the sender process. It |
| // only saves a weak_ptr of the device pointer in the map, the shared_ptr will |
| // be used to reconstruct all storages in this CudaMalloc allocation. And it |
| // will deleted in cudaIpcCloseMemHandle when its reference count is 0. |
| // |
| namespace { |
| std::mutex IpcMutex; |
| ska::flat_hash_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr; |
| } // namespace |
| |
| std::shared_ptr<void> getIpcDevPtr(std::string handle) { |
| std::lock_guard<std::mutex> lock(IpcMutex); |
| |
| auto iter = ipcMemHandle_to_devptr.find(handle); |
| if (iter != ipcMemHandle_to_devptr.end()) { |
| auto devptr = iter->second.lock(); |
| if (devptr) |
| return devptr; |
| } |
| // This ipcMemHandle hasn't been opened, or already expired, open it to |
| // enable IPC access to that mem block. |
| void* dev = nullptr; |
| auto ipc_handle = reinterpret_cast<const cudaIpcMemHandle_t*>(handle.c_str()); |
| C10_CUDA_CHECK( |
| cudaIpcOpenMemHandle(&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess)); |
| // devPtr has to be deleted in same device when created. |
| int curr_device; |
| C10_CUDA_CHECK(cudaGetDevice(&curr_device)); |
| auto sp = std::shared_ptr<void>(dev, [handle, curr_device](void* ptr) { |
| cuda::CUDAGuard device_guard(curr_device); |
| std::lock_guard<std::mutex> deleter_lock(IpcMutex); |
| C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); |
| ipcMemHandle_to_devptr.erase(handle); |
| }); |
| std::weak_ptr<void> wp = sp; |
| // To eliminate an additional search, we can use insert(). |
| // It doesn't overwrite when key already exists(ptr expired). |
| // But in the deleter for sp we erased the entry, |
| // this should be safe to do now. |
| ipcMemHandle_to_devptr.insert(iter, {handle, wp}); |
| |
| return sp; |
| } |
| |
| void* raw_alloc(size_t nbytes) { |
| if (nbytes == 0) { |
| return nullptr; |
| } |
| int device; |
| C10_CUDA_CHECK(cudaGetDevice(&device)); |
| void* r = nullptr; |
| caching_allocator.malloc( |
| &r, device, nbytes, cuda::getCurrentCUDAStream(device)); |
| return r; |
| } |
| |
| void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) { |
| if (nbytes == 0) { |
| return nullptr; |
| } |
| int device; |
| C10_CUDA_CHECK(cudaGetDevice(&device)); |
| void* r = nullptr; |
| caching_allocator.malloc(&r, device, nbytes, stream); |
| return r; |
| } |
| |
| void raw_delete(void* ptr) { |
| caching_allocator.free(ptr); |
| } |
| |
| } // namespace CUDACachingAllocator |
| |
| } // namespace cuda |
| } // namespace c10 |