| #include <torch/csrc/jit/mobile/interpreter.h> |
| |
| #include <ATen/core/class_type.h> |
| #include <ATen/core/dynamic_type.h> |
| #include <ATen/core/function.h> |
| #include <ATen/core/jit_type.h> |
| #include <ATen/core/operator_name.h> |
| #include <ATen/record_function.h> |
| #include <c10/util/Exception.h> |
| #include <c10/util/irange.h> |
| #include <torch/csrc/jit/backends/backend_exception.h> |
| #include <torch/csrc/jit/mobile/function.h> |
| #include <torch/csrc/jit/mobile/observer.h> |
| #include <torch/csrc/jit/mobile/promoted_prim_ops.h> |
| #include <torch/csrc/jit/runtime/jit_exception.h> |
| #include <torch/csrc/jit/runtime/vararg_functions.h> |
| |
| namespace torch { |
| namespace jit { |
| char const* toString(OpCode op); |
| std::ostream& operator<<(std::ostream& out, Instruction inst); |
| namespace mobile { |
| InterpreterState::InterpreterState(const Code& code) { |
| enterFrame(code); |
| } |
| |
| namespace { |
| static thread_local std::vector<DebugHandle> exception_debug_handles_; |
| void createObject(Stack& stack, const at::ClassTypePtr& type) { |
| auto userObj = c10::ivalue::Object::create( |
| c10::StrongTypePtr(type->compilation_unit(), type), |
| type->numAttributes()); |
| push(stack, std::move(userObj)); |
| } |
| |
| void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) { |
| at::TypePtr ty = pop(stack).type<c10::DynamicType>(); |
| for (const at::TypePtr& candidate : types) { |
| if (ty->isSubtypeOf(*candidate)) { |
| push(stack, true); |
| return; |
| } |
| } |
| push(stack, false); |
| } |
| } // namespace |
| |
| using namespace at; |
| |
| const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() { |
| return exception_debug_handles_; |
| } |
| |
| void InterpreterState::enterFrame(const Code& code) { |
| frames_.emplace_back(code); |
| registers_.resize(registers_.size() + code.register_size_); |
| } |
| |
| void InterpreterState::leaveFrame() { |
| registers_.resize( |
| registers_.size() - frames_.back().getCode().register_size_); |
| frames_.pop_back(); |
| } |
| |
| void InterpreterState::saveExceptionDebugHandles() { |
| std::vector<DebugHandle> exception_debug_handles; |
| for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) { |
| size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0); |
| if (auto handle = frame->getDebugHandle(pc)) { |
| exception_debug_handles.push_back(*handle); |
| } else { |
| exception_debug_handles.push_back(-1); |
| } |
| } |
| exception_debug_handles_ = std::move(exception_debug_handles); |
| } |
| |
| void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) { |
| bool newFrame = |
| f.call(stack, [&](const mobile::Code& code) { enterFrame(code); }); |
| (frames_.rbegin() + (newFrame ? 1 : 0))->step(); |
| } |
| |
| bool InterpreterState::run(Stack& stack) { |
| while (true) { |
| try { |
| auto& frame = frames_.back(); |
| const auto& code = frame.getCode(); |
| const auto pc = frame.getPC(); |
| auto inst = frame.getInstruction(); |
| // If no valid debug handle found then just log pc. |
| // This is possible when we did not save debug handles |
| |
| DebugHandle debug_handle = pc; |
| if (auto handle = frame.getDebugHandle()) { |
| debug_handle = *handle; |
| } |
| |
| // std::cout << "RUNNING " << pc << " " << code.instructions_[pc]; |
| // if (inst.op == OP) { |
| // std::cout << ", " << code.op_names_[inst.X].name; |
| // if (!code.op_names_[inst.X].overload_name.empty()) { |
| // std::cout << "." << code.op_names_[inst.X].overload_name; |
| // } |
| // } |
| // std::cout << std::endl; |
| |
| // TODO(iliacher): remove the workaround after RecordFunction is in |
| // Dispatcher |
| // Check with iliacher if has been done. |
| // Plus this is not safe as if you throw exception record function will be |
| // left enabled. That is a TODO |
| // NOTE: this recordFunction logic takes up ~2-3% of cpu cycles in some |
| // workflows. do we need it and/or can we opt-out of |
| // isRecordFunctionEnabled with a macro? if we delete it, things appear to |
| // work just fine. |
| bool prev_value = isRecordFunctionEnabled(); |
| if (!prev_value) { |
| // enable only for the RecordFunction |
| enableRecordFunction(true); |
| } |
| switch (inst.op) { |
| case OP: { |
| if (at::hasGlobalCallbacks()) { |
| if (auto* mobile_debug_info = static_cast<MobileDebugInfo*>( |
| c10::ThreadLocalDebugInfo::get( |
| c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) { |
| mobile_debug_info->setOpIdx(pc); |
| } |
| } |
| |
| RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( |
| code.op_names_[inst.X].name, debug_handle, stack); |
| code.operators_[inst.X](stack); |
| frame.step(); |
| } break; |
| case OPN: { |
| stack.push_back(inst.N); |
| RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( |
| code.op_names_[inst.X].name, debug_handle, stack); |
| code.operators_[inst.X](stack); |
| frame.step(); |
| } break; |
| case CALL: { |
| auto& function = *frame.getCode().functions_.at(inst.X); |
| callFunction(function, stack); |
| } break; |
| case INTERFACE_CALL: { |
| torch::jit::Function& method = |
| peek(stack, 0, inst.N) |
| .toObject() |
| ->type() |
| ->getMethod(code.constants_[inst.X].toStringRef()); |
| RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( |
| method.name(), debug_handle, stack); |
| callFunction(method, stack); |
| } break; |
| case LOAD: |
| stack.emplace_back(reg(inst.X)); |
| frame.step(); |
| break; |
| case MOVE: |
| stack.emplace_back(std::move(reg(inst.X))); |
| frame.step(); |
| break; |
| case STORE: |
| reg(inst.X) = pop(stack); |
| frame.step(); |
| break; |
| case STOREN: |
| for (size_t i = inst.N; i > 0; --i) { |
| reg(inst.X + i - 1) = pop(stack); |
| } |
| frame.step(); |
| break; |
| case DROP: |
| pop(stack); |
| frame.step(); |
| break; |
| case DROPR: |
| reg(inst.X) = IValue(); |
| frame.step(); |
| break; |
| case LOADC: |
| stack.emplace_back(code.constants_[inst.X]); |
| frame.step(); |
| break; |
| case GET_ATTR: { |
| auto userObj = pop(stack).toObject(); |
| auto value = userObj->getSlot(inst.X); |
| push(stack, std::move(value)); |
| frame.step(); |
| } break; |
| case SET_ATTR: { |
| auto v = pop(stack); |
| auto userObj = pop(stack).toObject(); |
| // Mobile only: since the number of slots is not known, resize the |
| // numAttributes before setSlot. |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| while (userObj->type()->numAttributes() <= inst.X) { |
| std::stringstream ss; |
| ss << userObj->type()->numAttributes(); |
| userObj->type()->addAttribute(ss.str(), c10::NoneType::get()); |
| } |
| userObj->setSlot(inst.X, std::move(v)); |
| frame.step(); |
| } break; |
| case JF: |
| frame.jump(pop(stack).toBool() ? 1 : inst.X); |
| break; |
| case JMP: |
| frame.jump(inst.X); |
| break; |
| case LOOP: { |
| // stack: iteration_count, max_iter, cond, loop_carried_deps... |
| auto sframe = stack.end() - (inst.N + 1); |
| int64_t trip_count = sframe[0].toInt(); |
| int64_t max_trip_count = sframe[1].toInt(); |
| bool cond = sframe[2].toBool(); |
| if (trip_count < max_trip_count && cond) { |
| sframe[2] = trip_count; |
| sframe[0] = trip_count + 1; |
| frame.step(); |
| } else { |
| size_t n_loop_carried = inst.N - 2; |
| for (const auto i : c10::irange(n_loop_carried)) { |
| sframe[i] = std::move(sframe[i + 3]); |
| } |
| drop(stack, 3); // iteration_count, max_iter, cond |
| frame.jump(inst.X); |
| } |
| } break; |
| case RET: |
| leaveFrame(); |
| if (frames_.size() > 0) { |
| continue; |
| } |
| return false; |
| case LIST_CONSTRUCT: { |
| listConstruct(stack, *code.types_.at(inst.X), inst.N); |
| frame.step(); |
| } break; |
| case LIST_UNPACK: { |
| listUnpack(stack, inst.X); |
| frame.step(); |
| } break; |
| case TUPLE_CONSTRUCT: { |
| tupleConstruct(stack, inst.X); |
| frame.step(); |
| } break; |
| case TUPLE_SLICE: { |
| tupleSlice(stack, inst.X, inst.X + inst.N); |
| frame.step(); |
| } break; |
| case TUPLE_INDEX: { |
| tupleIndex(stack); |
| frame.step(); |
| } break; |
| case RAISE_EXCEPTION: { |
| raiseExceptionWithMessage(stack); |
| frame.step(); |
| } break; |
| case __IS__: { |
| is(stack); |
| frame.step(); |
| } break; |
| case UN_INITIALIZED: { |
| unInitialized(stack); |
| frame.step(); |
| } break; |
| case __ISNOT__: { |
| isNot(stack); |
| frame.step(); |
| } break; |
| case FORMAT: { |
| format(stack, inst.X); |
| frame.step(); |
| } break; |
| case DEVICE: { |
| device(stack); |
| frame.step(); |
| } break; |
| case DTYPE: { |
| dtype(stack); |
| frame.step(); |
| } break; |
| case DIM: { |
| dim(stack); |
| frame.step(); |
| } break; |
| case __NOT__: { |
| _not(stack); |
| frame.step(); |
| } break; |
| case DICT_INDEX: { |
| dictIndex(stack); |
| frame.step(); |
| } break; |
| case TO_LIST: { |
| toList(stack); |
| frame.step(); |
| } break; |
| case NUM_TO_TENSOR: { |
| numToTensorScalar(stack); |
| frame.step(); |
| } break; |
| case IS_CUDA: { |
| isCuda(stack); |
| frame.step(); |
| } break; |
| case DICT_CONSTRUCT: { |
| dictConstruct(stack, *code.types_.at(inst.X), inst.N); |
| frame.step(); |
| } break; |
| case NAMED_TUPLE_CONSTRUCT: { |
| namedTupleConstruct(stack, code.types_.at(inst.X), inst.N); |
| frame.step(); |
| } break; |
| case CREATE_OBJECT: { |
| auto type = code.types_.at(inst.X)->expect<c10::ClassType>(); |
| createObject(stack, type); |
| frame.step(); |
| } break; |
| case ISINSTANCE: { |
| at::ArrayRef<TypePtr> types(&code.types_.at(inst.X), inst.N); |
| isinstance(stack, types); |
| frame.step(); |
| } break; |
| case WARN: { |
| drop(stack, 1); |
| // Note: Please don't move the pop(stack) code below into the |
| // TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments |
| // when STRIP_ERROR_MESSAGES is defined (which happens for production |
| // mobile builds). This will cause the stack to be in an inconsistent |
| // state. It has previously resulted in a SEV (S22350). |
| const auto& sref = stack.back().toStringRef(); |
| TORCH_WARN(sref); |
| stack.pop_back(); |
| frame.step(); |
| } break; |
| default: |
| AT_ERROR(toString(inst.op), " is invalid."); |
| } |
| |
| if (!prev_value) { |
| enableRecordFunction(false); |
| } |
| // This exception must be caught first as it derived from c10::Error |
| } catch (c10::BackendRuntimeException& e) { |
| saveExceptionDebugHandles(); |
| TORCH_RETHROW(e); |
| } catch (c10::Error& error) { |
| // Reason for catching and rethrowing the error is so that we can |
| // set the exception pc that is queried later |
| saveExceptionDebugHandles(); |
| TORCH_RETHROW(error); |
| } catch (...) { |
| saveExceptionDebugHandles(); |
| throw; |
| } |
| // for (auto val : stack) { |
| // if (val.isTensor()) { |
| // std::cout << val.toTensor().sizes() << std::endl; |
| // } else { |
| // std::cout << val << std::endl; |
| // } |
| // } |
| } |
| return false; |
| } |
| |
| IValue& InterpreterState::reg(size_t reg) { |
| return *(registers_.end() - reg); |
| } |
| |
| } // namespace mobile |
| } // namespace jit |
| } // namespace torch |