| |
| |
| |
| |
| |
| import unittest |
| |
| from caffe2.python import workspace, core |
| import caffe2.python.parallel_workers as parallel_workers |
| |
| |
| def create_queue(): |
| queue = 'queue' |
| |
| workspace.RunOperatorOnce( |
| core.CreateOperator( |
| "CreateBlobsQueue", [], [queue], num_blobs=1, capacity=1000 |
| ) |
| ) |
| # Technically, blob creations aren't thread safe. Since the unittest below |
| # does RunOperatorOnce instead of CreateNet+RunNet, we have to precreate |
| # all blobs beforehand |
| for i in range(100): |
| workspace.C.Workspace.current.create_blob("blob_" + str(i)) |
| workspace.C.Workspace.current.create_blob("status_blob_" + str(i)) |
| workspace.C.Workspace.current.create_blob("dequeue_blob") |
| workspace.C.Workspace.current.create_blob("status_blob") |
| |
| return queue |
| |
| |
| def create_worker(queue, get_blob_data): |
| def dummy_worker(worker_id): |
| blob = 'blob_' + str(worker_id) |
| |
| workspace.FeedBlob(blob, get_blob_data(worker_id)) |
| |
| workspace.RunOperatorOnce( |
| core.CreateOperator( |
| 'SafeEnqueueBlobs', [queue, blob], [blob, 'status_blob_' + str(worker_id)] |
| ) |
| ) |
| |
| return dummy_worker |
| |
| |
| def dequeue_value(queue): |
| dequeue_blob = 'dequeue_blob' |
| workspace.RunOperatorOnce( |
| core.CreateOperator( |
| "SafeDequeueBlobs", [queue], [dequeue_blob, 'status_blob'] |
| ) |
| ) |
| |
| return workspace.FetchBlob(dequeue_blob) |
| |
| |
| class ParallelWorkersTest(unittest.TestCase): |
| def testParallelWorkers(self): |
| workspace.ResetWorkspace() |
| |
| queue = create_queue() |
| dummy_worker = create_worker(queue, lambda worker_id: str(worker_id)) |
| worker_coordinator = parallel_workers.init_workers(dummy_worker) |
| worker_coordinator.start() |
| |
| for _ in range(10): |
| value = dequeue_value(queue) |
| self.assertTrue( |
| value in [b'0', b'1'], 'Got unexpected value ' + str(value) |
| ) |
| |
| self.assertTrue(worker_coordinator.stop()) |
| |
| def testParallelWorkersInitFun(self): |
| workspace.ResetWorkspace() |
| |
| queue = create_queue() |
| dummy_worker = create_worker( |
| queue, lambda worker_id: workspace.FetchBlob('data') |
| ) |
| workspace.FeedBlob('data', 'not initialized') |
| |
| def init_fun(worker_coordinator, global_coordinator): |
| workspace.FeedBlob('data', 'initialized') |
| |
| worker_coordinator = parallel_workers.init_workers( |
| dummy_worker, init_fun=init_fun |
| ) |
| worker_coordinator.start() |
| |
| for _ in range(10): |
| value = dequeue_value(queue) |
| self.assertEqual( |
| value, b'initialized', 'Got unexpected value ' + str(value) |
| ) |
| |
| # A best effort attempt at a clean shutdown |
| worker_coordinator.stop() |
| |
| def testParallelWorkersShutdownFun(self): |
| workspace.ResetWorkspace() |
| |
| queue = create_queue() |
| dummy_worker = create_worker(queue, lambda worker_id: str(worker_id)) |
| workspace.FeedBlob('data', 'not shutdown') |
| |
| def shutdown_fun(): |
| workspace.FeedBlob('data', 'shutdown') |
| |
| worker_coordinator = parallel_workers.init_workers( |
| dummy_worker, shutdown_fun=shutdown_fun |
| ) |
| worker_coordinator.start() |
| |
| self.assertTrue(worker_coordinator.stop()) |
| |
| data = workspace.FetchBlob('data') |
| self.assertEqual(data, b'shutdown', 'Got unexpected value ' + str(data)) |