blob: f989f4945389a91b91aa3fd20e584d7ff18dd34b [file] [log] [blame]
#include <torch/csrc/jit/mobile/interpreter.h>
#include <ATen/core/class_type.h>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/function.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/operator_name.h>
#include <ATen/record_function.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/backends/backend_exception.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
namespace torch {
namespace jit {
char const* toString(OpCode op);
std::ostream& operator<<(std::ostream& out, Instruction inst);
namespace mobile {
InterpreterState::InterpreterState(const Code& code) {
enterFrame(code);
}
namespace {
static thread_local std::vector<DebugHandle> exception_debug_handles_;
void createObject(Stack& stack, const at::ClassTypePtr& type) {
auto userObj = c10::ivalue::Object::create(
c10::StrongTypePtr(type->compilation_unit(), type),
type->numAttributes());
push(stack, std::move(userObj));
}
void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
at::TypePtr ty = pop(stack).type<c10::DynamicType>();
for (const at::TypePtr& candidate : types) {
if (ty->isSubtypeOf(*candidate)) {
push(stack, true);
return;
}
}
push(stack, false);
}
} // namespace
using namespace at;
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() {
return exception_debug_handles_;
}
void InterpreterState::enterFrame(const Code& code) {
frames_.emplace_back(code);
registers_.resize(registers_.size() + code.register_size_);
}
void InterpreterState::leaveFrame() {
registers_.resize(
registers_.size() - frames_.back().getCode().register_size_);
frames_.pop_back();
}
void InterpreterState::saveExceptionDebugHandles() {
std::vector<DebugHandle> exception_debug_handles;
for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) {
size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0);
if (auto handle = frame->getDebugHandle(pc)) {
exception_debug_handles.push_back(*handle);
} else {
exception_debug_handles.push_back(-1);
}
}
exception_debug_handles_ = std::move(exception_debug_handles);
}
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
bool newFrame =
f.call(stack, [&](const mobile::Code& code) { enterFrame(code); });
(frames_.rbegin() + (newFrame ? 1 : 0))->step();
}
bool InterpreterState::run(Stack& stack) {
while (true) {
try {
auto& frame = frames_.back();
const auto& code = frame.getCode();
const auto pc = frame.getPC();
auto inst = frame.getInstruction();
// If no valid debug handle found then just log pc.
// This is possible when we did not save debug handles
DebugHandle debug_handle = pc;
if (auto handle = frame.getDebugHandle()) {
debug_handle = *handle;
}
// std::cout << "RUNNING " << pc << " " << code.instructions_[pc];
// if (inst.op == OP) {
// std::cout << ", " << code.op_names_[inst.X].name;
// if (!code.op_names_[inst.X].overload_name.empty()) {
// std::cout << "." << code.op_names_[inst.X].overload_name;
// }
// }
// std::cout << std::endl;
// TODO(iliacher): remove the workaround after RecordFunction is in
// Dispatcher
// Check with iliacher if has been done.
// Plus this is not safe as if you throw exception record function will be
// left enabled. That is a TODO
// NOTE: this recordFunction logic takes up ~2-3% of cpu cycles in some
// workflows. do we need it and/or can we opt-out of
// isRecordFunctionEnabled with a macro? if we delete it, things appear to
// work just fine.
bool prev_value = isRecordFunctionEnabled();
if (!prev_value) {
// enable only for the RecordFunction
enableRecordFunction(true);
}
switch (inst.op) {
case OP: {
if (at::hasGlobalCallbacks()) {
if (auto* mobile_debug_info = static_cast<MobileDebugInfo*>(
c10::ThreadLocalDebugInfo::get(
c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) {
mobile_debug_info->setOpIdx(pc);
}
}
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
code.op_names_[inst.X].name, debug_handle, stack);
code.operators_[inst.X](stack);
frame.step();
} break;
case OPN: {
stack.push_back(inst.N);
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
code.op_names_[inst.X].name, debug_handle, stack);
code.operators_[inst.X](stack);
frame.step();
} break;
case CALL: {
auto& function = *frame.getCode().functions_.at(inst.X);
callFunction(function, stack);
} break;
case INTERFACE_CALL: {
torch::jit::Function& method =
peek(stack, 0, inst.N)
.toObject()
->type()
->getMethod(code.constants_[inst.X].toStringRef());
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
method.name(), debug_handle, stack);
callFunction(method, stack);
} break;
case LOAD:
stack.emplace_back(reg(inst.X));
frame.step();
break;
case MOVE:
stack.emplace_back(std::move(reg(inst.X)));
frame.step();
break;
case STORE:
reg(inst.X) = pop(stack);
frame.step();
break;
case STOREN:
for (size_t i = inst.N; i > 0; --i) {
reg(inst.X + i - 1) = pop(stack);
}
frame.step();
break;
case DROP:
pop(stack);
frame.step();
break;
case DROPR:
reg(inst.X) = IValue();
frame.step();
break;
case LOADC:
stack.emplace_back(code.constants_[inst.X]);
frame.step();
break;
case GET_ATTR: {
auto userObj = pop(stack).toObject();
auto value = userObj->getSlot(inst.X);
push(stack, std::move(value));
frame.step();
} break;
case SET_ATTR: {
auto v = pop(stack);
auto userObj = pop(stack).toObject();
// Mobile only: since the number of slots is not known, resize the
// numAttributes before setSlot.
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
while (userObj->type()->numAttributes() <= inst.X) {
std::stringstream ss;
ss << userObj->type()->numAttributes();
userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
}
userObj->setSlot(inst.X, std::move(v));
frame.step();
} break;
case JF:
frame.jump(pop(stack).toBool() ? 1 : inst.X);
break;
case JMP:
frame.jump(inst.X);
break;
case LOOP: {
// stack: iteration_count, max_iter, cond, loop_carried_deps...
auto sframe = stack.end() - (inst.N + 1);
int64_t trip_count = sframe[0].toInt();
int64_t max_trip_count = sframe[1].toInt();
bool cond = sframe[2].toBool();
if (trip_count < max_trip_count && cond) {
sframe[2] = trip_count;
sframe[0] = trip_count + 1;
frame.step();
} else {
size_t n_loop_carried = inst.N - 2;
for (const auto i : c10::irange(n_loop_carried)) {
sframe[i] = std::move(sframe[i + 3]);
}
drop(stack, 3); // iteration_count, max_iter, cond
frame.jump(inst.X);
}
} break;
case RET:
leaveFrame();
if (frames_.size() > 0) {
continue;
}
return false;
case LIST_CONSTRUCT: {
listConstruct(stack, *code.types_.at(inst.X), inst.N);
frame.step();
} break;
case LIST_UNPACK: {
listUnpack(stack, inst.X);
frame.step();
} break;
case TUPLE_CONSTRUCT: {
tupleConstruct(stack, inst.X);
frame.step();
} break;
case TUPLE_SLICE: {
tupleSlice(stack, inst.X, inst.X + inst.N);
frame.step();
} break;
case TUPLE_INDEX: {
tupleIndex(stack);
frame.step();
} break;
case RAISE_EXCEPTION: {
raiseExceptionWithMessage(stack);
frame.step();
} break;
case __IS__: {
is(stack);
frame.step();
} break;
case UN_INITIALIZED: {
unInitialized(stack);
frame.step();
} break;
case __ISNOT__: {
isNot(stack);
frame.step();
} break;
case FORMAT: {
format(stack, inst.X);
frame.step();
} break;
case DEVICE: {
device(stack);
frame.step();
} break;
case DTYPE: {
dtype(stack);
frame.step();
} break;
case DIM: {
dim(stack);
frame.step();
} break;
case __NOT__: {
_not(stack);
frame.step();
} break;
case DICT_INDEX: {
dictIndex(stack);
frame.step();
} break;
case TO_LIST: {
toList(stack);
frame.step();
} break;
case NUM_TO_TENSOR: {
numToTensorScalar(stack);
frame.step();
} break;
case IS_CUDA: {
isCuda(stack);
frame.step();
} break;
case DICT_CONSTRUCT: {
dictConstruct(stack, *code.types_.at(inst.X), inst.N);
frame.step();
} break;
case NAMED_TUPLE_CONSTRUCT: {
namedTupleConstruct(stack, code.types_.at(inst.X), inst.N);
frame.step();
} break;
case CREATE_OBJECT: {
auto type = code.types_.at(inst.X)->expect<c10::ClassType>();
createObject(stack, type);
frame.step();
} break;
case ISINSTANCE: {
at::ArrayRef<TypePtr> types(&code.types_.at(inst.X), inst.N);
isinstance(stack, types);
frame.step();
} break;
case WARN: {
drop(stack, 1);
// Note: Please don't move the pop(stack) code below into the
// TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments
// when STRIP_ERROR_MESSAGES is defined (which happens for production
// mobile builds). This will cause the stack to be in an inconsistent
// state. It has previously resulted in a SEV (S22350).
const auto& sref = stack.back().toStringRef();
TORCH_WARN(sref);
stack.pop_back();
frame.step();
} break;
default:
AT_ERROR(toString(inst.op), " is invalid.");
}
if (!prev_value) {
enableRecordFunction(false);
}
// This exception must be caught first as it derived from c10::Error
} catch (c10::BackendRuntimeException& e) {
saveExceptionDebugHandles();
TORCH_RETHROW(e);
} catch (c10::Error& error) {
// Reason for catching and rethrowing the error is so that we can
// set the exception pc that is queried later
saveExceptionDebugHandles();
TORCH_RETHROW(error);
} catch (...) {
saveExceptionDebugHandles();
throw;
}
// for (auto val : stack) {
// if (val.isTensor()) {
// std::cout << val.toTensor().sizes() << std::endl;
// } else {
// std::cout << val << std::endl;
// }
// }
}
return false;
}
IValue& InterpreterState::reg(size_t reg) {
return *(registers_.end() - reg);
}
} // namespace mobile
} // namespace jit
} // namespace torch