blob: a6d6a261068d6332406f3f2adeaf2969607d82be [file] [log] [blame]
import torch
RPC_SPARSE = "rpc_sparse"
RPC_DENSE = "rpc_dense"
def sparse_tensor_to_rpc_format(sparse_tensor):
r"""
A helper function creates a list containing the indices, values, and size
of a coalesced sparse tensor.
Args:
sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
"""
sparse_tensor = sparse_tensor.coalesce()
return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]
def sparse_rpc_format_to_tensor(sparse_rpc_format):
r"""
A helper function creates a sparse_coo_tensor from indices, values, and size.
Args:
sparse_rpc_format (list): sparse_coo_tensor represented as a list
"""
return torch.sparse_coo_tensor(
sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
).coalesce()
def process_bucket_with_remote_server(state, bucket):
r"""
Processes a gradient bucket passed by a DDP communication hook
during .backward(). The method supports processing sparse and dense
tensors. It records RPC future completion time metric for the trainer.
Args:
state (object): maintains state during the training process
bucket (GradBucket): gradient bucket
"""
cref = state.cref
tensor = bucket.buffer()
if not cref.use_cuda_rpc:
tensor = tensor.cpu()
sparse = tensor.is_sparse
if sparse:
tensor = sparse_tensor_to_rpc_format(tensor)
b_index = bucket.get_index()
server_args = [
cref.server_rref,
state.batch_number,
b_index,
tensor
]
key = state.get_key(b_index)
cref.record_start(
"hook_future_metric",
key,
RPC_SPARSE if sparse else RPC_DENSE
)
fut = cref.server_rref.rpc_async().average_gradient(*server_args)
def callback(fut):
cref.record_end("hook_future_metric", key)
tensor = fut.wait()
if type(tensor) is list:
tensor = sparse_rpc_format_to_tensor(tensor)
tensor = tensor.cuda(cref.rank)
return [tensor]
return fut.then(callback)