blob: a5021bb1e0b45d1619ac47a2a1b75ae88cb787b5 [file] [log] [blame]
#include <ATen/ATen.h>
#include <c10/util/CallOnce.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
// pthread.h is included for tracking bad forks
#ifndef WIN32
#include <pthread.h>
#endif
namespace torch {
namespace mps {
namespace {
// True for children forked after mps init
static bool in_bad_fork = false;
// Called in the forked child if mps has already been initialized
static void forked_mps_child() {
in_bad_fork = true;
}
// Should be called before the first mps call.
static void track_bad_mps_fork() {
#ifndef WIN32
static c10::once_flag flag;
c10::call_once(
flag, [] { pthread_atfork(nullptr, nullptr, forked_mps_child); });
#endif
}
} // namespace
static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return PyBool_FromLong(in_bad_fork);
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_getDefaultMPSGenerator(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
track_bad_mps_fork();
return THPGenerator_initDefaultGenerator(
at::detail::getMPSHooks().getDefaultMPSGenerator());
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
track_bad_mps_fork();
if (at::detail::getMPSHooks().hasMPS()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) {
HANDLE_TH_ERRORS
THPUtils_assert(
THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()");
auto minor = THPUtils_unpackUInt32(args);
if (at::detail::getMPSHooks().isOnMacOS13orNewer(minor)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_deviceSynchronize(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
at::detail::getMPSHooks().deviceSynchronize();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_emptyCache(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
at::detail::getMPSHooks().emptyCache();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_setMemoryFraction(
PyObject* _unused,
PyObject* args) {
HANDLE_TH_ERRORS
THPUtils_assert(
THPUtils_checkDouble(args), "invalid argument to setMemoryFraction()");
double fraction = THPUtils_unpackDouble(args);
at::detail::getMPSHooks().setMemoryFraction(fraction);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_currentAllocatedMemory(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLongLong(
at::detail::getMPSHooks().getCurrentAllocatedMemory());
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_driverAllocatedMemory(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLongLong(
at::detail::getMPSHooks().getDriverAllocatedMemory());
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_profilerStartTrace(
PyObject* _unused,
PyObject* args) {
HANDLE_TH_ERRORS
PyObject* mode_string_o = nullptr;
PyObject* wait_until_completed_string_o = nullptr;
if (!PyArg_ParseTuple(
args, "OO", &mode_string_o, &wait_until_completed_string_o)) {
return nullptr;
}
const std::string mode = THPUtils_unpackString(mode_string_o);
const bool waitUntilCompleted =
THPUtils_unpackBool(wait_until_completed_string_o);
at::detail::getMPSHooks().profilerStartTrace(mode, waitUntilCompleted);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_profilerStopTrace(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
at::detail::getMPSHooks().profilerStopTrace();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
// cppcoreguidelines-avoid-non-const-global-variables,
// cppcoreguidelines-avoid-c-arrays)
static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_deviceSynchronize",
MPSModule_deviceSynchronize,
METH_NOARGS,
nullptr},
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_mps_is_on_macos_13_or_newer",
MPSModule_isMacOS13orNewer,
METH_O,
nullptr},
{"_mps_get_default_generator",
MPSModule_getDefaultMPSGenerator,
METH_NOARGS,
nullptr},
{"_mps_emptyCache", MPSModule_emptyCache, METH_NOARGS, nullptr},
{"_mps_setMemoryFraction", MPSModule_setMemoryFraction, METH_O, nullptr},
{"_mps_currentAllocatedMemory",
MPSModule_currentAllocatedMemory,
METH_NOARGS,
nullptr},
{"_mps_driverAllocatedMemory",
MPSModule_driverAllocatedMemory,
METH_NOARGS,
nullptr},
{"_mps_profilerStartTrace",
MPSModule_profilerStartTrace,
METH_VARARGS,
nullptr},
{"_mps_profilerStopTrace",
MPSModule_profilerStopTrace,
METH_NOARGS,
nullptr},
{nullptr}};
PyMethodDef* python_functions() {
return _MPSModule_methods;
}
} // namespace mps
} // namespace torch