| # @package parallel_workers |
| # Module caffe2.python.parallel_workers |
| |
| |
| |
| |
| |
| |
| ''' |
| This module provides a python-land multithreaded mechanism for executing work. |
| |
| Basic usage is as follows: |
| coordinator = parallel_workers.init_workers( |
| my_worker_fun, |
| worker_name="train" |
| ) |
| ... |
| coordinator.start() |
| |
| First argument is the function to run in a loop on potentially multiple threads. |
| It has the call signature |
| worker_fun(worker_id) |
| |
| Argument 'worker_name' is used to distinguish different workers, |
| such as workers processing train data or workers processing test data. |
| |
| Optionally, one can define an "init function" that is called once before |
| threads start, and has call signature: |
| my_init_fun(worker_coordinator, global_coordinator) |
| |
| Note that for data_parallel_models, init_workers will be called |
| for each GPU. Note that the 'coordinator' returned by the function is same |
| each time. |
| ''' |
| |
| import logging |
| import threading |
| import atexit |
| import time |
| import collections |
| import traceback |
| |
| from abc import ABCMeta, abstractmethod |
| |
| log = logging.getLogger("parallel_workers") |
| log.setLevel(logging.INFO) |
| LOG_INT_SECS = 60 |
| |
| |
| def init_workers( |
| worker_fun, |
| num_worker_threads=2, |
| worker_name="train", |
| init_fun=None, |
| external_loggers=None, |
| shutdown_fun=None, |
| ): |
| global global_coordinator |
| |
| metrics = Metrics(external_loggers) |
| |
| worker_ids = [ |
| global_coordinator.get_new_worker_id() |
| for i in range(num_worker_threads) |
| ] |
| |
| # Create coordinator object |
| coordinator = WorkerCoordinator( |
| worker_name, worker_ids, init_fun, shutdown_fun=shutdown_fun) |
| |
| # Launch fetch worker threads |
| workers = [ |
| threading.Thread( |
| target=run_worker, |
| name="parallel_workers worker id {}".format(worker_id), |
| args=[coordinator, |
| Worker(coordinator, worker_id, worker_fun, metrics)], |
| ) for worker_id in worker_ids |
| ] |
| |
| coordinator._workers = workers |
| global_coordinator.add(coordinator) |
| |
| return global_coordinator |
| |
| |
| class Metrics(object): |
| def __init__(self, external_loggers): |
| self._metrics = collections.defaultdict(lambda: 0) |
| self._external_loggers = external_loggers |
| |
| def reset_metrics(self): |
| self._metrics = collections.defaultdict(lambda: 0) |
| |
| def log_metrics(self): |
| if not self._external_loggers: |
| return |
| for logger in self._external_loggers: |
| try: |
| logger.log(self._metrics) |
| except Exception as e: |
| print("Failed to call ExternalLogger: {}".format(e)) |
| |
| def put_metric(self, key, value, count=True): |
| self._metrics[key] += value |
| if count: |
| count_key = '{}_count'.format(key) |
| self._metrics[count_key] += 1 |
| |
| |
| class State(): |
| __metaclass__ = ABCMeta |
| |
| @abstractmethod |
| def start(self): |
| pass |
| |
| @abstractmethod |
| def stop(self): |
| pass |
| |
| @abstractmethod |
| def cleanup(self): |
| pass |
| |
| |
| class WorkerCoordinator(object): |
| def __init__( |
| self, worker_name, worker_ids, init_fun, |
| state=None, shutdown_fun=None |
| ): |
| self._active = True |
| self._started = False |
| self._workers = [] |
| self._worker_name = worker_name |
| self._worker_ids = worker_ids |
| self._init_fun = init_fun |
| self._state = state |
| self._shutdown_fun = shutdown_fun |
| |
| def is_active(self): |
| return self._active |
| |
| def init(self, global_coordinator): |
| if self._init_fun and not self._started: |
| data_coordinator = self |
| self._init_fun(data_coordinator, global_coordinator) |
| |
| def _start(self): |
| if self._started: |
| return |
| self._active = True |
| self._started = True |
| if self._state: |
| self._state.start() |
| |
| for w in self._workers: |
| w.daemon = True |
| w.start() |
| |
| def _stop(self, reason=None): |
| self._active = False |
| if reason is not None: |
| log.error("Data input failed due to an error: {}".format(reason)) |
| if self._shutdown_fun and self._started: |
| self._shutdown_fun() |
| if self._state: |
| self._state.stop() |
| |
| self._started = False |
| |
| def _wait_finish(self, cleanup=None): |
| print("Wait for workers to die: {}".format(self._worker_name)) |
| for w in self._workers: |
| if w != threading.current_thread(): |
| w.join(5.0) # don't wait forever, thread may be blocked in i/o |
| success = True |
| for w in self._workers: |
| if w.is_alive(): |
| print("Worker {} failed to close while waiting".format(w)) |
| success = False |
| |
| # Release memory for the scratch blobs |
| if success and self._state: |
| self._state.cleanup() |
| |
| print("All workers terminated: {}".format(success)) |
| return success |
| |
| def get_worker_ids(self): |
| return self._worker_ids |
| |
| |
| class GlobalWorkerCoordinator(object): |
| def __init__(self): |
| self._coordinators = [] |
| self._fetcher_id_seq = 0 |
| self._worker_ids = [] |
| self.register_shutdown_handler() |
| |
| def add(self, coordinator): |
| self._coordinators.append(coordinator) |
| |
| def get_new_worker_id(self): |
| worker_id = self._fetcher_id_seq |
| self._worker_ids.append(worker_id) |
| self._fetcher_id_seq += 1 |
| return worker_id |
| |
| def get_worker_ids(self): |
| return self._worker_ids |
| |
| def start(self): |
| # run init and start in separate for loop to |
| # ensure init happens serially before threads are spawn. |
| for c in self._coordinators: |
| c.init(self) |
| for c in self._coordinators: |
| c._start() |
| |
| def stop(self): |
| all_success = True |
| for c in self._coordinators: |
| c._stop() |
| for c in self._coordinators: |
| success = c._wait_finish() |
| all_success = all_success and success |
| self._coordinators = [] |
| return all_success |
| |
| def stop_coordinator(self, worker_name): |
| ''' |
| Stop a specific coordinator |
| ''' |
| for c in self._coordinators: |
| if c._worker_name == worker_name: |
| c._stop() |
| c._wait_finish() |
| self._coordinators = [ |
| c for c in self._coordinators |
| if c._worker_name != worker_name |
| ] |
| |
| def register_shutdown_handler(self): |
| def cleanup(): |
| self.stop() |
| |
| atexit.register(cleanup) |
| |
| |
| class Worker(object): |
| def __init__( |
| self, |
| coordinator, |
| worker_id, |
| worker_fun=None, |
| metrics=None |
| ): |
| self._coordinator = coordinator |
| self._worker_id = worker_id |
| self._worker_fun = worker_fun |
| self._metrics = metrics |
| |
| def start(self): |
| self._start_time = time.time() |
| |
| def run(self): |
| self._worker_fun(self._worker_id) |
| |
| def handle_exception(self, e): |
| traceback.print_exc() |
| logging.exception("Exception in worker", e) |
| self._coordinator._stop("Exception in worker {}: {}".format( |
| self._worker_id, e |
| )) |
| |
| def finish(self): |
| self._metrics.put_metric( |
| 'worker_time', time.time() - self._start_time) |
| self._metrics.log_metrics() |
| |
| |
| global_coordinator = GlobalWorkerCoordinator() |
| |
| |
| def run_worker(coordinator, worker): |
| while coordinator.is_active(): |
| worker.start() |
| try: |
| worker.run() |
| except Exception as e: |
| worker.handle_exception(e) |
| finally: |
| worker.finish() |