| #include <onnx/onnx_pb.h> |
| #include <torch/csrc/onnx/back_compat.h> |
| #include <torch/csrc/onnx/init.h> |
| #include <torch/csrc/onnx/onnx.h> |
| #include <torch/version.h> |
| |
| #include <torch/csrc/Exceptions.h> |
| #include <torch/csrc/jit/passes/onnx.h> |
| #include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h> |
| #include <torch/csrc/jit/passes/onnx/constant_fold.h> |
| #include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h> |
| #include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h> |
| #include <torch/csrc/jit/passes/onnx/eval_peephole.h> |
| #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h> |
| #include <torch/csrc/jit/passes/onnx/function_extraction.h> |
| #include <torch/csrc/jit/passes/onnx/function_substitution.h> |
| #include <torch/csrc/jit/passes/onnx/list_model_parameters.h> |
| #include <torch/csrc/jit/passes/onnx/naming.h> |
| #include <torch/csrc/jit/passes/onnx/onnx_log.h> |
| #include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h> |
| #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h> |
| #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h> |
| #include <torch/csrc/jit/passes/onnx/peephole.h> |
| #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h> |
| #include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h> |
| #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h> |
| #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h> |
| #include <torch/csrc/jit/passes/onnx/shape_type_inference.h> |
| #include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h> |
| #include <torch/csrc/jit/serialization/export.h> |
| |
| namespace torch::onnx { |
| |
| using namespace torch::jit; |
| |
| void initONNXBindings(PyObject* module) { |
| auto m = py::handle(module).cast<py::module>(); |
| |
| // ONNX specific passes |
| m.def("_jit_pass_onnx_remove_print", RemovePrintOps) |
| .def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops) |
| .def("_jit_pass_onnx", ToONNX) |
| .def( |
| "_jit_pass_onnx_assign_output_shape", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| const std::vector<at::Tensor>& tensors, |
| const python::IODescriptor& desc, |
| bool onnx_shape_inference, |
| bool is_script, |
| int opset_version) { |
| ONNXAssignOutputShape( |
| graph, |
| tensors, |
| desc, |
| onnx_shape_inference, |
| is_script, |
| opset_version); |
| })) |
| .def( |
| "_jit_pass_onnx_function_substitution", |
| wrap_pybind_function(ONNXFunctionCallSubstitution)) |
| .def( |
| "_jit_pass_onnx_autograd_function_process", |
| wrap_pybind_function(ONNXAutogradFunctionProcess)) |
| .def( |
| "_jit_pass_onnx_peephole", |
| ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph, |
| int opset_version, |
| bool fixed_batch_size) { |
| return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size); |
| })) |
| .def( |
| "_jit_pass_onnx_preprocess", |
| ::torch::wrap_pybind_function(PreprocessForONNX)) |
| .def( |
| "_jit_pass_onnx_eval_peephole", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| std::map<std::string, IValue>& paramsDict) { |
| EvalPeepholeONNX(graph, paramsDict); |
| return paramsDict; |
| }), |
| pybind11::return_value_policy::move) |
| .def( |
| "_jit_pass_onnx_cast_all_constant_to_floating", |
| ::torch::wrap_pybind_function(CastAllConstantToFloating)) |
| .def( |
| "_jit_pass_onnx_constant_fold", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| std::map<std::string, IValue>& paramsDict, |
| int opset_version) { |
| ConstantFoldONNX( |
| graph, |
| paramsDict, |
| opset_version); // overload resolution |
| return paramsDict; |
| }), |
| pybind11::return_value_policy::move) |
| .def( |
| "_jit_pass_onnx_eliminate_unused_items", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| std::map<std::string, IValue>& paramsDict) { |
| EliminateUnusedItemsONNX( |
| graph->block(), |
| paramsDict); // overload resolution |
| return paramsDict; |
| }), |
| pybind11::return_value_policy::move) |
| .def( |
| "_jit_pass_onnx_scalar_type_analysis", |
| ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph, |
| bool lowprecision_cast, |
| int opset_version) { |
| return ScalarTypeAnalysisForONNX( |
| graph, lowprecision_cast, opset_version); |
| }), |
| py::arg("graph"), |
| py::arg("lowprecision_cast") = true, |
| py::arg("opset_version")) |
| .def( |
| "_jit_pass_onnx_remove_inplace_ops_for_onnx", |
| ::torch::wrap_pybind_function(RemoveInplaceOpsForONNX)) |
| .def( |
| "_jit_pass_onnx_node_shape_type_inference", |
| ::torch::wrap_pybind_function( |
| [](Node* n, |
| std::map<std::string, IValue>& params_dict, |
| int opset_version) { |
| ONNXShapeTypeInference(n, params_dict, opset_version); |
| })) |
| .def( |
| "_jit_pass_onnx_graph_shape_type_inference", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| std::map<std::string, IValue>& params_dict, |
| int opset_version) { |
| ONNXShapeTypeInference(graph, params_dict, opset_version); |
| }), |
| py::arg("graph"), |
| py::arg("params_dict"), |
| py::arg("opset_version")) |
| .def( |
| "_jit_pass_onnx_set_dynamic_input_shape", |
| ::torch::wrap_pybind_function(ONNXSetDynamicInputShape)) |
| .def("_jit_pass_onnx_lint", torch::wrap_pybind_function(ONNXLintGraph)) |
| .def( |
| "_jit_pass_onnx_function_extraction", |
| ::torch::wrap_pybind_function( |
| torch::jit::onnx::ONNXFunctionExtraction)) |
| .def("_jit_pass_onnx_block", torch::wrap_pybind_function(BlockToONNX)) |
| .def( |
| "_jit_pass_onnx_unpack_quantized_weights", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| std::map<std::string, IValue>& paramsDict, |
| bool caffe2) { |
| UnpackQuantizedWeights(graph, paramsDict, caffe2); |
| return paramsDict; |
| }), |
| pybind11::return_value_policy::move) |
| .def( |
| "_jit_pass_onnx_quantization_insert_permutes", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| std::map<std::string, IValue>& paramsDict) { |
| insertPermutes(graph, paramsDict); |
| return paramsDict; |
| }), |
| pybind11::return_value_policy::move) |
| .def( |
| "_jit_onnx_list_model_parameters", |
| ::torch::wrap_pybind_function( |
| [](Module& module) { return list_module_parameters(module); })) |
| .def( |
| "_jit_pass_prepare_division_for_onnx", |
| ::torch::wrap_pybind_function(PrepareDivisionForONNX)) |
| .def( |
| "_jit_onnx_convert_pattern_from_subblock", |
| ::torch::wrap_pybind_function(ConvertPatternFromSubblock)) |
| .def( |
| "_jit_pass_fixup_onnx_controlflow_node", |
| ::torch::wrap_pybind_function(FixupONNXControlflowNode)) |
| .def( |
| "_jit_pass_onnx_deduplicate_initializers", |
| ::torch::wrap_pybind_function( |
| [](std::shared_ptr<Graph>& graph, |
| std::map<std::string, IValue> params_dict, |
| bool is_train) { |
| DeduplicateInitializers(graph, params_dict, is_train); |
| return params_dict; |
| }), |
| pybind11::return_value_policy::move) |
| .def( |
| "_jit_pass_onnx_clear_scope_records", |
| &torch::jit::onnx::ONNXClearScopeRecords) |
| .def( |
| "_jit_pass_onnx_track_scope_attributes", |
| &torch::jit::onnx::ONNXTrackScopeAttributes) |
| .def( |
| "_jit_is_onnx_log_enabled", |
| ::torch::jit::onnx::is_log_enabled, |
| "Returns whether ONNX logging is enabled or disabled.") |
| .def( |
| "_jit_set_onnx_log_enabled", |
| ::torch::jit::onnx::set_log_enabled, |
| "Enables or disables ONNX logging.") |
| .def( |
| "_jit_set_onnx_log_output_stream", |
| [](const std::string& stream_name = "stdout") -> void { |
| std::shared_ptr<std::ostream> out; |
| if (stream_name == "stdout") { |
| out = std::shared_ptr<std::ostream>( |
| &std::cout, [](std::ostream*) {}); |
| } else if (stream_name == "stderr") { |
| out = std::shared_ptr<std::ostream>( |
| &std::cerr, [](std::ostream*) {}); |
| } else { |
| std::cerr << "ERROR: only `stdout` and `stderr`" |
| << "are supported as `stream_name`" << std::endl; |
| } |
| ::torch::jit::onnx::set_log_output_stream(out); |
| }, |
| "Set specific file stream for ONNX logging.") |
| .def( |
| "_jit_onnx_log", |
| [](const py::args& args) -> void { |
| if (::torch::jit::onnx::is_log_enabled()) { |
| auto& out = ::torch::jit::onnx::_get_log_output_stream(); |
| for (auto arg : args) { |
| out << ::c10::str(arg); |
| } |
| out << std::endl; |
| } |
| }, |
| "Write `args` to the previously specified ONNX log stream.") |
| .def( |
| "_jit_pass_onnx_assign_scoped_names_for_node_and_value", |
| ::torch::wrap_pybind_function( |
| ::torch::jit::onnx::AssignScopedNamesForNodeAndValue), |
| "Assign informative scoped names for nodes and values.") |
| .def( |
| "_jit_onnx_create_full_scope_name", |
| ::torch::wrap_pybind_function( |
| ::torch::jit::onnx::ONNXScopeName::createFullScopeName), |
| "Create a full scope name from class name and variable name."); |
| |
| m.def( |
| "_check_onnx_proto", |
| ::torch::wrap_pybind_function([](const std::string& proto_string) { |
| check_onnx_proto(proto_string); |
| }), |
| py::arg("proto_string")); |
| |
| auto onnx = m.def_submodule("_onnx"); |
| py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType") |
| .value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) |
| .value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT) |
| .value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8) |
| .value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8) |
| .value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16) |
| .value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16) |
| .value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32) |
| .value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64) |
| .value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING) |
| .value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL) |
| .value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) |
| .value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) |
| .value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32) |
| .value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64) |
| .value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64) |
| .value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128) |
| .value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) |
| .value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN) |
| .value( |
| "FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ) |
| .value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2) |
| .value( |
| "FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ); |
| |
| py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes") |
| .value("ONNX", OperatorExportTypes::ONNX) |
| .value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN) |
| .value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK) |
| .value("ONNX_FALLTHROUGH", OperatorExportTypes::ONNX_FALLTHROUGH); |
| |
| py::enum_<TrainingMode>(onnx, "TrainingMode") |
| .value("EVAL", TrainingMode::EVAL) |
| .value("PRESERVE", TrainingMode::PRESERVE) |
| .value("TRAINING", TrainingMode::TRAINING); |
| |
| onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION); |
| |
| #ifdef BUILD_CAFFE2 |
| onnx.attr("_CAFFE2_ATEN_FALLBACK") = true; |
| #else |
| onnx.attr("_CAFFE2_ATEN_FALLBACK") = false; |
| #endif |
| } |
| } // namespace torch::onnx |