blob: bdd1c0dde4b9965bf321f7668ebb5b5635d61692 [file] [log] [blame]
#include <torch/mps.h>
#include <ATen/Context.h>
#include <c10/util/irange.h>
#include <cstddef>
namespace torch {
namespace mps {
bool is_available() {
return at::detail::getMPSHooks().hasMPS();
}
/// Sets the seed for the MPS's default generator.
void manual_seed(uint64_t seed) {
if (is_available()) {
auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator();
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
}
}
void synchronize() {
TORCH_CHECK(is_available(), "No MPS devices are available");
at::detail::getMPSHooks().deviceSynchronize();
}
} // namespace mps
} // namespace torch