| #pragma once |
| |
| #include <c10/core/impl/InlineDeviceGuard.h> |
| |
| namespace c10 { |
| |
| /// RAII guard that sets a certain default device in its constructor, and |
| /// changes it back to the device that was originally active upon destruction. |
| /// |
| /// The device is always reset to the one that was active at the time of |
| /// construction of the guard. Even if you `set_device` after construction, the |
| /// destructor will still reset the device to the one that was active at |
| /// construction time. |
| /// |
| /// This device guard does NOT have an uninitialized state; it is guaranteed |
| /// to reset a device on exit. If you are in a situation where you *might* |
| /// want to setup a guard (i.e., are looking for the moral equivalent |
| /// of optional<DeviceGuard>), see OptionalDeviceGuard. |
| class DeviceGuard { |
| public: |
| /// No default constructor; see Note [Omitted default constructor from RAII] |
| explicit DeviceGuard() = delete; |
| |
| /// Set the current device to the passed Device. |
| explicit DeviceGuard(Device device) : guard_(device) {} |
| |
| /// This constructor is for testing only. |
| explicit DeviceGuard( |
| Device device, |
| const impl::DeviceGuardImplInterface* impl) |
| : guard_(device, impl) {} |
| |
| /// Copy is disallowed |
| DeviceGuard(const DeviceGuard&) = delete; |
| DeviceGuard& operator=(const DeviceGuard&) = delete; |
| |
| /// Move is disallowed, as DeviceGuard does not have an uninitialized state, |
| /// which is required for moves on types with nontrivial destructors. |
| DeviceGuard(DeviceGuard&& other) = delete; |
| DeviceGuard& operator=(DeviceGuard&& other) = delete; |
| |
| /// Sets the device to the given one. The specified device must be consistent |
| /// with the device type originally specified during guard construction. |
| /// |
| /// TODO: The consistency check here is inconsistent with StreamGuard's |
| /// behavior with set_stream, where a stream on a different device than |
| /// the original one isn't an error; we just reset the stream and then |
| /// switch devices. |
| void reset_device(at::Device device) { |
| guard_.reset_device(device); |
| } |
| |
| /// This method is for testing only. |
| void reset_device( |
| at::Device device, |
| const impl::DeviceGuardImplInterface* impl) { |
| guard_.reset_device(device, impl); |
| } |
| |
| /// Sets the device index to the given one. The device type is inferred |
| /// from the original device type the guard was constructed with. |
| void set_index(DeviceIndex index) { |
| guard_.set_index(index); |
| } |
| |
| /// Returns the device that was set at the time the guard was constructed. |
| Device original_device() const { |
| return guard_.original_device(); |
| } |
| |
| /// Returns the most recent device that was set using this device guard, |
| /// either from construction, or via set_device. |
| Device current_device() const { |
| return guard_.current_device(); |
| } |
| |
| private: |
| impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_; |
| }; |
| |
| /** |
| * A OptionalDeviceGuard is an RAII class that sets a device to some value on |
| * initialization, and resets the device to its original value on destruction. |
| * Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but |
| * with extra constructors and methods as appropriate. |
| * |
| * Besides its obvious use (optionally applying a DeviceGuard), |
| * OptionalDeviceGuard is often also used for the following idiom: |
| * |
| * OptionalDeviceGuard g; |
| * for (const auto& t : tensors) { |
| * g.set_device(t.device()); |
| * do_something_with(t); |
| * } |
| * |
| * This usage is marginally more efficient than constructing a DeviceGuard every |
| * iteration of the for loop, as it avoids an unnecessary device reset. |
| * |
| * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs |
| * when you use the nullary constructor, or pass a nullopt to the constructor. |
| * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the |
| * original device was and they do not reset on destruction. This is why |
| * original_device() and current_device() return optional<Device> rather than |
| * Device (as they do in DeviceGuard), and also is why we didn't just |
| * provide OptionalDeviceGuard by default and hide DeviceGuard from users. |
| * |
| * The semantics of an OptionalDeviceGuard are exactly explained by thinking |
| * of it as an optional<DeviceGuard>. In particular, an initialized |
| * OptionalDeviceGuard doesn't restore device to its value at construction; it |
| * restores device to its value *at initialization*. So if you have the |
| * program: |
| * |
| * setDevice(1); |
| * OptionalDeviceGuard g; |
| * setDevice(2); |
| * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes! |
| * |
| * On destruction, g will reset device to 2, rather than 1. |
| * |
| * An uninitialized OptionalDeviceGuard is distinct from a (initialized) |
| * DeviceGuard whose original_device_ and current_device_ match, since the |
| * DeviceGuard will still reset the device to original_device_. |
| */ |
| class OptionalDeviceGuard { |
| public: |
| /// Create an uninitialized guard. Set the guard later using reset_device. |
| explicit OptionalDeviceGuard() = default; |
| |
| /// Initialize the guard, setting the current device to the passed Device. |
| explicit OptionalDeviceGuard(Device device) : guard_(device) {} |
| |
| /// Initialize the guard if a Device is passed; otherwise leave the |
| /// guard uninitialized. |
| explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {} |
| |
| /// Constructor for testing only. |
| explicit OptionalDeviceGuard( |
| Device device, |
| const impl::DeviceGuardImplInterface* impl) |
| : guard_(device, impl) {} |
| |
| /// Copy is disallowed |
| OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; |
| OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; |
| |
| /// Move is disallowed |
| /// See Note [Explicit initialization of optional fields] |
| /// and // Note [Move construction for RAII guards is tricky] |
| /// for rationale. |
| OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete; |
| OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete; |
| |
| /// Sets the device to the given one. The specified device must be consistent |
| /// with the device type originally specified during guard construction. |
| void reset_device(at::Device device) { |
| guard_.reset_device(device); |
| } |
| |
| /// For testing only |
| void reset_device( |
| at::Device device, |
| const impl::DeviceGuardImplInterface* impl) { |
| guard_.reset_device(device, impl); |
| } |
| |
| /// Returns the device that was set at the time the guard was constructed. |
| optional<Device> original_device() const { |
| return guard_.original_device(); |
| } |
| |
| /// Returns the most recent device that was set using this device guard, |
| /// either from construction, or via reset_device. |
| optional<Device> current_device() const { |
| return guard_.current_device(); |
| } |
| |
| private: |
| impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{}; |
| }; |
| |
| // Note [Whither the DeviceGuard boilerplate] |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // Design note: in principle, we could avoid these wrappers using: |
| // |
| // using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>; |
| // using OptionalDeviceGuard = |
| // impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>; |
| // |
| // But the error messages are worse, and our users can't just look at the |
| // header file to find out what's going on. Furthermore, for specializations |
| // like CUDAStreamGuard, it can be profitable to replace some interfaces with |
| // refined types (e.g., return CUDAStream instead of Stream). So, we eat |
| // the boilerplate and write out the API explicitly. |
| |
| } // namespace c10 |