blob: d520ee2fa7ec4bd6b3b4018259ee0773b3d06d95 [file] [log] [blame]
#include <cstdlib>
#include <iomanip>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/serialization/python_print.h>
namespace torch {
namespace jit {
class JitLoggingConfig {
public:
static JitLoggingConfig& getInstance() {
static JitLoggingConfig instance;
return instance;
}
JitLoggingConfig(JitLoggingConfig const&) = delete;
void operator=(JitLoggingConfig const&) = delete;
private:
std::string logging_levels;
std::unordered_map<std::string, size_t> files_to_levels;
std::ostream* out;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
JitLoggingConfig() {
const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);
out = &std::cerr;
parse();
}
void parse();
public:
std::string getLoggingLevels() const {
return this->logging_levels;
}
void setLoggingLevels(std::string levels) {
this->logging_levels = std::move(levels);
parse();
}
const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
return this->files_to_levels;
}
void setOutputStream(std::ostream& out_stream) {
this->out = &out_stream;
}
std::ostream& getOutputStream() {
return *(this->out);
}
};
std::string get_jit_logging_levels() {
return JitLoggingConfig::getInstance().getLoggingLevels();
}
void set_jit_logging_levels(std::string level) {
JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
}
void set_jit_logging_output_stream(std::ostream& stream) {
JitLoggingConfig::getInstance().setOutputStream(stream);
}
std::ostream& get_jit_logging_output_stream() {
return JitLoggingConfig::getInstance().getOutputStream();
}
// gets a string representation of a node header
// (e.g. outputs, a node kind and outputs)
std::string getHeader(const Node* node) {
std::stringstream ss;
node->print(ss, 0, {}, false, false, false, false);
return ss.str();
}
void JitLoggingConfig::parse() {
std::stringstream in_ss;
in_ss << "function:" << this->logging_levels;
files_to_levels.clear();
std::string line;
while (std::getline(in_ss, line, ':')) {
if (line.empty()) {
continue;
}
auto index_at = line.find_last_of('>');
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
auto end_index = line.find_last_of('.') == std::string::npos
? line.size()
: line.find_last_of('.');
auto filename = line.substr(begin_index, end_index - begin_index);
files_to_levels.insert({filename, logging_level});
}
}
bool is_enabled(const char* cfname, JitLoggingLevels level) {
const auto& files_to_levels =
JitLoggingConfig::getInstance().getFilesToLevels();
std::string fname{cfname};
fname = c10::detail::StripBasename(fname);
const auto end_index = fname.find_last_of('.') == std::string::npos
? fname.size()
: fname.find_last_of('.');
const auto fname_no_ext = fname.substr(0, end_index);
const auto it = files_to_levels.find(fname_no_ext);
if (it == files_to_levels.end()) {
return false;
}
return level <= static_cast<JitLoggingLevels>(it->second);
}
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
std::string jit_log_prefix(
const std::string& prefix,
const std::string& in_str) {
std::stringstream in_ss(in_str);
std::stringstream out_ss;
std::string line;
while (std::getline(in_ss, line)) {
out_ss << prefix << line << std::endl;
}
return out_ss.str();
}
std::string jit_log_prefix(
JitLoggingLevels level,
const char* fn,
int l,
const std::string& in_str) {
std::stringstream prefix_ss;
prefix_ss << "[";
prefix_ss << level << " ";
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
prefix_ss << std::setfill('0') << std::setw(3) << l;
prefix_ss << "] ";
return jit_log_prefix(prefix_ss.str(), in_str);
}
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
switch (level) {
case JitLoggingLevels::GRAPH_DUMP:
out << "DUMP";
break;
case JitLoggingLevels::GRAPH_UPDATE:
out << "UPDATE";
break;
case JitLoggingLevels::GRAPH_DEBUG:
out << "DEBUG";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Invalid level");
}
return out;
}
} // namespace jit
} // namespace torch