| |
| |
| |
| |
| |
| from caffe2.python import core, schema |
| import numpy as np |
| |
| import unittest |
| import pickle |
| import random |
| |
| class TestField(unittest.TestCase): |
| def testInitShouldSetEmptyParent(self): |
| f = schema.Field([]) |
| self.assertTupleEqual(f._parent, (None, 0)) |
| |
| def testInitShouldSetFieldOffsets(self): |
| f = schema.Field([ |
| schema.Scalar(dtype=np.int32), |
| schema.Struct( |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ('field2', schema.List(schema.Scalar(dtype=str))), |
| ), |
| schema.Scalar(dtype=np.int32), |
| schema.Struct( |
| ('field3', schema.Scalar(dtype=np.int32)), |
| ('field4', schema.List(schema.Scalar(dtype=str))) |
| ), |
| schema.Scalar(dtype=np.int32), |
| ]) |
| self.assertListEqual(f._field_offsets, [0, 1, 4, 5, 8, 9]) |
| |
| def testInitShouldSetFieldOffsetsIfNoChildren(self): |
| f = schema.Field([]) |
| self.assertListEqual(f._field_offsets, [0]) |
| |
| |
| class TestDB(unittest.TestCase): |
| def testPicklable(self): |
| s = schema.Struct( |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ('field2', schema.List(schema.Scalar(dtype=str))) |
| ) |
| s2 = pickle.loads(pickle.dumps(s)) |
| for r in (s, s2): |
| self.assertTrue(isinstance(r.field1, schema.Scalar)) |
| self.assertTrue(isinstance(r.field2, schema.List)) |
| self.assertTrue(getattr(r, 'non_existent', None) is None) |
| |
| def testListSubclassClone(self): |
| class Subclass(schema.List): |
| pass |
| |
| s = Subclass(schema.Scalar()) |
| clone = s.clone() |
| self.assertIsInstance(clone, Subclass) |
| self.assertEqual(s, clone) |
| self.assertIsNot(clone, s) |
| |
| def testListWithEvictedSubclassClone(self): |
| class Subclass(schema.ListWithEvicted): |
| pass |
| |
| s = Subclass(schema.Scalar()) |
| clone = s.clone() |
| self.assertIsInstance(clone, Subclass) |
| self.assertEqual(s, clone) |
| self.assertIsNot(clone, s) |
| |
| def testStructSubclassClone(self): |
| class Subclass(schema.Struct): |
| pass |
| |
| s = Subclass( |
| ('a', schema.Scalar()), |
| ) |
| clone = s.clone() |
| self.assertIsInstance(clone, Subclass) |
| self.assertEqual(s, clone) |
| self.assertIsNot(clone, s) |
| |
| def testNormalizeField(self): |
| s = schema.Struct(('field1', np.int32), ('field2', str)) |
| self.assertEquals( |
| s, |
| schema.Struct( |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ('field2', schema.Scalar(dtype=str)) |
| ) |
| ) |
| |
| def testTuple(self): |
| s = schema.Tuple(np.int32, str, np.float32) |
| s2 = schema.Struct( |
| ('field_0', schema.Scalar(dtype=np.int32)), |
| ('field_1', schema.Scalar(dtype=np.str)), |
| ('field_2', schema.Scalar(dtype=np.float32)) |
| ) |
| self.assertEquals(s, s2) |
| self.assertEquals(s[0], schema.Scalar(dtype=np.int32)) |
| self.assertEquals(s[1], schema.Scalar(dtype=np.str)) |
| self.assertEquals(s[2], schema.Scalar(dtype=np.float32)) |
| self.assertEquals( |
| s[2, 0], |
| schema.Struct( |
| ('field_2', schema.Scalar(dtype=np.float32)), |
| ('field_0', schema.Scalar(dtype=np.int32)), |
| ) |
| ) |
| # test iterator behavior |
| for i, (v1, v2) in enumerate(zip(s, s2)): |
| self.assertEquals(v1, v2) |
| self.assertEquals(s[i], v1) |
| self.assertEquals(s2[i], v1) |
| |
| def testRawTuple(self): |
| s = schema.RawTuple(2) |
| self.assertEquals( |
| s, schema.Struct( |
| ('field_0', schema.Scalar()), ('field_1', schema.Scalar()) |
| ) |
| ) |
| self.assertEquals(s[0], schema.Scalar()) |
| self.assertEquals(s[1], schema.Scalar()) |
| |
| def testStructIndexing(self): |
| s = schema.Struct( |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ('field2', schema.List(schema.Scalar(dtype=str))), |
| ('field3', schema.Struct()), |
| ) |
| self.assertEquals(s['field2'], s.field2) |
| self.assertEquals(s['field2'], schema.List(schema.Scalar(dtype=str))) |
| self.assertEquals(s['field3'], schema.Struct()) |
| self.assertEquals( |
| s['field2', 'field1'], |
| schema.Struct( |
| ('field2', schema.List(schema.Scalar(dtype=str))), |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ) |
| ) |
| |
| def testListInStructIndexing(self): |
| a = schema.List(schema.Scalar(dtype=str)) |
| s = schema.Struct( |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ('field2', a) |
| ) |
| self.assertEquals(s['field2:lengths'], a.lengths) |
| self.assertEquals(s['field2:values'], a.items) |
| with self.assertRaises(KeyError): |
| s['fields2:items:non_existent'] |
| with self.assertRaises(KeyError): |
| s['fields2:non_existent'] |
| |
| def testListWithEvictedInStructIndexing(self): |
| a = schema.ListWithEvicted(schema.Scalar(dtype=str)) |
| s = schema.Struct( |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ('field2', a) |
| ) |
| self.assertEquals(s['field2:lengths'], a.lengths) |
| self.assertEquals(s['field2:values'], a.items) |
| self.assertEquals(s['field2:_evicted_values'], a._evicted_values) |
| with self.assertRaises(KeyError): |
| s['fields2:items:non_existent'] |
| with self.assertRaises(KeyError): |
| s['fields2:non_existent'] |
| |
| def testMapInStructIndexing(self): |
| a = schema.Map( |
| schema.Scalar(dtype=np.int32), |
| schema.Scalar(dtype=np.float32), |
| ) |
| s = schema.Struct( |
| ('field1', schema.Scalar(dtype=np.int32)), |
| ('field2', a) |
| ) |
| self.assertEquals(s['field2:values:keys'], a.keys) |
| self.assertEquals(s['field2:values:values'], a.values) |
| with self.assertRaises(KeyError): |
| s['fields2:keys:non_existent'] |
| |
| def testPreservesMetadata(self): |
| s = schema.Struct( |
| ('a', schema.Scalar(np.float32)), ( |
| 'b', schema.Scalar( |
| np.int32, |
| metadata=schema.Metadata(categorical_limit=5) |
| ) |
| ), ( |
| 'c', schema.List( |
| schema.Scalar( |
| np.int32, |
| metadata=schema.Metadata(categorical_limit=6) |
| ) |
| ) |
| ) |
| ) |
| # attach metadata to lengths field |
| s.c.lengths.set_metadata(schema.Metadata(categorical_limit=7)) |
| |
| self.assertEqual(None, s.a.metadata) |
| self.assertEqual(5, s.b.metadata.categorical_limit) |
| self.assertEqual(6, s.c.value.metadata.categorical_limit) |
| self.assertEqual(7, s.c.lengths.metadata.categorical_limit) |
| sc = s.clone() |
| self.assertEqual(None, sc.a.metadata) |
| self.assertEqual(5, sc.b.metadata.categorical_limit) |
| self.assertEqual(6, sc.c.value.metadata.categorical_limit) |
| self.assertEqual(7, sc.c.lengths.metadata.categorical_limit) |
| sv = schema.from_blob_list( |
| s, [ |
| np.array([3.4]), np.array([2]), np.array([3]), |
| np.array([1, 2, 3]) |
| ] |
| ) |
| self.assertEqual(None, sv.a.metadata) |
| self.assertEqual(5, sv.b.metadata.categorical_limit) |
| self.assertEqual(6, sv.c.value.metadata.categorical_limit) |
| self.assertEqual(7, sv.c.lengths.metadata.categorical_limit) |
| |
| def testDupField(self): |
| with self.assertRaises(ValueError): |
| schema.Struct( |
| ('a', schema.Scalar()), |
| ('a', schema.Scalar())) |
| |
| def testAssignToField(self): |
| with self.assertRaises(TypeError): |
| s = schema.Struct(('a', schema.Scalar())) |
| s.a = schema.Scalar() |
| |
| def testPreservesEmptyFields(self): |
| s = schema.Struct( |
| ('a', schema.Scalar(np.float32)), |
| ('b', schema.Struct()), |
| ) |
| sc = s.clone() |
| self.assertIn("a", sc.fields) |
| self.assertIn("b", sc.fields) |
| sv = schema.from_blob_list(s, [np.array([3.4])]) |
| self.assertIn("a", sv.fields) |
| self.assertIn("b", sv.fields) |
| self.assertEqual(0, len(sv.b.fields)) |
| |
| def testStructSubstraction(self): |
| s1 = schema.Struct( |
| ('a', schema.Scalar()), |
| ('b', schema.Scalar()), |
| ('c', schema.Scalar()), |
| ) |
| s2 = schema.Struct( |
| ('b', schema.Scalar()) |
| ) |
| s = s1 - s2 |
| self.assertEqual(['a', 'c'], s.field_names()) |
| |
| s3 = schema.Struct( |
| ('a', schema.Scalar()) |
| ) |
| s = s1 - s3 |
| self.assertEqual(['b', 'c'], s.field_names()) |
| |
| with self.assertRaises(TypeError): |
| s1 - schema.Scalar() |
| |
| def testStructNestedSubstraction(self): |
| s1 = schema.Struct( |
| ('a', schema.Scalar()), |
| ('b', schema.Struct( |
| ('c', schema.Scalar()), |
| ('d', schema.Scalar()), |
| ('e', schema.Scalar()), |
| ('f', schema.Scalar()), |
| )), |
| ) |
| s2 = schema.Struct( |
| ('b', schema.Struct( |
| ('d', schema.Scalar()), |
| ('e', schema.Scalar()), |
| )), |
| ) |
| s = s1 - s2 |
| self.assertEqual(['a', 'b:c', 'b:f'], s.field_names()) |
| |
| def testStructAddition(self): |
| s1 = schema.Struct( |
| ('a', schema.Scalar()) |
| ) |
| s2 = schema.Struct( |
| ('b', schema.Scalar()) |
| ) |
| s = s1 + s2 |
| self.assertIn("a", s.fields) |
| self.assertIn("b", s.fields) |
| with self.assertRaises(TypeError): |
| s1 + s1 |
| with self.assertRaises(TypeError): |
| s1 + schema.Scalar() |
| |
| def testStructNestedAddition(self): |
| s1 = schema.Struct( |
| ('a', schema.Scalar()), |
| ('b', schema.Struct( |
| ('c', schema.Scalar()) |
| )), |
| ) |
| s2 = schema.Struct( |
| ('b', schema.Struct( |
| ('d', schema.Scalar()) |
| )) |
| ) |
| s = s1 + s2 |
| self.assertEqual(['a', 'b:c', 'b:d'], s.field_names()) |
| |
| s3 = schema.Struct( |
| ('b', schema.Scalar()), |
| ) |
| with self.assertRaises(TypeError): |
| s = s1 + s3 |
| |
| def testGetFieldByNestedName(self): |
| st = schema.Struct( |
| ('a', schema.Scalar()), |
| ('b', schema.Struct( |
| ('c', schema.Struct( |
| ('d', schema.Scalar()), |
| )), |
| )), |
| ) |
| self.assertRaises(KeyError, st.__getitem__, '') |
| self.assertRaises(KeyError, st.__getitem__, 'x') |
| self.assertRaises(KeyError, st.__getitem__, 'x:y') |
| self.assertRaises(KeyError, st.__getitem__, 'b:c:x') |
| a = st['a'] |
| self.assertTrue(isinstance(a, schema.Scalar)) |
| bc = st['b:c'] |
| self.assertIn('d', bc.fields) |
| bcd = st['b:c:d'] |
| self.assertTrue(isinstance(bcd, schema.Scalar)) |
| |
| def testAddFieldByNestedName(self): |
| f_a = schema.Scalar(blob=core.BlobReference('blob1')) |
| f_b = schema.Struct( |
| ('c', schema.Struct( |
| ('d', schema.Scalar(blob=core.BlobReference('blob2'))), |
| )), |
| ) |
| f_x = schema.Struct( |
| ('x', schema.Scalar(blob=core.BlobReference('blob3'))), |
| ) |
| |
| with self.assertRaises(TypeError): |
| st = schema.Struct( |
| ('a', f_a), |
| ('b', f_b), |
| ('b:c:d', f_x), |
| ) |
| with self.assertRaises(TypeError): |
| st = schema.Struct( |
| ('a', f_a), |
| ('b', f_b), |
| ('b:c:d:e', f_x), |
| ) |
| |
| st = schema.Struct( |
| ('a', f_a), |
| ('b', f_b), |
| ('e:f', f_x), |
| ) |
| self.assertEqual(['a', 'b:c:d', 'e:f:x'], st.field_names()) |
| self.assertEqual(['blob1', 'blob2', 'blob3'], st.field_blobs()) |
| |
| st = schema.Struct( |
| ('a', f_a), |
| ('b:c:e', f_x), |
| ('b', f_b), |
| ) |
| self.assertEqual(['a', 'b:c:e:x', 'b:c:d'], st.field_names()) |
| self.assertEqual(['blob1', 'blob3', 'blob2'], st.field_blobs()) |
| |
| st = schema.Struct( |
| ('a:a1', f_a), |
| ('b:b1', f_b), |
| ('a', f_x), |
| ) |
| self.assertEqual(['a:a1', 'a:x', 'b:b1:c:d'], st.field_names()) |
| self.assertEqual(['blob1', 'blob3', 'blob2'], st.field_blobs()) |
| |
| def testContains(self): |
| st = schema.Struct( |
| ('a', schema.Scalar()), |
| ('b', schema.Struct( |
| ('c', schema.Struct( |
| ('d', schema.Scalar()), |
| )), |
| )), |
| ) |
| self.assertTrue('a' in st) |
| self.assertTrue('b:c' in st) |
| self.assertTrue('b:c:d' in st) |
| self.assertFalse('' in st) |
| self.assertFalse('x' in st) |
| self.assertFalse('b:c:x' in st) |
| self.assertFalse('b:c:d:x' in st) |
| |
| def testFromEmptyColumnList(self): |
| st = schema.Struct() |
| columns = st.field_names() |
| rec = schema.from_column_list(col_names=columns) |
| self.assertEqual(rec, schema.Struct()) |
| |
| def testFromColumnList(self): |
| st = schema.Struct( |
| ('a', schema.Scalar()), |
| ('b', schema.List(schema.Scalar())), |
| ('c', schema.Map(schema.Scalar(), schema.Scalar())) |
| ) |
| columns = st.field_names() |
| # test that recovery works for arbitrary order |
| for _ in range(10): |
| some_blobs = [core.BlobReference('blob:' + x) for x in columns] |
| rec = schema.from_column_list(columns, col_blobs=some_blobs) |
| self.assertTrue(rec.has_blobs()) |
| self.assertEqual(sorted(st.field_names()), sorted(rec.field_names())) |
| self.assertEqual([str(blob) for blob in rec.field_blobs()], |
| [str('blob:' + name) for name in rec.field_names()]) |
| random.shuffle(columns) |
| |
| def testStructGet(self): |
| net = core.Net('test_net') |
| s1 = schema.NewRecord(net, schema.Scalar(np.float32)) |
| s2 = schema.NewRecord(net, schema.Scalar(np.float32)) |
| t = schema.Tuple(s1, s2) |
| assert t.get('field_0', None) == s1 |
| assert t.get('field_1', None) == s2 |
| assert t.get('field_2', None) is None |
| |
| def testScalarForVoidType(self): |
| s0_good = schema.Scalar((None, (2, ))) |
| with self.assertRaises(TypeError): |
| s0_bad = schema.Scalar((np.void, (2, ))) |
| |
| s1_good = schema.Scalar(np.void) |
| s2_good = schema.Scalar(None) |
| assert s1_good == s2_good |
| |
| def testScalarShape(self): |
| s0 = schema.Scalar(np.int32) |
| self.assertEqual(s0.field_type().shape, ()) |
| |
| s1_good = schema.Scalar((np.int32, 5)) |
| self.assertEqual(s1_good.field_type().shape, (5, )) |
| |
| with self.assertRaises(ValueError): |
| s1_bad = schema.Scalar((np.int32, -1)) |
| |
| s1_hard = schema.Scalar((np.int32, 1)) |
| self.assertEqual(s1_hard.field_type().shape, (1, )) |
| |
| s2 = schema.Scalar((np.int32, (2, 3))) |
| self.assertEqual(s2.field_type().shape, (2, 3)) |
| |
| def testDtypeForCoreType(self): |
| dtype = schema.dtype_for_core_type(core.DataType.FLOAT16) |
| self.assertEqual(dtype, np.float16) |
| |
| with self.assertRaises(TypeError): |
| schema.dtype_for_core_type(100) |