| |
| |
| |
| |
| |
| from caffe2.python.dataio import ( |
| CompositeReader, |
| CompositeReaderBuilder, |
| ReaderBuilder, |
| ReaderWithDelay, |
| ReaderWithLimit, |
| ReaderWithTimeLimit, |
| ) |
| from caffe2.python.dataset import Dataset |
| from caffe2.python.db_file_reader import DBFileReader |
| from caffe2.python.pipeline import pipe |
| from caffe2.python.schema import Struct, NewRecord, FeedRecord |
| from caffe2.python.session import LocalSession |
| from caffe2.python.task import TaskGroup, final_output, WorkspaceType |
| from caffe2.python.test_util import TestCase |
| from caffe2.python.cached_reader import CachedReader |
| from caffe2.python import core, workspace, schema |
| from caffe2.python.net_builder import ops |
| |
| import numpy as np |
| import numpy.testing as npt |
| import os |
| import shutil |
| import unittest |
| import tempfile |
| |
| |
| def make_source_dataset(ws, size=100, offset=0, name=None): |
| name = name or "src" |
| src_init = core.Net("{}_init".format(name)) |
| with core.NameScope(name): |
| src_values = Struct(('label', np.array(range(offset, offset + size)))) |
| src_blobs = NewRecord(src_init, src_values) |
| src_ds = Dataset(src_blobs, name=name) |
| FeedRecord(src_blobs, src_values, ws) |
| ws.run(src_init) |
| return src_ds |
| |
| |
| def make_destination_dataset(ws, schema, name=None): |
| name = name or 'dst' |
| dst_init = core.Net('{}_init'.format(name)) |
| with core.NameScope(name): |
| dst_ds = Dataset(schema, name=name) |
| dst_ds.init_empty(dst_init) |
| ws.run(dst_init) |
| return dst_ds |
| |
| |
| class TestReaderBuilder(ReaderBuilder): |
| def __init__(self, name, size, offset): |
| self._schema = schema.Struct( |
| ('label', schema.Scalar()), |
| ) |
| self._name = name |
| self._size = size |
| self._offset = offset |
| self._src_ds = None |
| |
| def schema(self): |
| return self._schema |
| |
| def setup(self, ws): |
| self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size, |
| name=self._name) |
| return {} |
| |
| def new_reader(self, **kwargs): |
| return self._src_ds |
| |
| |
| class TestCompositeReader(TestCase): |
| @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins') |
| def test_composite_reader(self): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| num_srcs = 3 |
| names = ["src_{}".format(i) for i in range(num_srcs)] |
| size = 100 |
| offsets = [i * size for i in range(num_srcs)] |
| src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name) |
| for (name, offset) in zip(names, offsets)] |
| |
| data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses] |
| # Sanity check we didn't overwrite anything |
| for d, offset in zip(data, offsets): |
| npt.assert_array_equal(d, range(offset, offset + size)) |
| |
| # Make an identically-sized empty destination dataset |
| dst_ds_schema = schema.Struct( |
| *[ |
| (name, src_ds.content().clone_schema()) |
| for name, src_ds in zip(names, src_dses) |
| ] |
| ) |
| dst_ds = make_destination_dataset(ws, dst_ds_schema) |
| |
| with TaskGroup() as tg: |
| reader = CompositeReader(names, |
| [src_ds.reader() for src_ds in src_dses]) |
| pipe(reader, dst_ds.writer(), num_runtime_threads=3) |
| session.run(tg) |
| |
| for i in range(num_srcs): |
| written_data = sorted( |
| ws.fetch_blob(str(dst_ds.content()[names[i]].label()))) |
| npt.assert_array_equal(data[i], written_data, "i: {}".format(i)) |
| |
| @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins') |
| def test_composite_reader_builder(self): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| num_srcs = 3 |
| names = ["src_{}".format(i) for i in range(num_srcs)] |
| size = 100 |
| offsets = [i * size for i in range(num_srcs)] |
| src_ds_builders = [ |
| TestReaderBuilder(offset=offset, size=size, name=name) |
| for (name, offset) in zip(names, offsets) |
| ] |
| |
| # Make an identically-sized empty destination dataset |
| dst_ds_schema = schema.Struct( |
| *[ |
| (name, src_ds_builder.schema()) |
| for name, src_ds_builder in zip(names, src_ds_builders) |
| ] |
| ) |
| dst_ds = make_destination_dataset(ws, dst_ds_schema) |
| |
| with TaskGroup() as tg: |
| reader_builder = CompositeReaderBuilder( |
| names, src_ds_builders) |
| reader_builder.setup(ws=ws) |
| pipe(reader_builder.new_reader(), dst_ds.writer(), |
| num_runtime_threads=3) |
| session.run(tg) |
| |
| for name, offset in zip(names, offsets): |
| written_data = sorted( |
| ws.fetch_blob(str(dst_ds.content()[name].label()))) |
| npt.assert_array_equal(range(offset, offset + size), written_data, |
| "name: {}".format(name)) |
| |
| |
| class TestReaderWithLimit(TestCase): |
| def test_runtime_threads(self): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| src_ds = make_source_dataset(ws) |
| totals = [None] * 3 |
| |
| def proc(rec): |
| # executed once |
| with ops.task_init(): |
| counter1 = ops.CreateCounter([], ['global_counter']) |
| counter2 = ops.CreateCounter([], ['global_counter2']) |
| counter3 = ops.CreateCounter([], ['global_counter3']) |
| # executed once per thread |
| with ops.task_instance_init(): |
| task_counter = ops.CreateCounter([], ['task_counter']) |
| # executed on each iteration |
| ops.CountUp(counter1) |
| ops.CountUp(task_counter) |
| # executed once per thread |
| with ops.task_instance_exit(): |
| with ops.loop(ops.RetrieveCount(task_counter)): |
| ops.CountUp(counter2) |
| ops.CountUp(counter3) |
| # executed once |
| with ops.task_exit(): |
| totals[0] = final_output(ops.RetrieveCount(counter1)) |
| totals[1] = final_output(ops.RetrieveCount(counter2)) |
| totals[2] = final_output(ops.RetrieveCount(counter3)) |
| return rec |
| |
| # Read full data set from original reader |
| with TaskGroup() as tg: |
| pipe(src_ds.reader(), num_runtime_threads=8, processor=proc) |
| session.run(tg) |
| self.assertEqual(totals[0].fetch(), 100) |
| self.assertEqual(totals[1].fetch(), 100) |
| self.assertEqual(totals[2].fetch(), 8) |
| |
| # Read with a count-limited reader |
| with TaskGroup() as tg: |
| q1 = pipe(src_ds.reader(), num_runtime_threads=2) |
| q2 = pipe( |
| ReaderWithLimit(q1.reader(), num_iter=25), |
| num_runtime_threads=3) |
| pipe(q2, processor=proc, num_runtime_threads=6) |
| session.run(tg) |
| self.assertEqual(totals[0].fetch(), 25) |
| self.assertEqual(totals[1].fetch(), 25) |
| self.assertEqual(totals[2].fetch(), 6) |
| |
| def _test_limit_reader_init_shared(self, size): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| |
| # Make source dataset |
| src_ds = make_source_dataset(ws, size=size) |
| |
| # Make an identically-sized empty destination Dataset |
| dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema()) |
| |
| return ws, session, src_ds, dst_ds |
| |
| def _test_limit_reader_shared(self, reader_class, size, expected_read_len, |
| expected_read_len_threshold, |
| expected_finish, num_threads, read_delay, |
| **limiter_args): |
| ws, session, src_ds, dst_ds = \ |
| self._test_limit_reader_init_shared(size) |
| |
| # Read without limiter |
| # WorkspaceType.GLOBAL is required because we are fetching |
| # reader.data_finished() after the TaskGroup finishes. |
| with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: |
| if read_delay > 0: |
| reader = reader_class(ReaderWithDelay(src_ds.reader(), |
| read_delay), |
| **limiter_args) |
| else: |
| reader = reader_class(src_ds.reader(), **limiter_args) |
| pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads) |
| session.run(tg) |
| read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch())) |
| |
| # Do a fuzzy match (expected_read_len +/- expected_read_len_threshold) |
| # to eliminate flakiness for time-limited tests |
| self.assertGreaterEqual( |
| read_len, |
| expected_read_len - expected_read_len_threshold) |
| self.assertLessEqual( |
| read_len, |
| expected_read_len + expected_read_len_threshold) |
| self.assertEqual( |
| sorted(ws.blobs[str(dst_ds.content().label())].fetch()), |
| list(range(read_len)) |
| ) |
| self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(), |
| expected_finish) |
| |
| def test_count_limit_reader_without_limit(self): |
| # No iter count specified, should read all records. |
| self._test_limit_reader_shared(ReaderWithLimit, |
| size=100, |
| expected_read_len=100, |
| expected_read_len_threshold=0, |
| expected_finish=True, |
| num_threads=8, |
| read_delay=0, |
| num_iter=None) |
| |
| def test_count_limit_reader_with_zero_limit(self): |
| # Zero iter count specified, should read 0 records. |
| self._test_limit_reader_shared(ReaderWithLimit, |
| size=100, |
| expected_read_len=0, |
| expected_read_len_threshold=0, |
| expected_finish=False, |
| num_threads=8, |
| read_delay=0, |
| num_iter=0) |
| |
| def test_count_limit_reader_with_low_limit(self): |
| # Read with limit smaller than size of dataset |
| self._test_limit_reader_shared(ReaderWithLimit, |
| size=100, |
| expected_read_len=10, |
| expected_read_len_threshold=0, |
| expected_finish=False, |
| num_threads=8, |
| read_delay=0, |
| num_iter=10) |
| |
| def test_count_limit_reader_with_high_limit(self): |
| # Read with limit larger than size of dataset |
| self._test_limit_reader_shared(ReaderWithLimit, |
| size=100, |
| expected_read_len=100, |
| expected_read_len_threshold=0, |
| expected_finish=True, |
| num_threads=8, |
| read_delay=0, |
| num_iter=110) |
| |
| def test_time_limit_reader_without_limit(self): |
| # No duration specified, should read all records. |
| self._test_limit_reader_shared(ReaderWithTimeLimit, |
| size=100, |
| expected_read_len=100, |
| expected_read_len_threshold=0, |
| expected_finish=True, |
| num_threads=8, |
| read_delay=0.1, |
| duration=0) |
| |
| def test_time_limit_reader_with_short_limit(self): |
| # Read with insufficient time limit |
| size = 50 |
| num_threads = 4 |
| sleep_duration = 0.25 |
| duration = 1 |
| expected_read_len = int(round(num_threads * duration / sleep_duration)) |
| # Because the time limit check happens before the delay + read op, |
| # subtract a little bit of time to ensure we don't get in an extra read |
| duration = duration - 0.25 * sleep_duration |
| |
| # NOTE: `expected_read_len_threshold` was added because this test case |
| # has significant execution variation under stress. Under stress, we may |
| # read strictly less than the expected # of samples; anywhere from |
| # [0,N] where N = expected_read_len. |
| # Hence we set expected_read_len to N/2, plus or minus N/2. |
| self._test_limit_reader_shared(ReaderWithTimeLimit, |
| size=size, |
| expected_read_len=expected_read_len / 2, |
| expected_read_len_threshold=expected_read_len / 2, |
| expected_finish=False, |
| num_threads=num_threads, |
| read_delay=sleep_duration, |
| duration=duration) |
| |
| def test_time_limit_reader_with_long_limit(self): |
| # Read with ample time limit |
| # NOTE: we don't use `expected_read_len_threshold` because the duration, |
| # read_delay, and # threads should be more than sufficient |
| self._test_limit_reader_shared(ReaderWithTimeLimit, |
| size=50, |
| expected_read_len=50, |
| expected_read_len_threshold=0, |
| expected_finish=True, |
| num_threads=4, |
| read_delay=0.2, |
| duration=10) |
| |
| |
| class TestDBFileReader(TestCase): |
| def setUp(self): |
| self.temp_paths = [] |
| |
| def tearDown(self): |
| # In case any test method fails, clean up temp paths. |
| for path in self.temp_paths: |
| self._delete_path(path) |
| |
| @staticmethod |
| def _delete_path(path): |
| if os.path.isfile(path): |
| os.remove(path) # Remove file. |
| elif os.path.isdir(path): |
| shutil.rmtree(path) # Remove dir recursively. |
| |
| def _make_temp_path(self): |
| # Make a temp path as db_path. |
| with tempfile.NamedTemporaryFile() as f: |
| temp_path = f.name |
| self.temp_paths.append(temp_path) |
| return temp_path |
| |
| @staticmethod |
| def _build_source_reader(ws, size): |
| src_ds = make_source_dataset(ws, size) |
| return src_ds.reader() |
| |
| @staticmethod |
| def _read_all_data(ws, reader, session): |
| dst_ds = make_destination_dataset(ws, reader.schema().clone_schema()) |
| |
| with TaskGroup() as tg: |
| pipe(reader, dst_ds.writer(), num_runtime_threads=8) |
| session.run(tg) |
| |
| return ws.blobs[str(dst_ds.content().label())].fetch() |
| |
| @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB") |
| def test_cached_reader(self): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| db_path = self._make_temp_path() |
| |
| # Read data for the first time. |
| cached_reader1 = CachedReader( |
| self._build_source_reader(ws, 100), db_path, loop_over=False, |
| ) |
| build_cache_step = cached_reader1.build_cache_step() |
| session.run(build_cache_step) |
| |
| data = self._read_all_data(ws, cached_reader1, session) |
| self.assertEqual(sorted(data), list(range(100))) |
| |
| # Read data from cache. |
| cached_reader2 = CachedReader( |
| self._build_source_reader(ws, 200), db_path, |
| ) |
| build_cache_step = cached_reader2.build_cache_step() |
| session.run(build_cache_step) |
| |
| data = self._read_all_data(ws, cached_reader2, session) |
| self.assertEqual(sorted(data), list(range(100))) |
| |
| self._delete_path(db_path) |
| |
| # We removed cache so we expect to receive data from original reader. |
| cached_reader3 = CachedReader( |
| self._build_source_reader(ws, 300), db_path, |
| ) |
| build_cache_step = cached_reader3.build_cache_step() |
| session.run(build_cache_step) |
| |
| data = self._read_all_data(ws, cached_reader3, session) |
| self.assertEqual(sorted(data), list(range(300))) |
| |
| self._delete_path(db_path) |
| |
| @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB") |
| def test_db_file_reader(self): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| db_path = self._make_temp_path() |
| |
| # Build a cache DB file. |
| cached_reader = CachedReader( |
| self._build_source_reader(ws, 100), |
| db_path=db_path, |
| db_type='LevelDB', |
| ) |
| build_cache_step = cached_reader.build_cache_step() |
| session.run(build_cache_step) |
| |
| # Read data from cache DB file. |
| db_file_reader = DBFileReader( |
| db_path=db_path, |
| db_type='LevelDB', |
| ) |
| data = self._read_all_data(ws, db_file_reader, session) |
| self.assertEqual(sorted(data), list(range(100))) |
| |
| self._delete_path(db_path) |