blob: f258775685bfe9a75d474ae1de1c0e37e0cd79ea [file] [log] [blame]
#pragma once
#include <exception>
#include "caffe2/core/blob.h"
#include <gloo/config.h>
#include <gloo/context.h>
#include <gloo/transport/device.h>
namespace caffe2 {
namespace gloo {
TORCH_API void signalFailure(Blob* status_blob, std::exception& exception);
struct createDeviceAttr {
// "tcp" or "ibverbs"
std::string transport;
// E.g. "eth0" (tcp), or "mlx5_0" (ibverbs).
// This may be empty to make Gloo figure it out.
std::string interface;
};
TORCH_API std::shared_ptr<::gloo::transport::Device> createDevice(
const createDeviceAttr attr);
// Captures the parameters passed to Gloo.
struct GlooParameters {
std::shared_ptr<::gloo::Context> context;
std::vector<const void*> inputs;
std::vector<void*> outputs;
size_t size;
TypeMeta meta;
template <typename T>
std::vector<const T*> getInputs() {
std::vector<const T*> result;
result.reserve(inputs.size());
for (auto& input : inputs) {
result.push_back(reinterpret_cast<const T*>(input));
}
return result;
}
template <typename T>
std::vector<T*> getOutputs() {
std::vector<T*> result;
result.reserve(outputs.size());
for (auto& output : outputs) {
result.push_back(reinterpret_cast<T*>(output));
}
return result;
}
template <typename T>
T* getOutput() {
return reinterpret_cast<T*>(outputs[0]);
}
template <typename T>
bool IsType() const {
return meta.Match<T>();
}
bool operator==(GlooParameters const& other) const {
return context == other.context && inputs == other.inputs &&
outputs == other.outputs && size == other.size;
}
};
} // namespace gloo
} // namespace caffe2