| |
| |
| |
| |
| |
| import numpy as np |
| import unittest |
| import time |
| |
| from caffe2.python import workspace, model_helper |
| from caffe2.python import timeout_guard |
| import caffe2.python.data_workers as data_workers |
| |
| |
| def dummy_fetcher(fetcher_id, batch_size): |
| # Create random amount of values |
| n = np.random.randint(64) + 1 |
| data = np.zeros((n, 3)) |
| labels = [] |
| for j in range(n): |
| data[j, :] *= (j + fetcher_id) |
| labels.append(data[j, 0]) |
| |
| return [np.array(data), np.array(labels)] |
| |
| |
| def dummy_fetcher_rnn(fetcher_id, batch_size): |
| # Hardcoding some input blobs |
| T = 20 |
| N = batch_size |
| D = 33 |
| data = np.random.rand(T, N, D) |
| label = np.random.randint(N, size=(T, N)) |
| seq_lengths = np.random.randint(N, size=(N)) |
| return [data, label, seq_lengths] |
| |
| |
| class DataWorkersTest(unittest.TestCase): |
| |
| def testNonParallelModel(self): |
| workspace.ResetWorkspace() |
| |
| model = model_helper.ModelHelper(name="test") |
| old_seq_id = data_workers.global_coordinator._fetcher_id_seq |
| coordinator = data_workers.init_data_input_workers( |
| model, |
| ["data", "label"], |
| dummy_fetcher, |
| 32, |
| 2, |
| input_source_name="unittest" |
| ) |
| new_seq_id = data_workers.global_coordinator._fetcher_id_seq |
| self.assertEqual(new_seq_id, old_seq_id + 2) |
| |
| coordinator.start() |
| |
| workspace.RunNetOnce(model.param_init_net) |
| workspace.CreateNet(model.net) |
| |
| for _i in range(500): |
| with timeout_guard.CompleteInTimeOrDie(5): |
| workspace.RunNet(model.net.Proto().name) |
| |
| data = workspace.FetchBlob("data") |
| labels = workspace.FetchBlob("label") |
| |
| self.assertEqual(data.shape[0], labels.shape[0]) |
| self.assertEqual(data.shape[0], 32) |
| |
| for j in range(32): |
| self.assertEqual(labels[j], data[j, 0]) |
| self.assertEqual(labels[j], data[j, 1]) |
| self.assertEqual(labels[j], data[j, 2]) |
| |
| coordinator.stop_coordinator("unittest") |
| self.assertEqual(coordinator._coordinators, []) |
| |
| def testRNNInput(self): |
| workspace.ResetWorkspace() |
| model = model_helper.ModelHelper(name="rnn_test") |
| old_seq_id = data_workers.global_coordinator._fetcher_id_seq |
| coordinator = data_workers.init_data_input_workers( |
| model, |
| ["data1", "label1", "seq_lengths1"], |
| dummy_fetcher_rnn, |
| 32, |
| 2, |
| dont_rebatch=False, |
| batch_columns=[1, 1, 0], |
| ) |
| new_seq_id = data_workers.global_coordinator._fetcher_id_seq |
| self.assertEqual(new_seq_id, old_seq_id + 2) |
| |
| coordinator.start() |
| |
| workspace.RunNetOnce(model.param_init_net) |
| workspace.CreateNet(model.net) |
| |
| while coordinator._coordinators[0]._state._inputs < 100: |
| time.sleep(0.01) |
| |
| # Run a couple of rounds |
| workspace.RunNet(model.net.Proto().name) |
| workspace.RunNet(model.net.Proto().name) |
| |
| # Wait for the enqueue thread to get blocked |
| time.sleep(0.2) |
| |
| # We don't dequeue on caffe2 side (as we don't run the net) |
| # so the enqueue thread should be blocked. |
| # Let's now shutdown and see it succeeds. |
| self.assertTrue(coordinator.stop()) |
| |
| @unittest.skip("Test is flaky: https://github.com/pytorch/pytorch/issues/9064") |
| def testInputOrder(self): |
| # |
| # Create two models (train and validation) with same input blobs |
| # names and ensure that both will get the data in correct order |
| # |
| workspace.ResetWorkspace() |
| self.counters = {0: 0, 1: 1} |
| |
| def dummy_fetcher_rnn_ordered1(fetcher_id, batch_size): |
| # Hardcoding some input blobs |
| T = 20 |
| N = batch_size |
| D = 33 |
| data = np.zeros((T, N, D)) |
| data[0][0][0] = self.counters[fetcher_id] |
| label = np.random.randint(N, size=(T, N)) |
| label[0][0] = self.counters[fetcher_id] |
| seq_lengths = np.random.randint(N, size=(N)) |
| seq_lengths[0] = self.counters[fetcher_id] |
| self.counters[fetcher_id] += 1 |
| return [data, label, seq_lengths] |
| |
| workspace.ResetWorkspace() |
| model = model_helper.ModelHelper(name="rnn_test_order") |
| |
| coordinator = data_workers.init_data_input_workers( |
| model, |
| input_blob_names=["data2", "label2", "seq_lengths2"], |
| fetch_fun=dummy_fetcher_rnn_ordered1, |
| batch_size=32, |
| max_buffered_batches=1000, |
| num_worker_threads=1, |
| dont_rebatch=True, |
| input_source_name='train' |
| ) |
| coordinator.start() |
| |
| val_model = model_helper.ModelHelper(name="rnn_test_order_val") |
| coordinator1 = data_workers.init_data_input_workers( |
| val_model, |
| input_blob_names=["data2", "label2", "seq_lengths2"], |
| fetch_fun=dummy_fetcher_rnn_ordered1, |
| batch_size=32, |
| max_buffered_batches=1000, |
| num_worker_threads=1, |
| dont_rebatch=True, |
| input_source_name='val' |
| ) |
| coordinator1.start() |
| |
| workspace.RunNetOnce(model.param_init_net) |
| workspace.CreateNet(model.net) |
| workspace.CreateNet(val_model.net) |
| |
| while coordinator._coordinators[0]._state._inputs < 900: |
| time.sleep(0.01) |
| |
| with timeout_guard.CompleteInTimeOrDie(5): |
| for m in (model, val_model): |
| print(m.net.Proto().name) |
| workspace.RunNet(m.net.Proto().name) |
| last_data = workspace.FetchBlob('data2')[0][0][0] |
| last_lab = workspace.FetchBlob('label2')[0][0] |
| last_seq = workspace.FetchBlob('seq_lengths2')[0] |
| |
| # Run few rounds |
| for _i in range(10): |
| workspace.RunNet(m.net.Proto().name) |
| data = workspace.FetchBlob('data2')[0][0][0] |
| lab = workspace.FetchBlob('label2')[0][0] |
| seq = workspace.FetchBlob('seq_lengths2')[0] |
| self.assertEqual(data, last_data + 1) |
| self.assertEqual(lab, last_lab + 1) |
| self.assertEqual(seq, last_seq + 1) |
| last_data = data |
| last_lab = lab |
| last_seq = seq |
| |
| time.sleep(0.2) |
| |
| self.assertTrue(coordinator.stop()) |