blob: 953cec8a1bc36e8550b026e49ef4d2b4fab76e75 [file] [log] [blame]
#pragma once
#ifdef USE_C10D_UCC
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
namespace c10d {
#define RECORD_COMMS_TRACE( \
_comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \
do { \
if (torch_ucc_config.enable_comms_logger) { \
_comms_tracer->recordComms( \
opTypeToString(_opType), \
(uintptr_t)_work.get(), \
_rank, \
_comm_size, \
_inTensors, \
_outTensors); \
} \
} while (0)
// interfaces to collect communication traces
class TORCH_API CommTraceLogger : public torch::CustomClassHolder {
private:
std::vector<std::string> comms_trace_;
std::vector<std::string> curBlocks_; /* unused */
std::vector<int64_t> curOutSplitSizes_;
std::vector<int64_t> curInSplitSizes_;
int curRoot_ = -1;
unsigned long seqnum = 0;
public:
void setCurBlock(const std::string& name); /* unused */
void popBlock(); /* unused */
// record root info if applicable, e.g., broadcast, gather, scatter
void recordOptionalInfo(int root = -1);
// record input/output splits of Alltoallv
void recordOptionalInfo(
const std::vector<int64_t>& outputSplitSizes = {},
const std::vector<int64_t>& inputSplitSizes = {});
// record essential comms information
void recordComms(
const std::string& collName,
const uintptr_t workReq = 0,
const int rank = -1,
const int world_size = -1,
const std::vector<at::Tensor>& inputTensors = {},
const std::vector<at::Tensor>& outputTensor = {});
// return collected comms traces
std::vector<std::string>& getCommsTrace() {
return comms_trace_;
}
};
} // namespace c10d
#endif // USE_C10D_UCC