| from typing import List, Tuple |
| |
| from torch.distributed._shard.sharding_spec import ( |
| ShardMetadata, |
| ) |
| |
| __all__: List[str] = [] |
| |
| |
| def _shards_get_overlap_region_wrt_saved_tensor( |
| saved_shard: ShardMetadata, current_shard: ShardMetadata |
| ) -> List[Tuple[int, int, int, int]]: |
| """ |
| Return the overlapping region between saved_shard and current_shard. |
| There returned list has the same number of elements as the tensor's dimension. |
| For each element, we produce a tuple with the following contents: |
| (dimension, `saved_shard` offset, `current_shard` offset, length) |
| |
| Offsets are relative to each shard. |
| """ |
| narrows = [] |
| for dim, ( |
| saved_shard_offset, |
| current_shard_offset, |
| saved_shard_size, |
| current_shard_size, |
| ) in enumerate( |
| zip( |
| saved_shard.shard_offsets, |
| current_shard.shard_offsets, |
| saved_shard.shard_sizes, |
| current_shard.shard_sizes, |
| ) |
| ): |
| min_range_end = min( |
| saved_shard_offset + saved_shard_size, |
| current_shard_offset + current_shard_size, |
| ) |
| |
| length = min_range_end - max(current_shard_offset, saved_shard_offset) |
| |
| if saved_shard_offset > current_shard_offset: |
| offset_for_saved_tensor = 0 |
| offset_for_current_tensor = ( |
| saved_shard_offset - current_shard_offset |
| ) |
| else: |
| offset_for_saved_tensor = current_shard_offset - saved_shard_offset |
| offset_for_current_tensor = 0 |
| |
| narrows.append( |
| (dim, offset_for_saved_tensor, offset_for_current_tensor, length) |
| ) |
| |
| return narrows |