blob: 0c5c6e74259cc36d6e54686a17c7d11f46f6465e [file] [log] [blame]
#define PY_SSIZE_T_CLEAN
#include <torch/csrc/utils/python_compat.h>
#include <opcode.h>
#include <stdbool.h>
// see https://bugs.python.org/issue35886
#if PY_VERSION_HEX >= 0x03080000
#define Py_BUILD_CORE
#include <internal/pycore_pystate.h>
// These headers were added in 3.11
#if IS_PYTHON_3_11_PLUS
#include <internal/pycore_frame.h>
#define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt
#include <internal/pycore_opcode.h>
#undef NEED_OPCODE_TABLES
#endif
#undef Py_BUILD_CORE
#endif // PY_VERSION_HEX >= 0x03080000
// All the eval APIs change in 3.11 so we need to decide which one to use on the fly
// https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction
#if IS_PYTHON_3_11_PLUS
#define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame
// The next two functions are taken from
// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
// These are not exported by the CPython binary and thus we have
// to get our own implementation of them.
// As a simple way to reduce the impact of ABI changes on the CPython side, this check forces
// us to manually re-check that the function didn't change on the next major version
#if PY_VERSION_HEX >= 0x030C0000 // 3.12
#error "Please ensure that the functions below still match the CPython implementation for 3.12"
#endif
static int
_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)
{
// This only works when opcode is a non-quickened form:
assert(_PyOpcode_Deopt[opcode] == opcode);
int check_oparg = 0;
for (_Py_CODEUNIT *instruction = _PyCode_CODE(frame->f_code);
instruction < frame->prev_instr; instruction++)
{
int check_opcode = _PyOpcode_Deopt[_Py_OPCODE(*instruction)];
check_oparg |= _Py_OPARG(*instruction);
if (check_opcode == opcode && check_oparg == oparg) {
return 1;
}
if (check_opcode == EXTENDED_ARG) {
check_oparg <<= 8;
}
else {
check_oparg = 0;
}
instruction += _PyOpcode_Caches[check_opcode];
}
return 0;
}
int
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
/* Merge fast locals into f->f_locals */
PyObject *locals;
PyObject **fast;
PyCodeObject *co;
locals = frame->f_locals;
if (locals == NULL) {
locals = frame->f_locals = PyDict_New();
if (locals == NULL)
return -1;
}
co = frame->f_code;
fast = _PyFrame_GetLocalsArray(frame);
// COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
// here:
int lasti = _PyInterpreterFrame_LASTI(frame);
if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
/* Free vars have not been initialized -- Do that */
PyCodeObject *co = frame->f_code;
PyObject *closure = frame->f_func->func_closure;
int offset = co->co_nlocals + co->co_nplaincellvars;
for (int i = 0; i < co->co_nfreevars; ++i) {
PyObject *o = PyTuple_GET_ITEM(closure, i);
Py_INCREF(o);
frame->localsplus[offset + i] = o;
}
// COPY_FREE_VARS doesn't have inline CACHEs, either:
frame->prev_instr = _PyCode_CODE(frame->f_code);
}
for (int i = 0; i < co->co_nlocalsplus; i++) {
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
/* If the namespace is unoptimized, then one of the
following cases applies:
1. It does not contain free variables, because it
uses import * or is a top-level namespace.
2. It is a class namespace.
We don't want to accidentally copy free variables
into the locals dict used by the class.
*/
if (kind & CO_FAST_FREE && !(co->co_flags & CO_OPTIMIZED)) {
continue;
}
PyObject *name = PyTuple_GET_ITEM(co->co_localsplusnames, i);
PyObject *value = fast[i];
if (frame->stacktop) {
if (kind & CO_FAST_FREE) {
// The cell was set by COPY_FREE_VARS.
assert(value != NULL && PyCell_Check(value));
value = PyCell_GET(value);
}
else if (kind & CO_FAST_CELL) {
// Note that no *_DEREF ops can happen before MAKE_CELL
// executes. So there's no need to duplicate the work
// that MAKE_CELL would otherwise do later, if it hasn't
// run yet.
if (value != NULL) {
if (PyCell_Check(value) &&
_PyFrame_OpAlreadyRan(frame, MAKE_CELL, i)) {
// (likely) MAKE_CELL must have executed already.
value = PyCell_GET(value);
}
// (likely) Otherwise it it is an arg (kind & CO_FAST_LOCAL),
// with the initial value set when the frame was created...
// (unlikely) ...or it was set to some initial value by
// an earlier call to PyFrame_LocalsToFast().
}
}
}
else {
assert(value == NULL);
}
if (value == NULL) {
if (PyObject_DelItem(locals, name) != 0) {
if (PyErr_ExceptionMatches(PyExc_KeyError)) {
PyErr_Clear();
}
else {
return -1;
}
}
}
else {
if (PyObject_SetItem(locals, name, value) != 0) {
return -1;
}
}
}
return 0;
}
// We need to be able to return the _PyInterpreterFrame to python so create
// a python binding for it
typedef struct THPPyInterpreterFrame {
PyObject_HEAD
_PyInterpreterFrame* frame; // Borrowed reference
} THPPyInterpreterFrame;
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
#define DECLARE_PYOBJ_ATTR(name) \
static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
PyObject* res = (PyObject*)self->frame->name; \
Py_XINCREF(res); \
return res; \
}
DECLARE_PYOBJ_ATTR(f_func)
DECLARE_PYOBJ_ATTR(f_globals)
DECLARE_PYOBJ_ATTR(f_builtins)
DECLARE_PYOBJ_ATTR(f_locals)
DECLARE_PYOBJ_ATTR(f_code)
DECLARE_PYOBJ_ATTR(frame_obj)
#undef DECLARE_PYOBJ_ATTR
static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
return res;
}
// This is not a true attribute of the class but we do access it in python and it is hard to implement
// on the python side, so do it here:
static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyGetSetDef THPDevice_properties[] = {
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
{"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
{"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
{"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
{NULL}};
PyTypeObject THPPyInterpreterFrameType = {
PyVarObject_HEAD_INIT(NULL, 0) "torch._C.dynamo.eval_frame._PyInterpreterFrame", /* tp_name */
sizeof(THPPyInterpreterFrame), /* tp_basicsize */
0, /* tp_itemsize */
NULL, /* tp_dealloc */
0, /* tp_vectorcall_offset */
NULL, /* tp_getattr */
NULL, /* tp_setattr */
NULL, /* tp_reserved */
NULL, /* tp_repr */
NULL, /* tp_as_number */
NULL, /* tp_as_sequence */
NULL, /* tp_as_mapping */
NULL, /* tp_hash */
NULL, /* tp_call */
NULL, /* tp_str */
NULL, /* tp_getattro */
NULL, /* tp_setattro */
NULL, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
NULL, /* tp_doc */
NULL, /* tp_traverse */
NULL, /* tp_clear */
NULL, /* tp_richcompare */
0, /* tp_weaklistoffset */
NULL, /* tp_iter */
NULL, /* tp_iternext */
NULL, /* tp_methods */
NULL, /* tp_members */
THPDevice_properties, /* tp_getset */
NULL, /* tp_base */
NULL, /* tp_dict */
NULL, /* tp_descr_get */
NULL, /* tp_descr_set */
0, /* tp_dictoffset */
NULL, /* tp_init */
NULL, /* tp_alloc */
NULL, /* tp_new */
};
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
if (!self)
return NULL;
self->frame = frame;
return self;
}
#else
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
#endif
#ifdef _WIN32
#define unlikely(x) (x)
#else
#define unlikely(x) __builtin_expect((x), 0)
#endif
#define NULL_CHECK(val) \
if (unlikely((val) == NULL)) { \
fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \
PyErr_Print(); \
abort(); \
} else { \
}
#define CHECK(cond) \
if (unlikely(!(cond))) { \
fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \
abort(); \
} else { \
}
#ifdef TORCHDYNAMO_DEBUG
#define DEBUG_CHECK(cond) CHECK(cond)
#define DEBUG_NULL_CHECK(val) NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__)
#define DEBUG_TRACE0(msg) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__)
#else
#define DEBUG_CHECK(cond)
#define DEBUG_NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...)
#define DEBUG_TRACE0(msg)
#endif
// Flag to just run a frame normally
#define SKIP_CODE ((void*)0x1)
static PyObject* noargs = NULL; /* cached empty tuple */
static PyObject* dotzerokey = NULL; /* ".0" */
static PyObject* guard_fail_hook = NULL;
static PyObject* guard_error_hook = NULL;
static PyObject* profiler_start_hook = NULL;
static PyObject* profiler_end_hook = NULL;
static PyObject* guard_profiler_name_str = NULL; /* cached py str */
size_t extra_index = -1;
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
inline static PyObject* eval_frame_callback_get(void) {
void* result = PyThread_tss_get(&eval_frame_callback_key);
if (unlikely(result == NULL)) {
Py_RETURN_NONE;
} else {
return (PyObject*)result;
}
}
inline static void eval_frame_callback_set(PyObject* obj) {
PyThread_tss_set(&eval_frame_callback_key, obj);
}
static void ignored(void* obj) {}
static PyObject* _custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag);
static PyObject* _custom_eval_frame(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag,
PyObject* callback);
static PyObject *(*previous_eval_frame)(PyThreadState *tstate,
PyFrameObject *frame, int throw_flag) = NULL;
#if PY_VERSION_HEX >= 0x03090000
static PyObject* custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag) {
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#else
static PyObject* custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) {
PyThreadState* tstate = PyThreadState_GET();
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#endif
inline static PyObject* eval_frame_default(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag) {
#if PY_VERSION_HEX >= 0x03090000
if (tstate == NULL) {
tstate = PyThreadState_GET();
}
if (previous_eval_frame) {
return previous_eval_frame(tstate, frame, throw_flag);
}
else {
return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
}
#else
return _PyEval_EvalFrameDefault(frame, throw_flag);
#endif
}
inline static void enable_eval_frame_shim(PyThreadState* tstate) {
#if PY_VERSION_HEX >= 0x03090000
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
&custom_eval_frame_shim) {
DEBUG_CHECK(previous_eval_frame == NULL);
previous_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
&custom_eval_frame_shim);
}
#else
if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
// First call
tstate->interp->eval_frame = &custom_eval_frame_shim;
}
#endif
}
inline static void enable_eval_frame_default(PyThreadState* tstate) {
#if PY_VERSION_HEX >= 0x03090000
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
previous_eval_frame) {
DEBUG_CHECK(previous_eval_frame != NULL);
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
previous_eval_frame);
previous_eval_frame = NULL;
}
#else
if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) {
// First call
tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
}
#endif
}
static inline PyObject* call_callback(
PyObject* callable,
THP_EVAL_API_FRAME_OBJECT* _frame,
long cache_len) {
#if IS_PYTHON_3_11_PLUS
THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
#else
PyFrameObject* frame = _frame;
#endif
PyObject* args = Py_BuildValue("(Ol)", frame, cache_len);
if (args == NULL) {
return NULL;
}
PyObject* result = PyObject_CallObject(callable, args);
Py_DECREF(args);
return result;
}
typedef struct cache_entry {
// check the guards: lambda: <locals of user function>: bool
PyObject* check_fn;
// modified user bytecode (protected by check_fn's guards)
PyCodeObject* code;
// on a cache miss, linked list of next thing to try
struct cache_entry* next;
} CacheEntry;
static CacheEntry* create_cache_entry(
CacheEntry* next,
PyObject* guarded_code) {
CacheEntry* e = (CacheEntry*)malloc(sizeof(CacheEntry));
DEBUG_NULL_CHECK(e);
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
NULL_CHECK(e->check_fn);
e->code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code");
NULL_CHECK(e->code);
e->next = next;
return e;
}
static void destroy_cache_entry(CacheEntry* e) {
if (e == NULL || e == SKIP_CODE) {
return;
}
Py_XDECREF(e->check_fn);
Py_XDECREF(e->code);
destroy_cache_entry(e->next);
free(e);
}
inline static CacheEntry* get_extra(PyCodeObject* code) {
CacheEntry* extra = NULL;
_PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra);
return extra;
}
inline static void set_extra(PyCodeObject* code, CacheEntry* extra) {
// TODO(jansel): would it be faster to bypass this?
_PyCode_SetExtra((PyObject*)code, extra_index, extra);
}
inline static const char* name(THP_EVAL_API_FRAME_OBJECT* frame) {
DEBUG_CHECK(PyUnicode_Check(frame->f_code->co_name));
return PyUnicode_AsUTF8(frame->f_code->co_name);
}
static PyObject* call_guard_fail_hook(
PyObject* hook,
CacheEntry* e,
size_t index,
PyObject* f_locals) {
// call debugging logic when a guard fails
PyObject* args = PyTuple_Pack(
5,
e->check_fn,
e->code,
f_locals,
PyLong_FromSize_t(index),
(e->next == NULL ? Py_True : Py_False));
if (args == NULL) return NULL;
PyObject* result = PyObject_CallObject(hook, args);
Py_DECREF(args);
return result;
}
static PyObject* call_profiler_start_hook(PyObject* name_str) {
if (profiler_start_hook == NULL) return NULL;
if (name_str == NULL) return NULL;
PyObject* args = PyTuple_Pack(1, name_str);
if (args == NULL) return NULL;
PyObject* result = PyObject_CallObject(profiler_start_hook, args);
Py_DECREF(args);
return result;
}
static void call_profiler_end_hook(PyObject* record) {
// 'record' obj is the return value of calling _start_hook()
if (profiler_end_hook == NULL) return;
if (record == NULL) return;
PyObject* args = PyTuple_Pack(1, record);
if (args == NULL) return;
PyObject* result = PyObject_CallObject(profiler_end_hook, args);
Py_XDECREF(result);
Py_DECREF(args);
}
// Return value: borrowed reference
// Is either Py_None or a PyCodeObject
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
if (e == NULL) {
// NB: intentionally not using Py_RETURN_NONE, to return borrowed ref
return Py_None;
}
PyObject *f_locals = frame->f_locals;
PyObject* dotzero = PyDict_GetItem(f_locals, dotzerokey);
PyObject* valid = NULL;
if (unlikely(dotzero != NULL)) {
// .0 is a special variable name used for implicit args
PyObject* args = PyTuple_Pack(1, dotzero);
if (args == NULL) return NULL;
valid = PyObject_Call(e->check_fn, args, f_locals);
Py_DECREF(args);
} else {
valid = PyObject_Call(e->check_fn, noargs, f_locals);
}
if (unlikely(valid == NULL)) {
if (guard_error_hook != NULL) {
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
PyObject* r = call_guard_fail_hook(guard_error_hook, e, index, f_locals);
if (r == NULL) {
return NULL;
}
Py_DECREF(r);
PyErr_Restore(type, value, traceback);
}
return NULL;
}
Py_DECREF(valid);
if (valid == Py_True) {
// Keep the head as the most recently used cache entry.
// If the hit cache entry is not the head of the linked list,
// move it to the head
if (prev != NULL) {
CacheEntry* extra = get_extra(frame->f_code);
prev->next = e->next;
e->next = extra;
set_extra(frame->f_code, e);
}
return (PyObject*)e->code;
}
if (unlikely(guard_fail_hook != NULL)) {
PyObject* r = call_guard_fail_hook(guard_fail_hook, e, index, f_locals);
if (r == NULL) {
return NULL;
}
Py_DECREF(r);
}
return lookup(e->next, frame, e, index + 1);
}
static long cache_size(CacheEntry* e) {
if (e == NULL) {
return 0;
}
return 1 + cache_size(e->next);
}
inline static PyObject* eval_custom_code(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
PyCodeObject* code,
int throw_flag) {
Py_ssize_t ncells = 0;
Py_ssize_t nfrees = 0;
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
ncells = PyCode_GetNCellvars(code);
nfrees = PyCode_GetNFreevars(code);
DEBUG_NULL_CHECK(tstate);
DEBUG_NULL_CHECK(frame);
DEBUG_NULL_CHECK(code);
#if IS_PYTHON_3_11_PLUS
DEBUG_CHECK(ncells == frame->f_code->co_ncellvars);
DEBUG_CHECK(nfrees == frame->f_code->co_nfreevars);
#else
DEBUG_CHECK(ncells == PyTuple_GET_SIZE(frame->f_code->co_cellvars));
DEBUG_CHECK(nfrees == PyTuple_GET_SIZE(frame->f_code->co_freevars));
#endif
DEBUG_CHECK(nlocals_new >= nlocals_old);
PyFrameObject* shadow_obj = PyFrame_New(tstate, code, frame->f_globals, NULL);
#if IS_PYTHON_3_11_PLUS
THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj->f_frame;
Py_XINCREF(frame->f_func->func_closure);
shadow->f_func->func_closure = frame->f_func->func_closure;
#else
THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj;
#endif
if (shadow == NULL) {
Py_DECREF(shadow_obj);
return NULL;
}
#if IS_PYTHON_3_11_PLUS
PyObject** fastlocals_old = frame->localsplus;
PyObject** fastlocals_new = shadow->localsplus;
// copy from old fastlocals to new fastlocals:
// for i, name in enumerate(localsplusnames_new):
// name_to_idx[name] = i
// for i, name in enumerate(localsplusnames_old):
// fastlocals_new[name_to_idx[name]] = fastlocals_old[i]
PyObject* name_to_idx = PyDict_New();
if (name_to_idx == NULL) {
DEBUG_TRACE0("unable to create localsplus name dict");
Py_DECREF(shadow_obj);
return NULL;
}
for (Py_ssize_t i = 0; i < code->co_nlocalsplus; i++) {
PyObject *name = PyTuple_GET_ITEM(code->co_localsplusnames, i);
PyObject *idx = PyLong_FromSsize_t(i);
if (name == NULL || idx == NULL || PyDict_SetItem(name_to_idx, name, idx) != 0) {
Py_DECREF(shadow_obj);
Py_DECREF(name_to_idx);
return NULL;
}
}
for (Py_ssize_t i = 0; i < frame->f_code->co_nlocalsplus; i++) {
PyObject *name = PyTuple_GET_ITEM(frame->f_code->co_localsplusnames, i);
PyObject *idx = PyDict_GetItem(name_to_idx, name);
Py_ssize_t new_i = PyLong_AsSsize_t(idx);
if (name == NULL || idx == NULL || (new_i == (Py_ssize_t)-1 && PyErr_Occurred() != NULL)) {
Py_DECREF(shadow_obj);
Py_DECREF(name_to_idx);
return NULL;
}
Py_XINCREF(fastlocals_old[i]);
fastlocals_new[new_i] = fastlocals_old[i];
}
Py_DECREF(name_to_idx);
#else
PyObject** fastlocals_old = frame->f_localsplus;
PyObject** fastlocals_new = shadow->f_localsplus;
for (Py_ssize_t i = 0; i < nlocals_old; i++) {
Py_XINCREF(fastlocals_old[i]);
fastlocals_new[i] = fastlocals_old[i];
}
for (Py_ssize_t i = 0; i < ncells + nfrees; i++) {
Py_XINCREF(fastlocals_old[nlocals_old + i]);
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
}
#endif
PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
Py_DECREF(shadow_obj);
return result;
}
static PyObject* _custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag) {
// Shims logic into one of three states. Can probably be refactored into a
// single func, later:
// - None: disables TorchDynamo
// - False: run-only mode (reuse existing compiles)
// - Python callable(): enables TorchDynamo
PyObject* callback = eval_frame_callback_get();
if (callback == Py_None) {
return eval_frame_default(tstate, frame, throw_flag);
}
return _custom_eval_frame(tstate, frame, throw_flag, callback);
}
static PyObject* _custom_eval_frame(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag,
PyObject* callback) {
#if IS_PYTHON_3_11_PLUS
DEBUG_TRACE(
"begin %s %s %i %i",
name(frame),
PyUnicode_AsUTF8(frame->f_code->co_filename),
frame->f_code->co_firstlineno,
_PyInterpreterFrame_LASTI(frame));
#else
DEBUG_TRACE(
"begin %s %s %i %i %i",
name(frame),
PyUnicode_AsUTF8(frame->f_code->co_filename),
frame->f_lineno,
frame->f_lasti,
frame->f_iblock);
#endif
if (throw_flag) {
// When unwinding generators, eval frame is called with throw_flag ==
// true. Frame evaluation is supposed to continue unwinding by propagating
// the exception. Dynamo doesn't really know how to do this, nor does it
// really want to do this, because there's unlikely any code to capture
// (you're going to immediately quit out of the frame, perhaps running
// some unwinding logic along the way). So we just run the default
// handler in this case.
//
// NB: A previous version of this patch returned NULL. This is wrong,
// because returning NULL is *different* from unwinding an exception.
// In particular, you will not execute things like context manager
// __exit__ if you just return NULL.
//
// NB: It's /conceivable/ that you might want to actually still call the
// Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
// do any stack unwinding code. But this is not really useful because
// (1) Dynamo doesn't actually know how to do stack unwinding, so it would
// immediately skip the frame, and (2) even if it did, this would only
// be profitable if there was tensor code in the unwinding code. Seems
// unlikely.
DEBUG_TRACE("throw %s", name(frame));
return eval_frame_default(tstate, frame, throw_flag);
}
CacheEntry* extra = get_extra(frame->f_code);
if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
DEBUG_TRACE("skip %s", name(frame));
return eval_frame_default(tstate, frame, throw_flag);
}
// TODO(jansel): investigate directly using the "fast" representation
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
// even though we should pass a PyFrameObject.
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
DEBUG_TRACE("error %s", name(frame));
return NULL;
}
// A callback of Py_False indicates "run only" mode, the cache is checked, but
// we never compile.
if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", name(frame));
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
PyObject* maybe_cached_code = lookup(extra, frame, NULL, 0);
call_profiler_end_hook(hook_record);
Py_XDECREF(hook_record);
if (maybe_cached_code == NULL) {
// guard eval failed, keep propagating
return NULL;
} else if (maybe_cached_code == Py_None) {
DEBUG_TRACE("cache miss %s", name(frame));
return eval_frame_default(tstate, frame, throw_flag);
}
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
// used cached version
DEBUG_TRACE("cache hit %s", name(frame));
return eval_custom_code(tstate, frame, cached_code, throw_flag);
}
DEBUG_CHECK(PyDict_CheckExact(frame->f_locals));
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins));
// We don't run the current custom_eval_frame behavior for guards.
// So we temporarily set the callback to Py_None to drive the correct behavior
// in the shim.
eval_frame_callback_set(Py_None);
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
PyObject* maybe_cached_code = lookup(extra, frame, NULL, 0);
call_profiler_end_hook(hook_record);
Py_XDECREF(hook_record);
if (maybe_cached_code == NULL) {
// Python error
return NULL;
} else if (maybe_cached_code != Py_None) {
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
// used cached version
DEBUG_TRACE("cache hit %s", name(frame));
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, cached_code, throw_flag);
}
// cache miss
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
// that gets re-interpreted as a PyObject (which it is NOT!)
PyObject* result =
call_callback(callback, frame, cache_size(extra));
if (result == NULL) {
// internal exception, returning here will leak the exception into user code
// this is useful for debugging -- but we dont want it to happen outside of
// testing
// NB: we intentionally DO NOT re-enable custom behavior to prevent
// cascading failure from internal exceptions. The upshot is if
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
// inside the torch.compile block we won't try to Dynamo anything else.
return NULL;
} else if (result != Py_None) {
DEBUG_TRACE("create cache %s", name(frame));
extra = create_cache_entry(extra, result);
Py_DECREF(result);
set_extra(frame->f_code, extra);
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, extra->code, throw_flag);
} else {
DEBUG_TRACE("create skip %s", name(frame));
Py_DECREF(result);
destroy_cache_entry(extra);
set_extra(frame->f_code, SKIP_CODE);
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_frame_default(tstate, frame, throw_flag);
}
}
static int active_dynamo_threads = 0;
static PyObject* increment_working_threads(PyThreadState* tstate) {
active_dynamo_threads = active_dynamo_threads + 1;
if (active_dynamo_threads > 0) {
enable_eval_frame_shim(tstate);
}
Py_RETURN_NONE;
}
static PyObject* decrement_working_threads(PyThreadState* tstate) {
if (active_dynamo_threads > 0) {
active_dynamo_threads = active_dynamo_threads - 1;
if (active_dynamo_threads == 0) {
enable_eval_frame_default(tstate);
}
}
Py_RETURN_NONE;
}
static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate) {
// Change the eval frame callback and return the old one
// - None: disables TorchDynamo
// - False: run-only mode (reuse existing compiles)
// - Python callable(): enables TorchDynamo
PyObject* old_callback = eval_frame_callback_get();
// owned by caller
Py_INCREF(old_callback);
if (old_callback != Py_None && new_callback == Py_None) {
decrement_working_threads(tstate);
} else if (old_callback == Py_None && new_callback != Py_None) {
increment_working_threads(tstate);
}
Py_INCREF(new_callback);
Py_DECREF(old_callback);
// Set thread local callback. This will drive behavior of our shim, if/when it
// is installed.
eval_frame_callback_set(new_callback);
return old_callback;
}
static PyObject* set_eval_frame_py(PyObject* dummy, PyObject* args) {
PyObject* callback = NULL;
if (!PyArg_ParseTuple(args, "O:callback", &callback)) {
DEBUG_TRACE0("arg error");
return NULL;
}
if (callback != Py_None && callback != Py_False &&
!PyCallable_Check(callback)) {
DEBUG_TRACE0("arg error");
PyErr_SetString(PyExc_TypeError, "expected a callable");
return NULL;
}
DEBUG_TRACE(
"python enabled=%d and is run_only=%d",
callback != Py_None,
callback == Py_False);
return set_eval_frame(callback, PyThreadState_GET());
}
static PyObject* reset_code(PyObject* dummy, PyObject* args) {
PyObject* code = NULL;
if (!PyArg_ParseTuple(args, "O:code", &code)) {
DEBUG_TRACE0("arg error");
return NULL;
}
if (!PyCode_Check(code)) {
DEBUG_TRACE0("arg error");
PyErr_SetString(PyExc_TypeError, "expected a code object");
return NULL;
}
destroy_cache_entry(get_extra((PyCodeObject*)code));
set_extra((PyCodeObject*)code, NULL);
Py_RETURN_NONE;
}
static PyObject* unsupported(PyObject* dummy, PyObject* args) {
// a dummy C function used in testing
PyObject* obj1 = NULL;
PyObject* obj2 = NULL;
if (!PyArg_ParseTuple(args, "OO", &obj1, &obj2)) {
return NULL;
}
Py_INCREF(obj2);
return obj2;
}
static PyObject* skip_code(PyObject* dummy, PyObject* args) {
PyObject* obj = NULL;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
if (!PyCode_Check(obj)) {
PyErr_SetString(PyExc_TypeError, "expected a code object");
return NULL;
}
set_extra((PyCodeObject*)obj, SKIP_CODE);
Py_RETURN_NONE;
}
static PyObject* set_guard_fail_hook(PyObject* dummy, PyObject* args) {
PyObject* obj = NULL;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
Py_XDECREF(guard_fail_hook);
if (obj == Py_None) {
guard_fail_hook = NULL;
} else {
guard_fail_hook = obj;
Py_INCREF(guard_fail_hook);
}
Py_RETURN_NONE;
}
static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* args) {
PyObject* obj = NULL;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
Py_XDECREF(guard_error_hook);
if (obj == Py_None) {
guard_error_hook = NULL;
} else {
guard_error_hook = obj;
Py_INCREF(guard_error_hook);
}
Py_RETURN_NONE;
}
static PyObject* clear_profiler_hooks(PyObject* dummy, PyObject* args) {
Py_XDECREF(profiler_start_hook);
profiler_start_hook = NULL;
Py_XDECREF(profiler_end_hook);
profiler_end_hook = NULL;
Py_XDECREF(guard_profiler_name_str);
guard_profiler_name_str = NULL;
Py_RETURN_NONE;
}
static PyObject* set_profiler_hooks(PyObject* dummy, PyObject* args) {
PyObject* start = NULL;
PyObject* end = NULL;
if (!PyArg_ParseTuple(args, "OO", &start, &end)) {
return NULL;
}
Py_XDECREF(profiler_start_hook);
Py_XDECREF(profiler_end_hook);
if (start == Py_None || end == Py_None) {
clear_profiler_hooks(NULL, NULL);
} else {
profiler_start_hook = start;
profiler_end_hook = end;
Py_INCREF(profiler_start_hook);
Py_INCREF(profiler_end_hook);
}
Py_XDECREF(guard_profiler_name_str);
guard_profiler_name_str = Py_BuildValue("s", "TorchDynamo Cache Lookup");
Py_RETURN_NONE;
}
static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame_py, METH_VARARGS, NULL},
{"reset_code", reset_code, METH_VARARGS, NULL},
{"unsupported", unsupported, METH_VARARGS, NULL},
{"skip_code", skip_code, METH_VARARGS, NULL},
{"set_guard_fail_hook", set_guard_fail_hook, METH_VARARGS, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_VARARGS, NULL},
{"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL},
{"clear_profiler_hooks", clear_profiler_hooks, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}};
static struct PyModuleDef _module = {
PyModuleDef_HEAD_INIT,
"torch._C._dynamo.eval_frame",
"Module containing hooks to override eval_frame",
-1,
_methods};
PyObject* torch_c_dynamo_eval_frame_init(void) {
extra_index = _PyEval_RequestCodeExtraIndex(ignored);
int result = PyThread_tss_create(&eval_frame_callback_key);
CHECK(result == 0);
Py_INCREF(Py_None);
eval_frame_callback_set(Py_None);
noargs = PyTuple_New(0);
dotzerokey = PyUnicode_InternFromString(".0");
PyObject* module = PyModule_Create(&_module);
#if IS_PYTHON_3_11_PLUS
if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
return NULL;
}
Py_INCREF(&THPPyInterpreterFrameType);
if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
return NULL;
}
#endif
return module;
}