blob: 28e66542b2c6a1075a82d5bdd18a97ce9928f081 [file] [log] [blame]
#include <torch/csrc/jit/runtime/interpreter.h>
#include <ATen/Parallel.h>
#include <ATen/core/ivalue.h>
#include <ATen/record_function.h>
#include <c10/core/thread_pool.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/interpreter/code_impl.h>
#include <torch/csrc/jit/runtime/interpreter/frame.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/script_profile.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <string>
#ifdef USE_RPC
#include <torch/csrc/distributed/autograd/context/container.h>
using torch::distributed::autograd::DistAutogradContainer;
#endif
#include <exception>
#include <memory>
#include <mutex>
#include <ostream>
#include <stdexcept>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
C10_DEFINE_bool(
torch_jit_enable_rethrow_caught_exception,
false,
"enable rethrowing caught exception");
namespace torch {
namespace jit {
using CodeImpl = interpreter::CodeImpl;
// Before we translate to interpreter instructions, we do
// some preprocessing of the graph to turn it into a form that is closer
// to what the instructions will look like.
// In particular we:
// * Computes whether a input to a node is the last use, so we can issue MOVE
// rather than LOAD instructions.
// * Drop nodes are inserted for any node that is unused to create a dummy use
// that will cause the interpreter to free the node.
// A drop node just pops its input off the stack to ensure the interpreter
// releases references to nodes that are never used. Drop nodes are also
// inserted when the last use of a node is in some conditionally run control
// flow (e.g. one side of an If) and the interpreter must free the node only
// after the control flow has reconverged
// Outputs are:
// * graph - the post processed copy of g
// * move_flags[n] - a list of booleans, one for each input,
// indicating whether this is the last use of the value. The interpreter
// should generate a move rather than a copy in this case.
TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) {
if (!t.defined()) {
return TensorType::get()->withUndefined();
}
auto r = TensorType::create(t);
if (!at::GradMode::is_enabled()) {
return r->withRequiresGrad(false);
}
return r;
}
namespace {
inline int64_t getDistAutogradContextId() {
#ifdef USE_RPC
return DistAutogradContainer::currentContextId();
#else
return 0;
#endif
}
} // namespace
thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr;
struct TLSCurrentInterpreterGuard {
TLSCurrentInterpreterGuard(InterpreterStateImpl* state)
: prev_state_(tls_int_state_ptr_) {
tls_int_state_ptr_ = state;
}
~TLSCurrentInterpreterGuard() {
tls_int_state_ptr_ = prev_state_;
}
private:
InterpreterStateImpl* prev_state_;
};
// InterpreterState state that and used to compute a Code
struct InterpreterStateImpl : c10::intrusive_ptr_target {
InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
: taskLauncher_(std::move(taskLauncher)) {
enterFrame(code, 0);
}
private:
using Frame = torch::jit::interpreter::Frame;
struct WarnedNodes {
public:
// Inserts idx into warned_nodes_, returns a boolean indicates whether
// insertion actually happened (idx wasn't originally in the set).
bool insert(int32_t idx) {
std::unique_lock<std::mutex> lock(mutex_);
return warned_nodes_.insert(idx).second;
}
private:
std::mutex mutex_;
std::unordered_set<int32_t> warned_nodes_;
};
WarnedNodes warned_nodes_;
// if we need to suspend, where do we reset the stack?
// answer: to where it was when we were called, not
// including any inputs to this function
int64_t stack_start_ = -1;
c10::intrusive_ptr<Future> future_;
TaskLauncher taskLauncher_;
// this holds all the tensors for this interpreter run
// we don't bother minimizing the size of this vector, since the extra
// memory used by the pointers in this will be small
// instead we are very aggressive about releasing tensors when they become
// dead to make sure memory management happens efficiently. We optimize for
// the case where derivatives are run with retain_graph=False in the case
// where it is true, then the interpreter and this array get copied if this
// every becomes a bottleneck then we _should_ consider minimizing the total
// number or register
std::vector<IValue> registers;
// A stack of objects that have been __enter__'d.
std::vector<IValue> entered_objects;
std::vector<Frame> frames;
c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
c10::raw::intrusive_ptr::incref(this);
return c10::intrusive_ptr<InterpreterStateImpl>::reclaim(this);
}
void enterFrame(const Code& code, size_t base_pointer) {
frames.emplace_back(Frame{code.pImpl, 0, base_pointer, c10::nullopt});
registers.resize(registers.size() + code.pImpl->register_size_);
}
void leaveFrame() {
registers.resize(registers.size() - frames.back().function->register_size_);
frames.pop_back();
}
void callFunction(
Function& f,
Stack& stack,
c10::optional<size_t> bailOut = c10::nullopt,
bool next = true) {
bool newFrame = f.call(stack, bailOut, [&](const Code& code) {
enterFrame(code, stack.size() - code.num_inputs());
checkAndStartRecordFunction(frames.back(), stack);
});
if (next) {
(frames.rbegin() + (newFrame ? 1 : 0))->pc++;
}
}
// relative to the end of the register list so that when we call
// functions we are referring to the registers of the currently executing
// function.
IValue& reg(size_t reg) {
return *(registers.end() - reg);
}
void dump(std::ostream& out, const Stack& stack) const {
out << "Stack:\n";
for (const auto& val : stack) {
out << val;
out << "\n";
}
}
#if defined(__GNUC__) || defined(__clang__)
#define JIT_USE_COMPUTED_GOTO
#endif
// Primitives for making interpreter internal state transitions.
// We maintain two local variables as the internal interpreter state:
// `frame` will be the current frame that the interpreter operators on.
// `inst` will the current instruction pointed to by program counter.
//
// Instruction blocks should be always declared through `INST` macro and
// the instruction body should always start with a `INST_GUARD` declaration.
// Also blocks should be ended properly with either `INST_NEXT` (for going
// to the next instruction), or `INST_DISPATCH` (for jumping to a computed
// position using `INST_FETCH`).
#define INST_FETCH(X) (frame.function->instructions_[frame.pc += (X)])
#define INST_GUARD \
profiling::InstructionSpan span { \
*frame.function->instructions_source()[frame.pc] \
}
#if defined(JIT_USE_COMPUTED_GOTO)
#define INST(NAME) \
NAME: \
label_##NAME
#define INST_DISPATCH goto* dispatch_table[inst.op]
#else
#define INST(NAME) NAME
#define INST_DISPATCH break
#endif
#define INST_NEXT \
inst = INST_FETCH(1); \
INST_DISPATCH
bool runImpl(Stack& stack) {
// if we have never run before, then we might have to return the
// stack when we suspend, record where it starts so we return the right
// stack
if (stack_start_ == -1) {
TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
stack_start_ = stack.size() - frames.back().function->n_inputs;
} else {
// during restarts, all of the stack is always our own, so we leave
// nothing
stack_start_ = 0;
}
TLSCurrentInterpreterGuard g(this);
if (frames.back().pc == 0 && stack_start_ == 0) {
checkAndStartRecordFunction(frames.back(), stack);
}
#if defined(JIT_USE_COMPUTED_GOTO)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
static void* dispatch_table[] = {
#define DISPATCH_TABLE_ENTRY(op, _) &&label_##op,
FORALL_OPCODES(DISPATCH_TABLE_ENTRY)
#undef DISPATCH_TABLE_ENTRY
};
#endif
try {
while (true) {
Frame& frame = frames.back();
Instruction inst = INST_FETCH(0);
switch (inst.op) {
case INST(ENTER): {
INST_GUARD;
const auto& obj = peek(stack, 0, 1);
TORCH_INTERNAL_ASSERT(obj.isObject());
entered_objects.push_back(obj);
}
INST_NEXT;
case INST(EXIT): {
INST_GUARD;
auto obj = entered_objects.back().toObject();
auto& f = obj->type()->getMethod("__exit__");
push(stack, std::move(obj));
entered_objects.pop_back();
push(stack, IValue());
push(stack, IValue());
push(stack, IValue());
callFunction(f, stack);
continue;
}
case INST(OP): {
INST_GUARD;
#ifndef NDEBUG
size_t init_size = stack.size();
#endif
frame.function->operator_table_[inst.X](stack);
#ifndef NDEBUG
frame.function->assert_stack_size(inst.X, init_size, stack.size());
#endif
}
INST_NEXT;
case INST(OPN): {
INST_GUARD;
stack.emplace_back(inst.N);
#ifndef NDEBUG
size_t init_size = stack.size();
#endif
frame.function->operator_table_[inst.X](stack);
#ifndef NDEBUG
frame.function->assert_stack_size(inst.X, init_size, stack.size());
#endif
}
INST_NEXT;
case INST(LOAD): {
INST_GUARD;
stack.emplace_back(reg(inst.X));
}
INST_NEXT;
case INST(MOVE): {
INST_GUARD;
stack.emplace_back(std::move(reg(inst.X)));
}
INST_NEXT;
case INST(STORE): {
INST_GUARD;
reg(inst.X) = pop(stack);
}
INST_NEXT;
case INST(STOREN): {
INST_GUARD;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(stack.size() >= inst.N);
for (size_t i = inst.N; i > 0; --i) {
reg(inst.X + i - 1) = pop(stack);
}
}
INST_NEXT;
case INST(DROP): {
INST_GUARD;
stack.pop_back();
}
INST_NEXT;
case INST(DROPR): {
INST_GUARD;
reg(inst.X) = IValue();
}
INST_NEXT;
case INST(LOADC): {
INST_GUARD;
stack.emplace_back(frame.function->constant_table_[inst.X]);
}
INST_NEXT;
case INST(GET_ATTR): {
INST_GUARD;
const auto& userObj = stack.back().toObjectRef();
stack.back() = userObj.getSlot(inst.X);
}
INST_NEXT;
case INST(SET_ATTR): {
INST_GUARD;
auto v = pop(stack);
auto& userObj = stack.back().toObjectRef();
userObj.setSlot(inst.X, std::move(v));
stack.pop_back();
}
INST_NEXT;
case INST(JF): {
INST_GUARD;
if (pop(stack).toBool()) {
inst = INST_FETCH(1);
} else {
inst = INST_FETCH(inst.X);
}
}
INST_DISPATCH;
case INST(JMP): {
INST_GUARD;
inst = INST_FETCH(inst.X);
}
INST_DISPATCH;
case INST(LOOP): {
INST_GUARD;
// stack: iteration_count, max_iter, cond, loop_carried_deps...
auto fr = stack.end() - (inst.N + 1);
int64_t trip_count = fr[0].toInt();
int64_t max_trip_count = fr[1].toInt();
bool cond = fr[2].toBool();
if (trip_count < max_trip_count && cond) {
fr[2] = trip_count;
fr[0] = trip_count + 1;
inst = INST_FETCH(1);
} else {
size_t n_loop_carried = inst.N - 2;
for (const auto i : c10::irange(n_loop_carried)) {
fr[i] = std::move(fr[i + 3]);
}
drop(stack, 3); // iteration_count, max_iter, cond
inst = INST_FETCH(inst.X);
}
}
INST_DISPATCH;
case INST(CALL): {
INST_GUARD;
Function* fn = frame.function->function_table_[inst.X];
callFunction(*fn, stack);
continue;
}
case INST(INTERFACE_CALL): {
INST_GUARD;
// note the hash table lookup to find the function
// this can be more optimized if necessary, caching parts
// of the hashing computation or storing the offset when
// the object is turned into an interface
// consider passing
// `frames.back().function->remaining_bailout_depth_` into
// `get_executor().getPlanFor()` to propagate caller's depth
// restrictions onto children while this strategy has a potential to
// reduce the number of compilations for too dynamic callers we
// might miss opportunities where a caller is dynamic but a callee
// gets stable arguments
Function& function =
peek(stack, 0, inst.N)
.toObject()
->type()
->getMethod(
frame.function->constant_table_[inst.X].toStringRef());
callFunction(function, stack);
continue;
}
case INST(RET): {
if (frames.size() > 1) {
leaveFrame();
continue;
}
if (future_) {
auto num_outputs = frames.back().function->n_outputs;
if (num_outputs == 1) {
future_->markCompleted(stack.back());
} else {
future_->markCompleted(
c10::ivalue::Tuple::create(jit::last(stack, num_outputs)));
}
}
// destroy the last frame and call RecordFunction's end callbacks
leaveFrame();
return false;
}
case INST(WAIT): {
INST_GUARD;
auto future = stack.back().toFuture();
if (!future->completed()) {
getOrCreateFuture();
// callback needs to be a struct rather than a lambda so that
// we can move the stack to the other thread
struct Callback {
Callback(
c10::intrusive_ptr<InterpreterStateImpl> state,
Stack stack)
: stateImpl_(std::move(state)),
state_(stateImpl_),
stack_(std::move(stack)),
dist_autograd_context_id_(getDistAutogradContextId()) {
state_ = InterpreterState(stateImpl_);
}
void operator()(c10::ivalue::Future& /* unused */) {
stateImpl_->taskLauncher_(InterpreterContinuation(
state_,
std::move(stack_),
dist_autograd_context_id_,
std::move(tls_state_)));
}
private:
c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
InterpreterState state_;
Stack stack_;
int64_t dist_autograd_context_id_;
// preserve the original ThreadLocalState
at::ThreadLocalState tls_state_;
};
// we are suspending, so we need to reset the stack to where we
// started if it started empty, except for the inputs we can avoid
// a true copy by swapping, which leaves the original stack empty.
Stack copied;
if (stack_start_ == 0) {
copied.swap(stack);
} else {
copied.insert(
copied.begin(),
std::make_move_iterator(stack.begin() + stack_start_),
std::make_move_iterator(stack.end()));
stack.resize(stack_start_);
}
// save pc into the frame so we continue here when restored
future->addCallback(
Callback(intrusive_from_this(), std::move(copied)));
return true;
}
stack.pop_back();
stack.emplace_back(future->value());
}
INST_NEXT;
case INST(PROFILE_OP): {
INST_GUARD;
auto& frame_id_ref = frame.id;
if (!frame_id_ref.has_value()) {
frame_id_ref = Frame::genId();
}
const auto& callback =
frame.function->profile_function_table_[inst.X];
push(stack, c10::IValue{static_cast<int64_t>(*frame_id_ref)});
callback(stack);
}
INST_NEXT;
case INST(FAIL_GUARD): {
INST_GUARD;
// patch FAIL_GUARD back to GUARD
GRAPH_DEBUG(
"Bailout ", inst.X, " triggered via bailout_requests_!");
frame.function->instructions_[frame.pc].op = GUARD;
push(stack, false);
}
INST_NEXT;
case INST(TYPECHECK): {
INST_GUARD;
unsigned num_inputs = inst.N, i = 0;
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
// Check every input's shape against profiled (expected) shape.
for (i = 0; i < num_inputs; i++) {
auto& input = peek(stack, i, num_inputs);
auto& t = input.toTensor();
const TypePtr& expected = frame.function->type_table_[inst.X + i];
auto* expected_type = expected->castRaw<TensorType>();
if (t.defined() && !expected_type->matchTensor(t)) {
push(stack, false);
break;
}
}
if (i == num_inputs) {
push(stack, true);
}
}
INST_NEXT;
case INST(GUARD): {
INST_GUARD;
if (!stack.back().isTensor()) {
// stack.back() is an Uninitialized IValue and this is a guard
// on a block output. Uninitialized IValues are never used
// so it's safe to pass this guard check
push(stack, true);
} else {
auto& t = stack.back().toTensor();
const TypePtr& expected = frame.function->type_table_[inst.X];
auto* expected_type = expected->castRaw<TensorType>();
if (t.defined() &&
!frames.back().symbols2dims.bindSymbolicShapes(
t.sizes(), expected_type->symbolic_sizes())) {
push(stack, false);
} else {
push(stack, expected_type->matchTensor(t));
}
}
}
INST_NEXT;
case INST(TAIL_CALL): {
INST_GUARD;
GRAPH_DEBUG("running TAIL_CALL for ", inst.X);
frame.function->function_table_[inst.X]->ensure_defined();
size_t remaining_bailout_depth =
frame.function->remaining_bailout_depth_ > 0
? frame.function->remaining_bailout_depth_ - 1
: 0;
auto& f = *frame.function->function_table_[inst.X];
size_t num_inputs = f.num_inputs();
size_t base_pointer = frame.base_pointer;
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
size_t inputs_start = stack.size() - num_inputs;
for (const auto i : c10::irange(num_inputs)) {
stack.at(base_pointer + i) =
std::move(stack.at(inputs_start + i));
}
stack.resize(base_pointer + num_inputs);
leaveFrame();
callFunction(f, stack, remaining_bailout_depth, false);
continue;
}
case INST(LIST_UNPACK): {
INST_GUARD;
listUnpack(stack, inst.X);
}
INST_NEXT;
case INST(TUPLE_CONSTRUCT): {
INST_GUARD;
tupleConstruct(stack, inst.X);
}
INST_NEXT;
case INST(TUPLE_SLICE): {
INST_GUARD;
tupleSlice(stack, inst.X, inst.X + inst.N);
}
INST_NEXT;
case INST(NAMED_TUPLE_CONSTRUCT): {
INST_GUARD;
namedTupleConstruct(
stack,
frame.function->type_table_[inst.X]->expect<TupleType>(),
inst.N);
}
INST_NEXT;
case INST(LIST_CONSTRUCT): {
INST_GUARD;
const auto& type =
frame.function->type_table_[inst.X]->expectRef<ListType>();
listConstruct(stack, type, inst.N);
}
INST_NEXT;
case INST(DICT_CONSTRUCT): {
INST_GUARD;
const auto& type =
frame.function->type_table_[inst.X]->expectRef<DictType>();
dictConstruct(stack, type, inst.N);
}
INST_NEXT;
case INST(CREATE_OBJECT): {
INST_GUARD;
auto type =
frame.function->type_table_[inst.X]->expect<ClassType>();
createObject(stack, type);
}
INST_NEXT;
case INST(ISINSTANCE): {
INST_GUARD;
at::ArrayRef<TypePtr> types(
&frame.function->type_table_[inst.X],
&frame.function->type_table_[inst.X] + inst.N);
isinstance(stack, types);
}
INST_NEXT;
case INST(TUPLE_INDEX): {
INST_GUARD;
tupleIndex(stack);
}
INST_NEXT;
case INST(RAISE_EXCEPTION): {
INST_GUARD;
raiseExceptionWithMessage(stack);
}
INST_NEXT;
case INST(UNCHECKED_CAST): {
INST_GUARD;
noop(stack);
}
INST_NEXT;
case INST(__IS__): {
INST_GUARD;
is(stack);
}
INST_NEXT;
case INST(UN_INITIALIZED): {
INST_GUARD;
unInitialized(stack);
}
INST_NEXT;
case INST(__ISNOT__): {
INST_GUARD;
isNot(stack);
}
INST_NEXT;
case INST(FORMAT): {
INST_GUARD;
format(stack, inst.X);
}
INST_NEXT;
case INST(DEVICE): {
INST_GUARD;
device(stack);
}
INST_NEXT;
case INST(DTYPE): {
INST_GUARD;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!stack.empty());
dtype(stack);
}
INST_NEXT;
case INST(DIM): {
INST_GUARD;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!stack.empty());
dim(stack);
}
INST_NEXT;
case INST(__NOT__): {
INST_GUARD;
_not(stack);
}
INST_NEXT;
case INST(DICT_INDEX): {
INST_GUARD;
dictIndex(stack);
}
INST_NEXT;
case INST(TO_LIST): {
INST_GUARD;
toList(stack);
}
INST_NEXT;
case INST(NUM_TO_TENSOR): {
INST_GUARD;
numToTensorScalar(stack);
}
INST_NEXT;
case INST(IS_CUDA): {
INST_GUARD;
isCuda(stack);
}
INST_NEXT;
case INST(FORK): {
INST_GUARD;
// Move inputs to a separate stack
auto& forked_fn =
toGraphFunction(*frame.function->function_table_[inst.X]);
InterpreterState forked_interpreter(
forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_);
InterpreterContinuation continuation(
forked_interpreter,
Stack(stack.end() - inst.N, stack.end()),
getDistAutogradContextId());
drop(stack, inst.N);
push(stack, forked_interpreter.getFuture());
taskLauncher_(std::move(continuation));
}
INST_NEXT;
case INST(AWAITABLE): {
INST_GUARD;
auto fn_ptr = frame.function->function_table_[inst.X];
auto& fn = toGraphFunction(*fn_ptr);
auto num_outputs = fn.graph()->outputs().size();
TypePtr out_type;
if (num_outputs == 1) {
out_type = fn.graph()->outputs()[0]->type();
} else {
std::vector<TypePtr> out_types;
for (const auto& o : fn.graph()->outputs()) {
out_types.push_back(o->type());
}
out_type = TupleType::create(out_types);
}
auto args = std::vector<IValue>(stack.end() - inst.N, stack.end());
auto aw = c10::make_intrusive<c10::ivalue::Await>(out_type);
aw->setArgs(std::move(args));
aw->setFn(
[&args = aw->args(),
fn_ptr,
taskLauncher = taskLauncher_]() -> IValue {
auto& fn = toGraphFunction(*fn_ptr);
auto n_out = fn.graph()->outputs().size();
torch::jit::Stack s;
for (const auto& arg : args) {
s.push_back(arg);
}
InterpreterState await_interpreter(
fn.get_executor().getPlanFor(s).code, taskLauncher);
await_interpreter.run(s);
if (n_out == 1) {
return s.back();
}
return c10::ivalue::Tuple::create(jit::last(s, n_out));
});
drop(stack, inst.N);
push(stack, std::move(aw));
}
INST_NEXT;
case INST(WARN): {
INST_GUARD;
// Keeps track of which WARN instruction has been executed before,
// we only want to execute each WARN once to match default Python
// warning behavior.
bool need_warn = true;
if (inst.X != -1) {
need_warn = warned_nodes_.insert(inst.X);
}
Node* node =
frames.back().function->instructions_source_.at(frame.pc);
auto range = node->sourceRange().source();
if (range->filename()) {
drop(stack, 1);
const auto& msg = stack.back().toStringRef();
if (need_warn) {
auto line = range->starting_line_no() +
range->lineno_for_offset(node->sourceRange().start());
c10::SourceLocation location{
"", range->filename()->c_str(), uint32_t(line)};
// Sends the warning to the warning handler with the
// "verbatim" flag. This flag ensures the warning handler
// will print the exception as configured.
c10::warn(c10::Warning(
c10::UserWarning(), location, msg, /*verbatim=*/true));
}
stack.pop_back();
} else {
const auto& msg = stack.back().toStringRef();
if (need_warn) {
TORCH_WARN(msg);
}
stack.pop_back();
}
}
INST_NEXT;
}
}
} catch (std::exception& e) {
for (auto it = entered_objects.rbegin(), end = entered_objects.rend();
it != end;
++it) {
auto& f = it->toObject()->type()->getMethod("__exit__");
Stack stack;
push(stack, *it);
push(stack, IValue());
push(stack, IValue());
push(stack, IValue());
try {
f.run(stack);
} catch (std::exception& _) {
// TODO(T98048876): Handle `_` correctly.
}
}
if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
if (future_) {
future_->setError(std::current_exception());
return false;
}
throw;
}
auto* jit_exception = dynamic_cast<JITException*>(&e);
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
c10::optional<std::string> python_class_name;
if (jit_exception) {
python_class_name = jit_exception->getPythonClassName();
}
handleError(
e, (bool)jit_exception, not_implemented_error, python_class_name);
return false;
}
}
#undef INST_NEXT
#undef INST_DISPATCH
#undef INST
#undef INST_GUARD
#undef INST_FETCH
#undef JIT_USE_COMPUTED_GOTO
void formatStackTrace(std::ostream& out) {
format_stack_trace(out, callstack());
}
void handleError(
const std::exception& e,
bool is_jit_exception,
c10::NotImplementedError* not_implemented_error,
c10::optional<std::string> python_class_name) {
ExceptionMessage msg(e);
std::ostringstream ss;
std::string class_name =
python_class_name ? *python_class_name : "RuntimeError";
ss << "The following operation failed in the TorchScript interpreter.\n";
formatStackTrace(ss);
ss << class_name << ": " << msg << "\n";
if (future_) {
future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
} else if (is_jit_exception) {
// save the original exception's message when creating a new JITException
throw JITException(ss.str(), python_class_name, e.what());
} else if (not_implemented_error) {
throw c10::NotImplementedError(
ss.str(),
not_implemented_error->backtrace(),
not_implemented_error->caller());
} else {
if (get_cpp_stacktraces_enabled()) {
ss << e.what() << "\n";
}
throw std::runtime_error(ss.str());
}
}
static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
if (!frame.record_function) {
auto step_callbacks = at::getStepCallbacksUnlessEmpty(
at::RecordScope::TORCHSCRIPT_FUNCTION);
if (C10_UNLIKELY(step_callbacks.has_value())) {
auto rec_fn =
std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
if (rec_fn->needsInputs()) {
rec_fn->before(
frame.function->function_name_,
last(stack, frame.function->n_inputs));
} else {
rec_fn->before(frame.function->function_name_);
}
frame.record_function = std::move(rec_fn);
}
}
}
public:
// One way to avoid overhead of forming string would be to return
// a vector of frame.function, i.e. CodeImpl*
// This is not exactly clean as it will expose, internal details of
// interpreter. But this way we hold onto graph/node and Function and
// we can create module hierarchy string for each event in autograd
// profiler at the end, when consolidating events.
// At the moment overhead does not seem exhorbitantly large.
// Another option would be return vector of (string, InlinedCallstackPtrs)
// string would contain function name and typename of self
// Format of the returned vector of strings:
// For each frame, the corresponding module name, type and function name
// are in following format:
// <module-instance-name>(module type)::<function-name>
// Special keys for module-instance-name:
// - TOP: for top level module
// - SELF: When method/function of the frame is associated with
// previous frame's module instance
// - INSTANCE_NAME_UNKNOWN: instance name cannot be figured out
// - CALL_FUNCTION: call to free function
std::vector<std::string> moduleHierarchy() const {
std::vector<std::string> module_function_list;
std::string module_hierarchy("TOP");
for (size_t i = 0; i < frames.size(); ++i) {
const Frame& frame = frames[i];
std::string fn_name = frame.function->function_name_;
// For each frame, type of the class with which the function is
// associated, is queried here. And the type name is added to
// module hierarchy.
const auto& g = frame.function->graph_;
std::string g_self_type;
if (g && !g->inputs().empty()) {
const auto& g_self_type_ptr =
g->inputs()[0]->type()->cast<c10::ClassType>();
if (g_self_type_ptr) {
g_self_type = g_self_type_ptr->name()->qualifiedName();
g_self_type = g_self_type.substr(g_self_type.find_last_of('.') + 1);
}
}
module_hierarchy.append("(")
.append(g_self_type)
.append(")::")
.append(fn_name);
module_function_list.emplace_back(std::move(module_hierarchy));
size_t pc = frame.pc;
// CALL nodes have already advanced the pc, so
// undo that to report the call node
if (i + 1 < frames.size()) {
--pc;
}
Node* node = frame.function->instructions_source_[pc];
if (node->callstack()) {
for (const auto& p : (*node->callstack())->vec()) {
fn_name = std::get<0>(p)->name();
const auto& opt_module_info = std::get<2>(p);
if (opt_module_info.has_value()) {
const auto& module_instance_info = opt_module_info.value();
module_hierarchy = utils::get_module_info(module_instance_info);
module_hierarchy.append("::").append(fn_name);
} else {
// This is likely a call to free function, not associated with
// any class
module_hierarchy = "::";
module_hierarchy.append(fn_name);
}
module_function_list.emplace_back(std::move(module_hierarchy));
}
}
module_hierarchy = std::string();
// If this node is of type callMethod then the following frame
// will contain the op being executed.
// For such callMethod node, we add the object instance name
// associated with it, since the following frame will not have it.
if (node->kind() == prim::CallMethod) {
std::string class_instance_name;
if (node->input(0)->node()->kind() == prim::GetAttr) {
class_instance_name = node->input(0)->node()->s(attr::name);
} else if (
!node->owningGraph()->inputs().empty() &&
node->input(0) == node->owningGraph()->inputs()[0]) {
class_instance_name = "SELF";
} else {
class_instance_name = "INSTANCE_NAME_UNKNOWN";
}
module_hierarchy = std::move(class_instance_name);
} else if (node->kind() == prim::CallFunction) {
auto function_constant = node->input(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
auto fun_name = fun_type->function()->name();
module_hierarchy = "CALL_FUNCTION::";
module_hierarchy.append(fun_name);
}
}
return module_function_list;
}
std::vector<StackEntry> callstack() const {
std::vector<StackEntry> entries;
for (const auto i : c10::irange(frames.size())) {
const Frame& frame = frames[i];
std::string previous_fn_name = frame.function->function_name_;
size_t pc = frame.pc;
// CALL nodes have already advanced the pc, so
// undo that to report the call node
if (i + 1 < frames.size()) {
--pc;
}
Node* node = frame.function->instructions_source_[pc];
if (node->callstack()) {
for (const auto& p : (*node->callstack())->vec()) {
entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)});
previous_fn_name = std::get<0>(p)->name();
}
}
entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()});
}
return entries;
}
c10::intrusive_ptr<Future> getOrCreateFuture() {
if (!future_) {
future_ =
c10::make_intrusive<Future>(frames.front().function->return_type_);
}
return future_;
}
c10::intrusive_ptr<Future> runAsync(Stack& stack) {
getOrCreateFuture();
runImpl(stack);
return future_;
}
void run(Stack& stack) {
// By the time the continuation completes the frame will be gone, so this
// must be done before calling runImpl().
TORCH_INTERNAL_ASSERT(!frames.empty());
const auto num_outputs = frames.front().function->n_outputs;
if (runImpl(stack)) {
future_->wait();
if (num_outputs == 1) {
push(stack, future_->value());
} else {
auto tuple = future_->value().toTuple();
for (const IValue& value : tuple->elements()) {
push(stack, value);
}
}
}
}
};
std::vector<StackEntry> currentCallstack() {
if (tls_int_state_ptr_) {
auto cs = tls_int_state_ptr_->callstack();
std::reverse(cs.begin(), cs.end());
return cs;
}
return std::vector<StackEntry>();
}
std::vector<std::string> currentModuleHierarchy() {
if (tls_int_state_ptr_) {
return tls_int_state_ptr_->moduleHierarchy();
}
return std::vector<std::string>();
}
std::ostream& operator<<(std::ostream& out, const Code& code) {
out << *code.pImpl->graph_ << "\n";
code.pImpl->dump(out);
return out;
}
Code::Code(
const std::shared_ptr<Graph>& graph,
std::string function_name,
size_t remaining_bailout_depth)
: pImpl(new CodeImpl(
graph,
std::move(function_name),
remaining_bailout_depth)) {}
Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {}
Code::~Code() = default;
MobileCode::MobileCode(
const std::shared_ptr<Graph>& graph,
std::string function_name,
bool emit_default_input_instructions,
bool support_default_args_before_out,
bool emit_promoted_ops,
size_t remaining_bailout_depth)
: Code(new interpreter::MobileCodeImpl(
graph,
std::move(function_name),
emit_default_input_instructions,
support_default_args_before_out,
emit_promoted_ops,
remaining_bailout_depth)) {}
MobileCode::~MobileCode() = default;
const std::vector<GraphExecutor*>& Code::grad_executors() {
return pImpl->grad_executors();
}
const std::vector<GraphExecutor*>& Code::diff_graph_op_executors() {
return pImpl->diff_graph_op_executors();
}
size_t Code::num_bailouts() const {
return pImpl->type_table_.size();
}
void Code::request_bailout(size_t index) {
pImpl->request_bailout(index);
}
size_t Code::num_inputs() const {
return pImpl->n_inputs;
}
size_t Code::num_outputs() const {
return pImpl->n_outputs;
}
const std::vector<c10::IValue>& Code::constant_table() const {
return pImpl->constant_table();
}
const std::vector<Instruction>& Code::instructions() const {
return pImpl->instructions();
}
const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
const {
return pImpl->op_to_num_specified_args();
}
const std::vector<Node*>& Code::instructions_source() const {
return pImpl->instructions_source();
}
const std::vector<TypePtr>& Code::type_table() const {
return pImpl->type_table_;
}
size_t Code::register_size() const {
return pImpl->register_size_;
}
InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
: pImpl(c10::make_intrusive<InterpreterStateImpl>(
code,
std::move(taskLauncher))) {}
InterpreterState::~InterpreterState() = default;
void InterpreterState::run(Stack& stack) {
static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
}
c10::intrusive_ptr<Future> InterpreterState::runAsync(Stack& stack) {
return static_cast<InterpreterStateImpl*>(pImpl.get())->runAsync(stack);
}
c10::intrusive_ptr<Future> InterpreterState::getFuture() {
return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
}
InterpreterState::InterpreterState(
c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
: pImpl(std::move(pImpl_)) {}
void InterpreterContinuation::operator()() {
#ifdef USE_RPC
auto prev_dist_id = DistAutogradContainer::currentContextId();
DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_);
#endif
if (tls_state_ != c10::nullopt) {
at::ThreadLocalStateGuard g(*tls_state_);
state.runAsync(stack);
} else {
state.runAsync(stack);
}
#ifdef USE_RPC
DistAutogradContainer::forceCurrentContextId(prev_dist_id);
#endif
}
} // namespace jit
} // namespace torch