| #pragma once |
| |
| // Engine implements backpropagation from output variables and their gradients |
| // to "root" variables (variables created by the user with requires_grad=True). |
| |
| #include <ATen/Tensor.h> |
| #include <ATen/ThreadLocalState.h> |
| #include <ATen/core/ivalue.h> |
| #include <torch/csrc/Export.h> |
| #include <torch/csrc/autograd/anomaly_mode.h> |
| #include <torch/csrc/autograd/function.h> |
| #include <torch/csrc/autograd/functions/basic_ops.h> |
| #include <torch/csrc/autograd/graph_task.h> |
| #include <torch/csrc/autograd/input_buffer.h> |
| #include <torch/csrc/autograd/saved_variable_hooks.h> |
| #include <torch/csrc/autograd/utils/warnings.h> |
| |
| #include <c10/util/CallOnce.h> |
| |
| #include <deque> |
| #include <exception> |
| #include <functional> |
| #include <memory> |
| #include <queue> |
| #include <thread> |
| #include <unordered_map> |
| #include <utility> |
| #include <vector> |
| |
| namespace torch { |
| namespace autograd { |
| struct ReadyQueue; |
| } |
| } // namespace torch |
| |
| namespace torch { |
| namespace autograd { |
| |
| // Maximum reentrant backward depth before switching to a new thread |
| // This limit is based on the TSAN's deadlock detector, where it will |
| // fail if a program hold more than 65 locks in one thread at once. |
| // As we hold mutex in every of our custom C++ autograd Node, we would |
| // like to avoid TSAN complains on this when doing reentrant backwards |
| // For reference, see https://github.com/google/sanitizers/issues/950 |
| static constexpr int MAX_DEPTH = 60; |
| |
| void set_device(int device); |
| void validate_outputs( |
| const edge_list& edges, |
| variable_list& grads, |
| const std::function<std::string(const std::string&)>& format_error); |
| |
| struct NodeTask { |
| std::weak_ptr<GraphTask> base_; |
| std::shared_ptr<Node> fn_; |
| // This buffer serves as an implicit "addition" node for all of the |
| // gradients flowing here. Once all the dependencies are finished, we |
| // use the contents of this buffer to run the function. |
| InputBuffer inputs_; |
| // When worker receives a task with isShutdownTask = true, it will immediately |
| // exit. The engine sends a shutdown task to every queue upon its destruction. |
| bool isShutdownTask_; |
| |
| int getReentrantDepth() const; |
| |
| NodeTask( |
| // NOLINTNEXTLINE(modernize-pass-by-value) |
| std::weak_ptr<GraphTask> base, |
| std::shared_ptr<Node> fn, |
| InputBuffer inputs, |
| bool isShutdownTask = false) |
| : base_(base), |
| fn_(std::move(fn)), |
| inputs_(std::move(inputs)), |
| isShutdownTask_(isShutdownTask) {} |
| }; |
| |
| // Guard that sets and restores checkpoint_valid |
| class CheckpointValidGuard { |
| public: |
| explicit CheckpointValidGuard( |
| const std::shared_ptr<const GraphTask>& graph_task); |
| ~CheckpointValidGuard(); |
| |
| private: |
| bool prev_checkpoint_valid_state; |
| }; |
| |
| struct ReadyQueue { |
| private: |
| // Returns true when t2 should be (weakly) BEFORE t1 in the queue. |
| // Shutdown tasks are first and then empty NodeTask are next. |
| struct CompareNodeTaskTime { |
| bool operator()(NodeTask const& t1, NodeTask const& t2) { |
| // NOLINTNEXTLINE(bugprone-branch-clone) |
| if (t2.isShutdownTask_) { |
| return true; |
| } else if (!t1.fn_ || t1.isShutdownTask_) { |
| return false; |
| } else if (!t2.fn_) { |
| return true; |
| } else if (t1.getReentrantDepth() == t2.getReentrantDepth()) { |
| return t1.fn_->sequence_nr() < t2.fn_->sequence_nr(); |
| } else { |
| return t1.getReentrantDepth() < t2.getReentrantDepth(); |
| } |
| } |
| }; |
| |
| // To notify threads waiting on the ReadyQueue of available tasks on the heap_ |
| std::condition_variable not_empty_; |
| // To protect read and writes to heap_ |
| mutable std::mutex mutex_; |
| |
| std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime> |
| heap_; |
| |
| public: |
| // incrementOutstandingTasks indicates whether or not we should increment |
| // 'outstanding_tasks_' for the associated GraphTask. This should mostly |
| // always be true and is only set false in certain cases (see docs for |
| // DistEngine.execute_graph_task_until_ready_queue_empty) |
| void push(NodeTask item, bool incrementOutstandingTasks = true); |
| void pushShutdownTask(); |
| NodeTask pop(); |
| bool empty() const; |
| size_t size() const; |
| }; |
| |
| // A single instance of this struct should be created through the whole process |
| // lifetime. The worker thread creation logic and Engine's destructor rely on |
| // this. |
| struct TORCH_API Engine { |
| /// Returns a reference to a static `Engine` instance. |
| static Engine& get_default_engine(); |
| |
| static Engine& get_base_engine(); |
| |
| Engine(const Engine&) = delete; |
| Engine(Engine&&) = delete; |
| virtual ~Engine(); |
| |
| // Given a list of (Node, input number) pairs computes the value of the graph |
| // by following next_edge references. |
| virtual variable_list execute( |
| const edge_list& roots, |
| const variable_list& inputs, |
| bool keep_graph, |
| bool create_graph, |
| bool accumulate_grad, |
| const edge_list& outputs = {}); |
| |
| // Given a pre-populated GraphTask and GraphRoot, computes the backward pass |
| // for the graph. |
| // |
| // NB: This API should only be used by internal autograd specific |
| // machinery and shouldn't be exposed to users in anyway. |
| virtual c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task( |
| const std::shared_ptr<GraphTask>& graph_task, |
| std::shared_ptr<Node> graph_root, |
| InputBuffer&& input_buffer); |
| |
| virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() { |
| return std::make_unique<AnomalyMetadata>(); |
| } |
| |
| virtual std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() { |
| return nullptr; |
| } |
| |
| // We pass cpu_ready_queue to evaluate_function, so that it knows |
| // the correct ready queue to push to after a NodeTask is ready |
| void evaluate_function( |
| std::shared_ptr<GraphTask>& graph_task, |
| Node* func, |
| InputBuffer& inputs, |
| const std::shared_ptr<ReadyQueue>& cpu_ready_queue); |
| |
| void initialize_device_threads_pool(); |
| virtual void thread_on_exception( |
| std::shared_ptr<GraphTask> graph_task, |
| const std::shared_ptr<Node>& fn, |
| std::exception& e); |
| |
| void queue_callback(std::function<void()> callback); |
| |
| bool is_checkpoint_valid(); |
| |
| // Should be called after fork to notify that worker threads are gone |
| void release_workers(); |
| |
| // Must be called by subclass before destructing to avoid a data-race-on-vptr. |
| void stop(); |
| |
| // Initializes a device thread for the autograd engine. |
| virtual void thread_init( |
| int device, |
| const std::shared_ptr<ReadyQueue>& ready_queue, |
| bool should_increment = true); |
| |
| protected: |
| Engine(); |
| void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr); |
| |
| // initialize the thread local ready queue with the ready queue that is |
| // created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new |
| // ready queue if ready_queue is not provided. |
| void init_local_ready_queue( |
| std::shared_ptr<ReadyQueue> ready_queue = nullptr); |
| |
| std::shared_ptr<ReadyQueue> ready_queue( |
| std::shared_ptr<ReadyQueue> cpu_ready_queue, |
| at::Device device); |
| std::shared_ptr<ReadyQueue> ready_queue_by_index( |
| std::shared_ptr<ReadyQueue> cpu_ready_queue, |
| int device_index); |
| // start device threads (CUDA, XLA, etc.) in Engine, |
| // note that it does NOT start CPU thread. |
| void start_device_threads(); |
| void increment_non_reentrant_thread_count(); |
| void decrement_non_reentrant_thread_count(); |
| virtual void thread_main(const std::shared_ptr<GraphTask>& task); |
| void reentrant_thread_init(); |
| void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task); |
| |
| // Ensures device_ready_queues_ are initialized only once |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| c10::once_flag start_device_threads_flag_; |
| // Safe to read device_ready_queues_ without synchronization after |
| // initialization |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::vector<std::function<void()>> final_callbacks_; |
| // To protect reads and writes to final_callbacks_ |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::mutex post_callbacks_lock_; |
| |
| // How many nested reentrant calls are allowed until a new thread is used |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| int max_recursion_depth_; |
| |
| struct ThreadPoolShared { |
| // Data structures used by the threads for executing reentrant backwards |
| // tasks. See Note [Reentrant backwards] |
| // Number of available threads for processing new GraphTasks. |
| unsigned int num_workers_; |
| // The threads will wait on work_ to be notified of GraphTasks |
| std::condition_variable work_; |
| // To protect reads and writes to graphtask_queue_ and num_workers_ |
| // and for synchronizing creating new threads when needed |
| std::mutex mutex_; |
| // Workers will process the GraphTasks added to this queue. A GraphTask is |
| // allocated inside Engine::execute and lives for the duration of execute |
| std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
| ThreadPoolShared() : num_workers_(0) {} |
| }; |
| |
| // Temporary workaround until shutting down threads is done |
| // We need shared ownership of all these objects because the threads are |
| // leaked when Engine shuts down, so there may be threads waiting on work_ for |
| // the graphtasks_queue_ to be nonempty. |
| // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
| std::shared_ptr<ThreadPoolShared> thread_pool_shared_; |
| |
| private: |
| // Number of non-reentrant threads |
| std::atomic<uint32_t> non_reentrant_device_thread_count_; |
| // Destructor will wait for non-reentrant threads to finish |
| std::condition_variable non_reentrant_device_thread_condvar_; |
| std::mutex non_reentrant_device_thread_mutex_; |
| // stop() must be called before the destruction path goes down to the base |
| // class, in order to avoid a data-race-on-vptr. Use this boolean to guard |
| // whether stop() has already been called, so we can call this in every |
| // destructor of the class hierarchy. |
| bool stopped_{false}; |
| }; |
| |
| // allow python_engine to override the default engine when it loads |
| using EngineStub = Engine& (*)(); |
| TORCH_API void set_default_engine_stub(EngineStub stub); |
| |
| } // namespace autograd |
| } // namespace torch |