blob: a6afc0a8f64156c9a9dcb0f92927810d321c3db3 [file] [log] [blame]
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/custom_class.h>
namespace torch {
namespace jit {
class CUDAEvent;
// This class is a wrapper around c10::cuda::CUDAStream.
// It is needed because TorchBind does not support all of the argument types
// for c10::cuda::CUDAStream. For more details, please refer to
// c10/cuda/CUDAStream.h.
class CUDAStream final : public CustomClassHolder {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CUDAStream(
c10::optional<c10::Device> device = c10::nullopt,
int64_t priority = 0) {
constexpr int64_t PRIORITY_INDEX = 0;
c10::DeviceIndex device_index =
device.has_value() ? device->index() : c10::cuda::current_device();
stream_ = std::make_unique<c10::cuda::CUDAStream>(
c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device_index));
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CUDAStream(c10::cuda::CUDAStream s) {
stream_ = std::make_unique<c10::cuda::CUDAStream>(s);
}
bool query() {
return stream_->query();
}
c10::intrusive_ptr<CUDAEvent> recordEvent(
c10::intrusive_ptr<CUDAEvent> event);
void synchronize() {
stream_->synchronize();
}
void waitEvent(c10::intrusive_ptr<CUDAEvent> event);
void waitStream(c10::intrusive_ptr<CUDAStream> stream);
/// Get the CUDA device index that this stream is associated with.
int64_t device_index() const {
return stream_->device_index();
}
/// Get the full Device that this stream is associated with. The Device
/// is guaranteed to be a CUDA device.
c10::Device device() const {
return stream_->device();
}
/// Return the stream ID corresponding to this particular stream.
int64_t id() const {
return stream_->id();
}
private:
std::unique_ptr<c10::cuda::CUDAStream> stream_;
friend class CUDAEvent;
};
// This class is a wrapper around at::cuda::CUDAStream.
// It is needed because TorchBind does not support all of the argument types
// for at::cuda::CUDAEvent. For more details, please refer to
// aten/src/ATen/cuda/CUDAEvent.h.
class CUDAEvent final : public CustomClassHolder {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CUDAEvent(
bool enable_timing = false,
bool blocking = false,
bool interprocess = false) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int flags = cudaEventDisableTiming;
if (enable_timing) {
flags = cudaEventDefault;
}
if (blocking) {
flags |= cudaEventBlockingSync;
}
if (interprocess) {
TORCH_CHECK(!enable_timing);
flags |= cudaEventInterprocess;
}
event_ = std::make_unique<at::cuda::CUDAEvent>(flags);
}
double elapsedTime(c10::intrusive_ptr<CUDAEvent> end) {
return event_->elapsed_time(*end->event_);
}
std::string ipcHandle() {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
cudaIpcEventHandle_t handle;
event_->ipc_handle(&handle);
std::string str_handle((const char*)&handle, sizeof(handle));
return str_handle;
}
bool query() {
return event_->query();
}
void record(c10::intrusive_ptr<CUDAStream> stream);
void synchronize() {
event_->synchronize();
}
void wait(c10::intrusive_ptr<CUDAStream> stream);
private:
void recordInternal(CUDAStream* stream);
std::unique_ptr<at::cuda::CUDAEvent> event_;
friend class CUDAStream;
};
c10::intrusive_ptr<CUDAEvent> CUDAStream::recordEvent(
c10::intrusive_ptr<CUDAEvent> event) {
if (!event) {
event = c10::make_intrusive<CUDAEvent>();
}
event->recordInternal(this);
return event;
}
void CUDAStream::waitEvent(c10::intrusive_ptr<CUDAEvent> event) {
event->event_->block(*stream_);
}
void CUDAStream::waitStream(c10::intrusive_ptr<CUDAStream> stream) {
auto ev = c10::make_intrusive<CUDAEvent>();
stream->recordEvent(ev);
waitEvent(ev);
}
void CUDAEvent::record(c10::intrusive_ptr<CUDAStream> stream) {
event_->record(*stream->stream_);
}
void CUDAEvent::recordInternal(CUDAStream* stream) {
event_->record(*stream->stream_);
}
void CUDAEvent::wait(c10::intrusive_ptr<CUDAStream> stream) {
event_->block(*stream->stream_);
}
TORCH_LIBRARY(cuda, m) {
auto stream_class = m.class_<torch::jit::CUDAStream>("Stream").def(
torch::init<c10::optional<c10::Device>, int64_t>(),
"",
{torch::arg("device") = c10::nullopt, torch::arg("priority") = 0});
auto event_class = m.class_<torch::jit::CUDAEvent>("Event").def(
torch::init<bool, bool, bool>(),
"",
{torch::arg("enable_timing") = false,
torch::arg("blocking") = false,
torch::arg("interprocess") = false});
stream_class.def("query", &CUDAStream::query)
.def("record_event", &CUDAStream::recordEvent)
.def("synchronize", &CUDAStream::synchronize)
.def("wait_event", &CUDAStream::waitEvent)
.def("wait_stream", &CUDAStream::waitStream)
.def("device_index", &CUDAStream::device_index)
.def_property("device", &CUDAStream::device)
.def("id", &CUDAStream::id);
event_class.def("elapsed_time", &CUDAEvent::elapsedTime)
.def("query", &CUDAEvent::query)
.def("record", &CUDAEvent::record)
.def("synchronize", &CUDAEvent::synchronize)
.def("wait", &CUDAEvent::wait);
};
} // namespace jit
} // namespace torch