| #pragma once |
| |
| #include <torch/csrc/python_headers.h> |
| |
| #include <torch/csrc/autograd/engine.h> |
| #include <torch/csrc/autograd/function.h> |
| |
| bool THPEngine_initModule(PyObject* module); |
| |
| namespace torch { |
| namespace autograd { |
| namespace python { |
| |
| struct PythonEngine : public Engine { |
| static Engine& get_python_engine(); |
| ~PythonEngine() override; |
| void thread_init( |
| int device, |
| const std::shared_ptr<ReadyQueue>& ready_queue, |
| bool should_increment) override; |
| void thread_on_exception( |
| std::shared_ptr<GraphTask> graph_task, |
| const std::shared_ptr<Node>& fn, |
| std::exception& e) override; |
| variable_list execute( |
| const edge_list& roots, |
| const variable_list& inputs, |
| bool keep_graph, |
| bool create_graph, |
| bool accumulate_grad, |
| const edge_list& outputs = {}) override; |
| |
| c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task( |
| const std::shared_ptr<GraphTask>& graph_task, |
| std::shared_ptr<Node> graph_root, |
| InputBuffer&& input_buffer) override; |
| |
| std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override; |
| std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() |
| override; |
| |
| private: |
| PythonEngine(); |
| }; |
| |
| } // namespace python |
| } // namespace autograd |
| } // namespace torch |