blob: 6216281ecb6c1ef692801d31412feb8be4a07a1b [file] [log] [blame]
#include <torch/csrc/autograd/profiler_python.h>
#include <atomic>
#include <cstdint>
#include <deque>
#include <iostream>
#include <limits>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include <Python.h>
#include <frameobject.h>
#include <ATen/core/TensorBase.h>
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
#include <c10/util/Optional.h>
#include <c10/util/StringUtil.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/containers.h>
#include <torch/csrc/profiler/orchestration/python_tracer.h>
#include <torch/csrc/profiler/util.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_strings.h>
namespace py = pybind11;
namespace torch {
namespace profiler {
namespace impl {
namespace {
enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
static constexpr size_t CallTypeSize = 4;
using no_ephemeral_t = std::tuple<>;
static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();
// ============================================================================
// == Miscellaneous structs and utils =========================================
// ============================================================================
struct CodeLocation {
CodeLocation() = default;
explicit CodeLocation(PyFrameObject* frame)
: line_number_{PyFrame_GetLineNumber(frame)} {
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
filename_ = THPUtils_unpackStringView(code->co_filename).data();
name_ = THPUtils_unpackStringView(code->co_name).data();
}
bool operator==(const CodeLocation& other) const {
return filename_ == other.filename_ && name_ == other.name_ &&
line_number_ == other.line_number_;
}
const char* filename_{nullptr};
const char* name_{nullptr};
int line_number_{0};
};
template <CallType C>
PyCodeObject* getCode();
template <>
PyCodeObject* getCode<CallType::PyModuleCall>() {
static auto module_call_code = []() {
pybind11::gil_scoped_acquire gil;
auto res = py::module::import("torch.nn")
.attr("Module")
.attr("__call__")
.attr("__code__")
.ptr();
TORCH_INTERNAL_ASSERT(PyCode_Check(res));
return (PyCodeObject*)res;
}();
return module_call_code;
};
template <>
PyCodeObject* getCode<CallType::PyOptimizerCall>() {
static auto optimizer_step_code = []() {
pybind11::gil_scoped_acquire gil;
auto res = py::module::import("torch.optim")
.attr("Optimizer")
.attr("_optimizer_step_code")
.attr("__code__")
.ptr();
TORCH_INTERNAL_ASSERT(PyCode_Check(res));
return (PyCodeObject*)res;
}();
return optimizer_step_code;
};
} // namespace
} // namespace impl
} // namespace profiler
} // namespace torch
template <>
struct std::hash<torch::profiler::impl::CodeLocation> {
size_t operator()(const torch::profiler::impl::CodeLocation& x) {
return c10::get_hash(x.filename_, x.name_, x.line_number_);
}
};
namespace torch {
namespace profiler {
namespace impl {
namespace {
// ============================================================================
// == CallTypeHelper: Tools for generic programming on specializations. =======
// ============================================================================
template <template <CallType> class ClassT>
class CallTypeHelper final {
private:
static_assert(
CallType::PyCall == 0,
"CallTypeHelper uses integer math which depends on a zero start.");
static constexpr size_t End = CallTypeSize;
template <size_t... I>
static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
std::index_sequence<I...>);
template <size_t C, typename T, typename FunctorT, typename... Args>
static void map(T& t, FunctorT& f, Args&&... args) {
f(std::get<C>(t), args...);
c10::guts::if_constexpr<C + 1 < End>(
[&](auto _) { map<C + 1>(_(t), f, std::forward<Args>(args)...); });
}
public:
using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));
template <typename FunctorT, typename... Args>
static void map(tuple_type& t, FunctorT& f, Args&&... args) {
map<0>(t, f, std::forward<Args>(args)...);
}
};
// ============================================================================
// == Event type definitions. =================================================
// ============================================================================
// When we are tracing a Python program, the general procedure is to record
// every time we enter or exit a function and later replay these events during
// post processing. Thus, during the profiling phase we want to do the MINIMAL
// amount of work to capture all of the information that we need; otherwise we
// will distort the profile. (While we don't wish to be terribly inefficient
// during post processing, we are willing to do extra fixup work in post if it
// reduces overhead in the profiling phase.)
//
// When the tracer first enters a frame, it constructs a CallKey for that
// location. The contents of the key vary by context. For a python function
// the key is the (PyCodeObject*, int) pair that defines the bytecode of the
// function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
// For a bound C function it is a (non-owning) pointer to the bound function.
// A CallKey should be small, inexpensive, and POD.
//
// We then collect a CallKey<CallType::PyCall> for the calling frame for better
// source tracking. This pair is a `Callsite`, and serves as a first level key
// during tracing. We lookup the Callsite in a thread local cache which maps
// Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
// TraceKey and return. On a cache miss, we use a global value cache to store
// whatever fields we need from the two CallKeys, generate a new TraceKey, and
// update the local cache.
//
// During post processing we:
// 1) Determine the type represented by a TraceKey by checking which
// sub-cache it appears in in the thread local cache.
// 2) Look up the pair of CallKeys from the thread local cache.
// 3) Look up the expanded values of each CallKey from the global value cache.
//
// To add a new event type to the cache:
// 1) Add an entry to the `CallType` enum.
// 2) Add a specialization of Config which defined key_t, ephemeral_t and
// cache_t.
// 3) Add a specialization of ValueCache::store and ValueCache::load.
//
// -------------------------
// -- Ephemeral arguments --
// -------------------------
// The value cache mechanism assumes that `key_t` is enough to specify the
// correct value. However it may not be possible to materialize a value using
// only an instance of `key_t`. As a result, the cache also accepts "ephemeral"
// inputs which can be used to populate the value cache. Ephemeral inputs come
// with two caveats:
// 1) They are NOT safe to save, and cannot be used after `ValueCache::store`.
// 2) They should be used to access data that is not expect to change from
// call to call, such as the name of a function.
template <CallType>
struct Config;
template <>
struct Config<CallType::PyCall> {
using key_t = CodeLocation;
using ephemeral_t = no_ephemeral_t;
using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
static constexpr EventType event_type = EventType::PyCall;
};
template <typename Key, typename Cls, typename ParameterInfo>
struct ExtendedPyCallConfig {
using key_t = Key;
using cls_t = Cls;
using ephemeral_t = PyFrameObject*;
struct ClsAndParameters {
cls_t cls_;
std::vector<ParameterInfo> parameters_;
};
struct Cache {
// `nn.Module.forward` or `optim.Optimizer._optimizer_step_code`
c10::optional<CodeLocation> location_;
ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_;
ska::flat_hash_map<cls_t, at::StringView> cls_names_;
};
using cache_t = Cache;
static constexpr EventType event_type = EventType::PyCall;
};
template <>
struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig<
PyModuleSelf,
PyModuleCls,
NNModuleInfo::ParameterInfo> {};
template <>
struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig<
PyOptimizerSelf,
PyOptimizerCls,
OptimizerInfo::ParameterInfo> {};
template <>
struct Config<CallType::PyCCall> {
using key_t = PyMethod;
using ephemeral_t = PyObject*;
using cache_t = ska::flat_hash_map<key_t, at::StringView>;
static constexpr EventType event_type = EventType::PyCCall;
};
// ============================================================================
// == Callsite & ValueCache: Storage during profiling =========================
// ============================================================================
template <CallType C>
class Callsite {
public:
static constexpr CallType call_type = C;
using key_t = typename Config<C>::key_t;
static_assert(
std::is_trivially_copyable<key_t>::value,
"Key should be trivial, as it is passed by value.");
template <typename U>
Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {}
bool operator==(const Callsite<C>& other) const {
return value_ == other.value_ && caller_ == other.caller_;
}
key_t value_;
Config<CallType::PyCall>::key_t caller_;
};
// ============================================================================
// == Type specific store and load implementations. ===========================
// ============================================================================
using PyCallKey = Config<CallType::PyCall>::key_t;
using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
using PyCCallKey = Config<CallType::PyCCall>::key_t;
using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;
class ValueCache {
public:
ValueCache() = default;
ValueCache(const ValueCache&) = delete;
template <CallType C>
void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);
template <CallType C>
auto load(const Callsite<C>& callsite, size_t python_tid) const {
auto caller = load<CallType::PyCall>(callsite.caller_);
TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
return ExtraFields<Config<C>::event_type>{
/*end_time_ns=*/std::numeric_limits<time_t>::min(),
python_tid,
caller.frame_state_,
load<C>(callsite.value_)};
}
c10::optional<TensorMetadata> recordIfTensor(py::handle p);
std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap(
py::dict tensor_map);
void trimPrefixes();
private:
template <CallType C>
typename ExtraFields<Config<C>::event_type>::args_t load(
const typename Config<C>::key_t&) const;
template <CallType C>
using State = typename Config<C>::cache_t;
CallTypeHelper<State>::tuple_type state_;
};
template <CallType C>
typename Config<C>::cls_t set_class(
ValueCache* value_cache,
typename Config<C>::cache_t& cache,
const typename Config<C>::key_t& key,
const typename Config<C>::ephemeral_t& frame) {
if (C10_UNLIKELY(!cache.location_.has_value())) {
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
TORCH_INTERNAL_ASSERT(code.get() == getCode<C>());
cache.location_ = PyCallKey(frame);
value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t());
}
auto cls_handle = py::handle((PyObject*)key).attr("__class__");
auto cls = typename Config<C>::cls_t(cls_handle.ptr());
if (cache.cls_names_.find(cls) == cache.cls_names_.end()) {
cache.cls_names_[cls] =
at::StringView(py::str(cls_handle.attr("__name__")));
}
return cls;
}
TensorMetadata toTensorMetadata(PyObject* self) {
TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self));
const auto& t = THPVariable_Unpack(self);
RawTensorMetadata m{t};
return TensorMetadata{
m,
t.sizes().vec(),
m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()};
}
c10::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) {
return THPVariable_CheckExact(p.ptr())
? c10::optional<TensorMetadata>{toTensorMetadata(p.ptr())}
: c10::nullopt;
}
std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap(
py::dict tensor_map) {
std::vector<std::pair<std::string, TensorMetadata>> out;
for (auto& it : tensor_map) {
auto* value = it.second.ptr();
if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) {
out.emplace_back(
py::cast<std::string>(it.first), toTensorMetadata(value));
}
}
return out;
}
template <>
void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) {
auto& locations = std::get<CallType::PyCall>(state_);
if (C10_UNLIKELY(locations.find(key) == locations.end())) {
locations[key] = {
key.line_number_,
at::StringView(key.filename_),
at::StringView(key.name_)};
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
const PyCallKey& key) const {
return {std::get<CallType::PyCall>(state_).at(key), c10::nullopt};
}
template <>
void ValueCache::store<CallType::PyModuleCall>(
const PyModuleCallKey& key,
Config<CallType::PyModuleCall>::ephemeral_t frame) {
auto& cache = std::get<CallType::PyModuleCall>(state_);
if (C10_UNLIKELY(
cache.cls_and_parameters_.find(key) ==
cache.cls_and_parameters_.end())) {
auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame);
py::dict params = py::handle((PyObject*)key).attr("_parameters");
std::vector<NNModuleInfo::ParameterInfo> params_;
for (auto& it : params) {
auto* p = it.second.ptr();
if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) {
params_.push_back(
{it.first.cast<std::string>(),
toTensorMetadata(p),
recordIfTensor(py::getattr(it.second, "grad", py::none()))});
}
}
cache.cls_and_parameters_[key] = {cls, std::move(params_)};
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
const PyModuleCallKey& key) const {
auto& cache = std::get<CallType::PyModuleCall>(state_);
TORCH_INTERNAL_ASSERT(cache.location_.has_value());
const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
const auto& cls = cls_and_parameters.cls_;
NNModuleInfo info{
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
return {
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
/*module_info_=*/std::move(info),
/*optimizer_info_=*/c10::nullopt};
}
template <>
void ValueCache::store<CallType::PyOptimizerCall>(
const PyOptimizerCallKey& key,
Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
if (C10_UNLIKELY(
cache.cls_and_parameters_.find(key) ==
cache.cls_and_parameters_.end())) {
auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
const py::handle self{(PyObject*)key};
std::vector<OptimizerInfo::ParameterInfo> params;
for (const auto& i : (py::list)self.attr("param_groups")) {
for (auto& param : py::cast<py::dict>(i).attr("get")("params")) {
if (THPVariable_CheckExact(param.ptr())) {
// While `self.state` is permitted to store data in an arbitrary way,
// all generic optimizers (SGD, Adam, etc) use param as the key since
// the state in question is tied to particular parameters. We can
// relax this assumption if the need arises.
params.push_back(
{toTensorMetadata(param.ptr()),
recordIfTensor(py::getattr(param, "grad", py::none())),
unpackTensorMap(py::cast<py::dict>(self.attr("state"))
.attr("get")(param, py::dict()))});
}
}
}
cache.cls_and_parameters_[key] = {cls, std::move(params)};
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<
CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
auto cls = cls_and_parameters.cls_;
OptimizerInfo info{
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
return {
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
/*module_info_=*/c10::nullopt,
/*optimizer_info_=*/std::move(info)};
}
template <>
void ValueCache::store<CallType::PyCCall>(
const PyCCallKey& key,
Config<CallType::PyCCall>::ephemeral_t arg) {
auto& names = std::get<CallType::PyCCall>(state_);
if (C10_UNLIKELY(names.find(key) == names.end())) {
names[key] = at::StringView(py::repr(arg));
}
}
template <>
ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
const PyCCallKey& key) const {
return std::get<CallType::PyCCall>(state_).at(key);
}
// TODO: Use re2.
void ValueCache::trimPrefixes() {
static const auto prefixes = []() {
pybind11::gil_scoped_acquire gil;
return py::module::import("torch.profiler.python_tracer")
.attr("_prefix_regex")()
.cast<std::vector<std::string>>();
}();
for (auto& it : std::get<CallType::PyCall>(state_)) {
std::string filename = it.second.filename_.str();
for (const auto& p : prefixes) {
if (filename.compare(0, p.size(), p) == 0) {
filename.erase(0, p.size());
it.second.filename_ = at::StringView(filename);
break;
}
}
}
}
// ============================================================================
// == TraceKey cache ==========================================================
// ============================================================================
using python_tracer::TraceKey;
TraceKey nextKey() {
static std::atomic<uint64_t> key{0};
return TraceKey{++key};
}
template <CallType C>
struct TraceKeyCacheState {
struct Hash {
size_t operator()(const Callsite<C>& key) {
return c10::get_hash(key.value_, key.caller_);
}
};
TraceKey intern(
Callsite<C> callsite,
typename Config<C>::ephemeral_t ephemeral,
ValueCache& value_cache) {
auto it = state_.find(callsite);
if (C10_UNLIKELY(it == state_.end())) {
value_cache.store<C>(callsite.value_, ephemeral);
value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t());
it = state_.insert({callsite, nextKey()}).first;
}
return it->second;
}
auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
return std::make_pair(
value_cache.load<C>(callsite.value_),
value_cache.load<CallType::PyCall>(callsite.caller_));
}
ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
};
// ============================================================================
// == Core CPython data types =================================================
// ============================================================================
// PyObject that allows different threads to record events without colliding.
// It is passed as the second argument when enabling tracing via
// `PyEval_SetProfile`.
struct ThreadLocalResults;
struct TraceContext {
PyObject_HEAD;
ThreadLocalResults* thread_local_results_;
};
// CPython boilerplate to define `TraceContext` as a proper python object.
static PyTypeObject TraceContextType = {
PyVarObject_HEAD_INIT(nullptr, 0) "TraceContext", /* tp_name */
sizeof(TraceContext), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0,
/* tp_vectorcall_offset */ // NOLINT: modernize-use-nullptr
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"Python tracer TLS", /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
PyType_GenericNew, /* tp_new */
nullptr /* tp_free */
};
class gil_and_restore_thread {
public:
gil_and_restore_thread()
: gil_(), initial_thread_state_{PyThreadState_Get()} {}
~gil_and_restore_thread() {
PyThreadState_Swap(initial_thread_state_);
// `gil_scoped_acquire` is a bit fragile in on-demand mode:
// https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
if (!Py_IsInitialized()) {
gil_.disarm();
}
}
PyThreadState* initial_thread_state() const {
return initial_thread_state_;
}
private:
pybind11::gil_scoped_acquire gil_;
PyThreadState* initial_thread_state_;
};
// ============================================================================
// == Thread local cache ======================================================
// ============================================================================
class PythonTracer;
struct ThreadLocalResults {
ThreadLocalResults(
PyThreadState* thread_state,
ValueCache* value_cache,
PythonTracer* active_tracer)
: thread_state_{thread_state},
ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
value_cache_{value_cache},
active_tracer_{active_tracer} {
ctx_->thread_local_results_ = this;
}
ThreadLocalResults() = delete;
ThreadLocalResults(const ThreadLocalResults&) = delete;
ThreadLocalResults(ThreadLocalResults&&) = delete;
ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;
~ThreadLocalResults() {
Py_DECREF((PyObject*)ctx_);
}
template <CallType C, EventType E, typename Ephemeral, typename... Args>
TraceKey intern(Ephemeral ephemeral, Args... args) {
static_assert(
Config<C>::event_type == E,
"ThreadLocalResults.intern called from the wrong typed context.");
auto callsite = Callsite<C>(std::forward<Args>(args)...);
return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_);
}
static constexpr size_t BLOCK_SIZE = 1024;
PyThreadState* thread_state_;
TraceContext* ctx_;
ValueCache* value_cache_;
PythonTracer* active_tracer_;
CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
AppendOnlyList<approx_time_t, BLOCK_SIZE> exit_times_;
AppendOnlyList<approx_time_t, BLOCK_SIZE> c_exit_times_;
};
// ============================================================================
// == Tracing implementation ==================================================
// ============================================================================
class PythonTracer final : public python_tracer::PythonTracerBase {
public:
PythonTracer(torch::profiler::impl::RecordQueue* queue);
~PythonTracer() override;
static int pyProfileFn(
PyObject* obj,
PyFrameObject* frame,
int what,
PyObject* arg);
void stop() override;
std::vector<std::shared_ptr<Result>> getEvents(
std::function<time_t(approx_time_t)> time_converter,
std::vector<python_tracer::CompressedEvent>& enters,
time_t end_time_ns) override;
struct StartFrame {
TraceKey trace_key_;
approx_time_t start_time;
};
private:
void recordPyCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
bool is_startup_frame);
void recordCCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
PyObject* arg);
const std::vector<PyThreadState*> interpreterThreads() const;
std::atomic<bool> active_lock_{false};
bool active_{false};
torch::profiler::impl::RecordQueue* queue_;
PyInterpreterState* interpreter_;
PyCodeObject* module_call_code_;
PyCodeObject* optimizer_hook_;
std::vector<StartFrame> start_frames_;
std::deque<ThreadLocalResults> thread_local_results_;
ValueCache value_cache_;
};
const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
pybind11::gil_scoped_acquire gil;
std::vector<PyThreadState*> out;
if (SOFT_ASSERT(interpreter_)) {
auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
while (thread_state != nullptr) {
out.push_back(thread_state);
thread_state = PyThreadState_Next(thread_state);
}
}
return out;
}
PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
: queue_(queue),
interpreter_(nullptr),
module_call_code_(getCode<CallType::PyModuleCall>()),
optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
TORCH_CHECK(queue_ != nullptr);
bool expected{false};
active_ = active_lock_.compare_exchange_strong(expected, true);
if (!active_) {
TORCH_WARN(
"There is already an active Python tracer. "
"Refusing to register profile functions.");
return;
}
gil_and_restore_thread gil;
interpreter_ = PyInterpreterState_Get();
if (!gil.initial_thread_state()) {
TORCH_WARN("PyThreadState_Get returned NULL");
return;
}
// Register the tracer in each thread.
for (const auto thread_state : interpreterThreads()) {
PyThreadState_Swap(thread_state);
thread_local_results_.emplace_back(thread_state, &value_cache_, this);
auto* ctx = thread_local_results_.back().ctx_;
// When we begin profiling there are already frames on the Python
// interpreter stack. To ensure a complete trace, we must push calls
// to all the prior frames onto our event stack. (We stop at depth=128)
std::vector<THPFrameObjectPtr> current_stack;
auto frame = PyEval_GetFrame();
Py_XINCREF(frame);
size_t depth = 0; // Make sure we can't infinite loop.
while (frame != nullptr) {
current_stack.emplace_back(frame);
if (++depth == 128) {
break;
}
// NB: `PyFrame_GetBack` returns a strong reference.
frame = PyFrame_GetBack(frame);
}
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
recordPyCall(thread_local_results_.back(), it->get(), true);
auto frame_refcount = Py_REFCNT(it->get());
// We hold one reference in `current_stack`, and the interpreter holds
// another.
TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
}
// Note:
// This profile will not compose with other CPython profilers, and
// cannot be round tripped via `sys.settrace(sys.gettrace())`
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
}
};
void PythonTracer::stop() {
gil_and_restore_thread gil;
if (active_) {
for (const auto thread_state : interpreterThreads()) {
if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
PyThreadState_Swap(thread_state);
PyEval_SetProfile(nullptr, nullptr);
}
}
auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
active_ = false;
SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
}
}
PythonTracer::~PythonTracer() {
if (active_) {
TORCH_WARN("`PythonTracer::stop()` was not called.");
stop();
}
}
void PythonTracer::recordPyCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
bool is_startup_frame) {
static constexpr auto E = EventType::PyCall;
const auto key = [&]() -> TraceKey {
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
if (code.get() == module_call_code_) {
// By default, CPython stores locals in a "fast" format, with an array
// of names and an array of values. Consequently, frame->f_locals is
// NULL since the interpreter has no need to populate it.
//
// If these arrays were part of the public API then we could very
// quickly access `self`. Unfortunately they are not, and moreover are
// not stable across versions. As a result, we are forced to call
// `PyFrame_FastToLocals` which forces the interpreter to materialize
// the full dict of locals.
auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
Py_INCREF(self.get());
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
TORCH_INTERNAL_ASSERT(back != nullptr);
return tls.intern<CallType::PyModuleCall, E>(
frame, self.get(), back.get());
} else if (code.get() == optimizer_hook_) {
auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
Py_INCREF(self.get());
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
TORCH_INTERNAL_ASSERT(back != nullptr);
return tls.intern<CallType::PyOptimizerCall, E>(
frame, self.get(), back.get());
} else {
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
auto f_back = (back.get() != nullptr) ? back.get() : frame;
return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
}
}();
const auto time = getApproximateTime();
is_startup_frame ? start_frames_.push_back({key, time})
: queue_->getSubqueue()->emplace_py_call(key, time);
}
void PythonTracer::recordCCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
PyObject* arg) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(Py_TYPE(arg) == &PyCFunction_Type);
auto fn = reinterpret_cast<PyCFunctionObject*>(arg);
// NB: For C calls a new frame is not created, so we use `frame` rather than
// `frame->f_back`.
auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
arg, (void*)(fn->m_ml), frame);
queue_->getSubqueue()->emplace_py_call(key, getApproximateTime());
}
// ============================================================================
// == Post processing =========================================================
// ============================================================================
struct Exit {
bool operator>(const Exit& other) const {
return t_ > other.t_;
}
time_t t_;
size_t python_tid_;
};
class PostProcess {
public:
PostProcess(
std::function<time_t(approx_time_t)> time_converter,
std::deque<ThreadLocalResults>& tls,
const ValueCache& value_cache,
time_t end_time_ns)
: end_time_{end_time_ns}, time_converter_{std::move(time_converter)} {
for (size_t python_tid : c10::irange(tls.size())) {
CallTypeHelper<TraceKeyCacheState>::map(
tls[python_tid].trace_keys_, *this, value_cache, python_tid);
addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
}
}
void set_start_frames(
const std::vector<PythonTracer::StartFrame>& start_frames,
std::vector<python_tracer::CompressedEvent>& enters) {
for (const auto& frame : start_frames) {
enters.push_back(
{frame.trace_key_,
NoTID, // Allows us to detect unhandled start frames
{},
time_converter_(frame.start_time)});
}
}
template <CallType C>
void operator()(
const TraceKeyCacheState<C>& trace_cache,
const ValueCache& value_cache,
size_t python_tid) {
for (const auto& it : trace_cache.state_) {
const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
{it.second, value_cache.load(it.first, python_tid)});
TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
}
}
template <EventType E, size_t N>
void addExits(AppendOnlyList<approx_time_t, N>& exits, size_t python_tid) {
for (const auto i : exits) {
get_state<E>().exits_.push({time_converter_(i), python_tid});
}
}
std::vector<std::shared_ptr<Result>> run(
std::vector<python_tracer::CompressedEvent>& enters) {
std::stable_sort(
enters.begin(), enters.end(), [](const auto a, const auto b) {
return a.enter_t_ < b.enter_t_;
});
std::vector<std::shared_ptr<Result>> out;
populate<EventType::PyCall>(enters, out);
populate<EventType::PyCCall>(enters, out);
return out;
}
private:
template <EventType E>
void populate(
std::vector<python_tracer::CompressedEvent>& enters,
std::vector<std::shared_ptr<Result>>& out) {
using stack_t = std::vector<std::shared_ptr<Result>>;
const auto initial_size = out.size();
auto pop = [](stack_t& stack, time_t t) {
TORCH_INTERNAL_ASSERT(stack.size(), "Python replay stack is empty.");
c10::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
stack.pop_back();
};
ska::flat_hash_map<size_t, stack_t> stacks;
auto& state = get_state<E>();
for (const auto& enter : enters) {
auto fields_it = state.fields_.find(enter.key_);
if (fields_it != state.fields_.end()) {
while (!state.exits_.empty() &&
state.exits_.top().t_ < enter.enter_t_) {
auto& exit = state.exits_.top();
pop(stacks[exit.python_tid_], exit.t_);
state.exits_.pop();
}
out.push_back(Result::create(
enter.enter_t_,
enter.system_tid_,
enter.kineto_info_,
fields_it->second));
stacks[fields_it->second.python_tid_].push_back(out.back());
}
}
// Handle events which were still running when profiling ended.
for (auto& i : stacks) {
while (!i.second.empty()) {
pop(i.second, end_time_);
}
}
// Assign system TIDs to start events based on the system TID of the next
// observed event with the same Python TID.
ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
tid_map;
auto it = out.rbegin();
for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) {
const auto python_tid =
c10::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
const auto& tid_info =
tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
.first->second;
(*it)->start_tid_ = tid_info.first;
(*it)->kineto_info_ = tid_info.second;
}
tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
++it;
}
}
template <EventType E>
struct State {
ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_;
};
template <EventType E>
auto& get_state() {
return std::get < E == EventType::PyCall ? 0 : 1 > (state_);
}
time_t end_time_;
std::function<time_t(approx_time_t)> time_converter_;
std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
};
struct PythonIDVisitor {
void operator()(ExtraFields<EventType::PyCall>& py_call) {
py_call.id_ = ++current_python_id_;
if (py_call.module_.has_value()) {
auto& m = py_call.module_;
auto& module_ids = module_ids_[m->cls_];
m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
}
}
void operator()(ExtraFields<EventType::PyCCall>& py_call) {
py_call.id_ = ++current_python_id_;
}
template <typename T>
void operator()(T&) {}
size_t current_python_id_{0};
ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
module_ids_;
};
std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
std::function<time_t(approx_time_t)> time_converter,
std::vector<python_tracer::CompressedEvent>& enters,
time_t end_time_ns) {
value_cache_.trimPrefixes();
PostProcess post_process(
std::move(time_converter),
thread_local_results_,
value_cache_,
end_time_ns);
post_process.set_start_frames(start_frames_, enters);
auto out = post_process.run(enters);
std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
return a->start_time_ns_ < b->start_time_ns_;
});
PythonIDVisitor id_visitor;
for (auto& i : out) {
c10::visit(id_visitor, i->extra_fields_);
}
return out;
}
// ============================================================================
// == API =====================================================================
// ============================================================================
int PythonTracer::pyProfileFn(
PyObject* obj,
PyFrameObject* frame,
int what,
PyObject* arg) {
auto& local_results =
*reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
switch (what) {
case PyTrace_CALL:
local_results.active_tracer_->recordPyCall(local_results, frame, false);
break;
case PyTrace_C_CALL:
local_results.active_tracer_->recordCCall(local_results, frame, arg);
break;
case PyTrace_EXCEPTION:
case PyTrace_RETURN:
local_results.exit_times_.emplace_back(getApproximateTime());
break;
case PyTrace_C_EXCEPTION:
case PyTrace_C_RETURN:
local_results.c_exit_times_.emplace_back(getApproximateTime());
break;
}
return 0;
}
std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
torch::profiler::impl::RecordQueue* queue) {
return std::make_unique<PythonTracer>(queue);
}
} // namespace
} // namespace impl
} // namespace profiler
} // namespace torch
namespace torch {
namespace autograd {
namespace profiler {
namespace python_tracer {
void init() {
pybind11::gil_scoped_acquire gil;
TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
torch::profiler::impl::python_tracer::registerTracer(
&torch::profiler::impl::getTracer);
}
} // namespace python_tracer
} // namespace profiler
} // namespace autograd
} // namespace torch