| # Owner(s): ["module: meta tensors"] |
| |
| import copy |
| import gc |
| import random |
| import threading |
| import unittest |
| |
| import torch |
| from torch.testing._internal.common_utils import ( |
| find_library_location, |
| IS_FBCODE, |
| IS_MACOS, |
| IS_SANDCASTLE, |
| IS_WINDOWS, |
| run_tests, |
| TestCase, |
| ) |
| from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary |
| |
| |
| def C(): |
| return torch.randn(1) |
| |
| |
| # These tests are ported from cpython/Lib/test/test_weakref.py, |
| # but adapted to use tensor rather than object |
| class WeakTest(TestCase): |
| COUNT = 10 |
| |
| def test_make_weak_keyed_dict_from_dict(self): |
| o = torch.randn(2) |
| dict = WeakIdKeyDictionary({o: 364}) |
| self.assertEqual(dict[o], 364) |
| |
| def test_make_weak_keyed_dict_from_weak_keyed_dict(self): |
| o = torch.randn(3) |
| dict = WeakIdKeyDictionary({o: 364}) |
| dict2 = WeakIdKeyDictionary(dict) |
| self.assertEqual(dict[o], 364) |
| |
| def check_popitem(self, klass, key1, value1, key2, value2): |
| weakdict = klass() |
| weakdict[key1] = value1 |
| weakdict[key2] = value2 |
| self.assertEqual(len(weakdict), 2) |
| k, v = weakdict.popitem() |
| self.assertEqual(len(weakdict), 1) |
| if k is key1: |
| self.assertIs(v, value1) |
| else: |
| self.assertIs(v, value2) |
| k, v = weakdict.popitem() |
| self.assertEqual(len(weakdict), 0) |
| if k is key1: |
| self.assertIs(v, value1) |
| else: |
| self.assertIs(v, value2) |
| |
| def test_weak_keyed_dict_popitem(self): |
| self.check_popitem(WeakIdKeyDictionary, C(), "value 1", C(), "value 2") |
| |
| def check_setdefault(self, klass, key, value1, value2): |
| self.assertIsNot( |
| value1, |
| value2, |
| "invalid test -- value parameters must be distinct objects", |
| ) |
| weakdict = klass() |
| o = weakdict.setdefault(key, value1) |
| self.assertIs(o, value1) |
| self.assertIn(key, weakdict) |
| self.assertIs(weakdict.get(key), value1) |
| self.assertIs(weakdict[key], value1) |
| |
| o = weakdict.setdefault(key, value2) |
| self.assertIs(o, value1) |
| self.assertIn(key, weakdict) |
| self.assertIs(weakdict.get(key), value1) |
| self.assertIs(weakdict[key], value1) |
| |
| def test_weak_keyed_dict_setdefault(self): |
| self.check_setdefault(WeakIdKeyDictionary, C(), "value 1", "value 2") |
| |
| def check_update(self, klass, dict): |
| # |
| # This exercises d.update(), len(d), d.keys(), k in d, |
| # d.get(), d[]. |
| # |
| weakdict = klass() |
| weakdict.update(dict) |
| self.assertEqual(len(weakdict), len(dict)) |
| for k in weakdict.keys(): |
| self.assertIn(k, dict, "mysterious new key appeared in weak dict") |
| v = dict.get(k) |
| self.assertIs(v, weakdict[k]) |
| self.assertIs(v, weakdict.get(k)) |
| for k in dict.keys(): |
| self.assertIn(k, weakdict, "original key disappeared in weak dict") |
| v = dict[k] |
| self.assertIs(v, weakdict[k]) |
| self.assertIs(v, weakdict.get(k)) |
| |
| def test_weak_keyed_dict_update(self): |
| self.check_update(WeakIdKeyDictionary, {C(): 1, C(): 2, C(): 3}) |
| |
| def test_weak_keyed_delitem(self): |
| d = WeakIdKeyDictionary() |
| o1 = torch.randn(1) |
| o2 = torch.randn(2) |
| d[o1] = "something" |
| d[o2] = "something" |
| self.assertEqual(len(d), 2) |
| del d[o1] |
| self.assertEqual(len(d), 1) |
| self.assertEqual(list(d.keys()), [o2]) |
| |
| def test_weak_keyed_union_operators(self): |
| try: |
| {} | {} |
| except TypeError: |
| self.skipTest("dict union not supported in this Python") |
| |
| o1 = C() |
| o2 = C() |
| o3 = C() |
| wkd1 = WeakIdKeyDictionary({o1: 1, o2: 2}) |
| wkd2 = WeakIdKeyDictionary({o3: 3, o1: 4}) |
| wkd3 = wkd1.copy() |
| d1 = {o2: "5", o3: "6"} |
| pairs = [(o2, 7), (o3, 8)] |
| |
| tmp1 = wkd1 | wkd2 # Between two WeakKeyDictionaries |
| self.assertEqual(dict(tmp1), dict(wkd1) | dict(wkd2)) |
| self.assertIs(type(tmp1), WeakIdKeyDictionary) |
| wkd1 |= wkd2 |
| self.assertEqual(wkd1, tmp1) |
| |
| tmp2 = wkd2 | d1 # Between WeakKeyDictionary and mapping |
| self.assertEqual(dict(tmp2), dict(wkd2) | d1) |
| self.assertIs(type(tmp2), WeakIdKeyDictionary) |
| wkd2 |= d1 |
| self.assertEqual(wkd2, tmp2) |
| |
| tmp3 = wkd3.copy() # Between WeakKeyDictionary and iterable key, value |
| tmp3 |= pairs |
| self.assertEqual(dict(tmp3), dict(wkd3) | dict(pairs)) |
| self.assertIs(type(tmp3), WeakIdKeyDictionary) |
| |
| tmp4 = d1 | wkd3 # Testing .__ror__ |
| self.assertEqual(dict(tmp4), d1 | dict(wkd3)) |
| self.assertIs(type(tmp4), WeakIdKeyDictionary) |
| |
| del o1 |
| self.assertNotIn(4, tmp1.values()) |
| self.assertNotIn(4, tmp2.values()) |
| self.assertNotIn(1, tmp3.values()) |
| self.assertNotIn(1, tmp4.values()) |
| |
| def test_weak_keyed_bad_delitem(self): |
| d = WeakIdKeyDictionary() |
| o = torch.randn(1) |
| # An attempt to delete an object that isn't there should raise |
| # KeyError. It didn't before 2.3. |
| self.assertRaises(KeyError, d.__delitem__, o) |
| self.assertRaises(KeyError, d.__getitem__, o) |
| |
| # If a key isn't of a weakly referencable type, __getitem__ and |
| # __setitem__ raise TypeError. __delitem__ should too. |
| self.assertRaises(TypeError, d.__delitem__, 13) |
| self.assertRaises(TypeError, d.__getitem__, 13) |
| self.assertRaises(TypeError, d.__setitem__, 13, 13) |
| |
| def test_make_weak_keyed_dict_repr(self): |
| dict = WeakIdKeyDictionary() |
| self.assertRegex(repr(dict), "<WeakIdKeyDictionary at 0x.*>") |
| |
| def check_threaded_weak_dict_copy(self, type_, deepcopy): |
| # `deepcopy` should be either True or False. |
| exc = [] |
| |
| # Cannot give these slots as weakrefs weren't supported |
| # on these objects until later versions of Python |
| class DummyKey: # noqa: B903 |
| def __init__(self, ctr): |
| self.ctr = ctr |
| |
| class DummyValue: # noqa: B903 |
| def __init__(self, ctr): |
| self.ctr = ctr |
| |
| def dict_copy(d, exc): |
| try: |
| if deepcopy is True: |
| _ = copy.deepcopy(d) |
| else: |
| _ = d.copy() |
| except Exception as ex: |
| exc.append(ex) |
| |
| def pop_and_collect(lst): |
| gc_ctr = 0 |
| while lst: |
| i = random.randint(0, len(lst) - 1) |
| gc_ctr += 1 |
| lst.pop(i) |
| if gc_ctr % 10000 == 0: |
| gc.collect() # just in case |
| |
| d = type_() |
| keys = [] |
| values = [] |
| # Initialize d with many entries |
| for i in range(70000): |
| k, v = DummyKey(i), DummyValue(i) |
| keys.append(k) |
| values.append(v) |
| d[k] = v |
| del k |
| del v |
| |
| t_copy = threading.Thread(target=dict_copy, args=(d, exc)) |
| t_collect = threading.Thread(target=pop_and_collect, args=(keys,)) |
| |
| t_copy.start() |
| t_collect.start() |
| |
| t_copy.join() |
| t_collect.join() |
| |
| # Test exceptions |
| if exc: |
| raise exc[0] |
| |
| def test_threaded_weak_key_dict_copy(self): |
| # Issue #35615: Weakref keys or values getting GC'ed during dict |
| # copying should not result in a crash. |
| self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, False) |
| |
| def test_threaded_weak_key_dict_deepcopy(self): |
| # Issue #35615: Weakref keys or values getting GC'ed during dict |
| # copying should not result in a crash. |
| self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, True) |
| |
| |
| # Adapted from cpython/Lib/test/mapping_tests.py |
| class WeakKeyDictionaryTestCase(TestCase): |
| __ref = {torch.randn(1): 1, torch.randn(2): 2, torch.randn(3): 3} |
| type2test = WeakIdKeyDictionary |
| |
| def _reference(self): |
| return self.__ref.copy() |
| |
| def _empty_mapping(self): |
| """Return an empty mapping object""" |
| return self.type2test() |
| |
| def _full_mapping(self, data): |
| """Return a mapping object with the value contained in data |
| dictionary""" |
| x = self._empty_mapping() |
| for key, value in data.items(): |
| x[key] = value |
| return x |
| |
| def __init__(self, *args, **kw): |
| unittest.TestCase.__init__(self, *args, **kw) |
| self.reference = self._reference().copy() |
| |
| # A (key, value) pair not in the mapping |
| key, value = self.reference.popitem() |
| self.other = {key: value} |
| |
| # A (key, value) pair in the mapping |
| key, value = self.reference.popitem() |
| self.inmapping = {key: value} |
| self.reference[key] = value |
| |
| def test_read(self): |
| # Test for read only operations on mapping |
| p = self._empty_mapping() |
| p1 = dict(p) # workaround for singleton objects |
| d = self._full_mapping(self.reference) |
| if d is p: |
| p = p1 |
| # Indexing |
| for key, value in self.reference.items(): |
| self.assertEqual(d[key], value) |
| knownkey = next(iter(self.other.keys())) |
| self.assertRaises(KeyError, lambda: d[knownkey]) |
| # len |
| self.assertEqual(len(p), 0) |
| self.assertEqual(len(d), len(self.reference)) |
| # __contains__ |
| for k in self.reference: |
| self.assertIn(k, d) |
| for k in self.other: |
| self.assertNotIn(k, d) |
| # cmp |
| self.assertTrue( |
| p == p |
| ) # NB: don't use assertEqual, that doesn't actually use == |
| self.assertTrue(d == d) |
| self.assertTrue(p != d) |
| self.assertTrue(d != p) |
| # bool |
| if p: |
| self.fail("Empty mapping must compare to False") |
| if not d: |
| self.fail("Full mapping must compare to True") |
| |
| # keys(), items(), iterkeys() ... |
| def check_iterandlist(iter, lst, ref): |
| self.assertTrue(hasattr(iter, "__next__")) |
| self.assertTrue(hasattr(iter, "__iter__")) |
| x = list(iter) |
| self.assertTrue(set(x) == set(lst) == set(ref)) |
| |
| check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys()) |
| check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) |
| check_iterandlist(iter(d.values()), list(d.values()), self.reference.values()) |
| check_iterandlist(iter(d.items()), list(d.items()), self.reference.items()) |
| # get |
| key, value = next(iter(d.items())) |
| knownkey, knownvalue = next(iter(self.other.items())) |
| self.assertEqual(d.get(key, knownvalue), value) |
| self.assertEqual(d.get(knownkey, knownvalue), knownvalue) |
| self.assertNotIn(knownkey, d) |
| |
| def test_write(self): |
| # Test for write operations on mapping |
| p = self._empty_mapping() |
| # Indexing |
| for key, value in self.reference.items(): |
| p[key] = value |
| self.assertEqual(p[key], value) |
| for key in self.reference.keys(): |
| del p[key] |
| self.assertRaises(KeyError, lambda: p[key]) |
| p = self._empty_mapping() |
| # update |
| p.update(self.reference) |
| self.assertEqual(dict(p), self.reference) |
| items = list(p.items()) |
| p = self._empty_mapping() |
| p.update(items) |
| self.assertEqual(dict(p), self.reference) |
| d = self._full_mapping(self.reference) |
| # setdefault |
| key, value = next(iter(d.items())) |
| knownkey, knownvalue = next(iter(self.other.items())) |
| self.assertEqual(d.setdefault(key, knownvalue), value) |
| self.assertEqual(d[key], value) |
| self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) |
| self.assertEqual(d[knownkey], knownvalue) |
| # pop |
| self.assertEqual(d.pop(knownkey), knownvalue) |
| self.assertNotIn(knownkey, d) |
| self.assertRaises(KeyError, d.pop, knownkey) |
| default = 909 |
| d[knownkey] = knownvalue |
| self.assertEqual(d.pop(knownkey, default), knownvalue) |
| self.assertNotIn(knownkey, d) |
| self.assertEqual(d.pop(knownkey, default), default) |
| # popitem |
| key, value = d.popitem() |
| self.assertNotIn(key, d) |
| self.assertEqual(value, self.reference[key]) |
| p = self._empty_mapping() |
| self.assertRaises(KeyError, p.popitem) |
| |
| def test_constructor(self): |
| self.assertEqual(self._empty_mapping(), self._empty_mapping()) |
| |
| def test_bool(self): |
| self.assertTrue(not self._empty_mapping()) |
| self.assertTrue(self.reference) |
| self.assertTrue(bool(self._empty_mapping()) is False) |
| self.assertTrue(bool(self.reference) is True) |
| |
| def test_keys(self): |
| d = self._empty_mapping() |
| self.assertEqual(list(d.keys()), []) |
| d = self.reference |
| self.assertIn(next(iter(self.inmapping.keys())), d.keys()) |
| self.assertNotIn(next(iter(self.other.keys())), d.keys()) |
| self.assertRaises(TypeError, d.keys, None) |
| |
| def test_values(self): |
| d = self._empty_mapping() |
| self.assertEqual(list(d.values()), []) |
| |
| self.assertRaises(TypeError, d.values, None) |
| |
| def test_items(self): |
| d = self._empty_mapping() |
| self.assertEqual(list(d.items()), []) |
| |
| self.assertRaises(TypeError, d.items, None) |
| |
| def test_len(self): |
| d = self._empty_mapping() |
| self.assertEqual(len(d), 0) |
| |
| def test_getitem(self): |
| d = self.reference |
| self.assertEqual( |
| d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values())) |
| ) |
| |
| self.assertRaises(TypeError, d.__getitem__) |
| |
| def test_update(self): |
| # mapping argument |
| d = self._empty_mapping() |
| d.update(self.other) |
| self.assertEqual(list(d.items()), list(self.other.items())) |
| |
| # No argument |
| d = self._empty_mapping() |
| d.update() |
| self.assertEqual(d, self._empty_mapping()) |
| |
| # item sequence |
| d = self._empty_mapping() |
| d.update(self.other.items()) |
| self.assertEqual(list(d.items()), list(self.other.items())) |
| |
| # Iterator |
| d = self._empty_mapping() |
| d.update(self.other.items()) |
| self.assertEqual(list(d.items()), list(self.other.items())) |
| |
| # FIXME: Doesn't work with UserDict |
| # self.assertRaises((TypeError, AttributeError), d.update, None) |
| self.assertRaises((TypeError, AttributeError), d.update, 42) |
| |
| outerself = self |
| |
| class SimpleUserDict: |
| def __init__(self) -> None: |
| self.d = outerself.reference |
| |
| def keys(self): |
| return self.d.keys() |
| |
| def __getitem__(self, i): |
| return self.d[i] |
| |
| d.clear() |
| d.update(SimpleUserDict()) |
| i1 = sorted((id(k), v) for k, v in d.items()) |
| i2 = sorted((id(k), v) for k, v in self.reference.items()) |
| self.assertEqual(i1, i2) |
| |
| class Exc(Exception): |
| pass |
| |
| d = self._empty_mapping() |
| |
| class FailingUserDict: |
| def keys(self): |
| raise Exc |
| |
| self.assertRaises(Exc, d.update, FailingUserDict()) |
| |
| d.clear() |
| |
| class FailingUserDict: |
| def keys(self): |
| class BogonIter: |
| def __init__(self) -> None: |
| self.i = 1 |
| |
| def __iter__(self): |
| return self |
| |
| def __next__(self): |
| if self.i: |
| self.i = 0 |
| return "a" |
| raise Exc |
| |
| return BogonIter() |
| |
| def __getitem__(self, key): |
| return key |
| |
| self.assertRaises(Exc, d.update, FailingUserDict()) |
| |
| class FailingUserDict: |
| def keys(self): |
| class BogonIter: |
| def __init__(self) -> None: |
| self.i = ord("a") |
| |
| def __iter__(self): |
| return self |
| |
| def __next__(self): |
| if self.i <= ord("z"): |
| rtn = chr(self.i) |
| self.i += 1 |
| return rtn |
| raise StopIteration |
| |
| return BogonIter() |
| |
| def __getitem__(self, key): |
| raise Exc |
| |
| self.assertRaises(Exc, d.update, FailingUserDict()) |
| |
| d = self._empty_mapping() |
| |
| class badseq: |
| def __iter__(self): |
| return self |
| |
| def __next__(self): |
| raise Exc |
| |
| self.assertRaises(Exc, d.update, badseq()) |
| |
| self.assertRaises(ValueError, d.update, [(1, 2, 3)]) |
| |
| # no test_fromkeys or test_copy as both os.environ and selves don't support it |
| |
| def test_get(self): |
| d = self._empty_mapping() |
| self.assertTrue(d.get(next(iter(self.other.keys()))) is None) |
| self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) |
| d = self.reference |
| self.assertTrue(d.get(next(iter(self.other.keys()))) is None) |
| self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) |
| self.assertEqual( |
| d.get(next(iter(self.inmapping.keys()))), |
| next(iter(self.inmapping.values())), |
| ) |
| self.assertEqual( |
| d.get(next(iter(self.inmapping.keys())), 3), |
| next(iter(self.inmapping.values())), |
| ) |
| self.assertRaises(TypeError, d.get) |
| self.assertRaises(TypeError, d.get, None, None, None) |
| |
| def test_setdefault(self): |
| d = self._empty_mapping() |
| self.assertRaises(TypeError, d.setdefault) |
| |
| def test_popitem(self): |
| d = self._empty_mapping() |
| self.assertRaises(KeyError, d.popitem) |
| self.assertRaises(TypeError, d.popitem, 42) |
| |
| def test_pop(self): |
| d = self._empty_mapping() |
| k, v = next(iter(self.inmapping.items())) |
| d[k] = v |
| self.assertRaises(KeyError, d.pop, next(iter(self.other.keys()))) |
| |
| self.assertEqual(d.pop(k), v) |
| self.assertEqual(len(d), 0) |
| |
| self.assertRaises(KeyError, d.pop, k) |
| |
| |
| # Adapted from cpython/Lib/test/mapping_tests.py |
| class WeakKeyDictionaryScriptObjectTestCase(TestCase): |
| def _reference(self): |
| self.__ref = { |
| torch.classes._TorchScriptTesting._Foo(1, 2): 1, |
| torch.classes._TorchScriptTesting._Foo(2, 3): 2, |
| torch.classes._TorchScriptTesting._Foo(3, 4): 3, |
| } |
| return self.__ref.copy() |
| |
| def _empty_mapping(self): |
| """Return an empty mapping object""" |
| return WeakIdKeyDictionary(ref_type=_WeakHashRef) |
| |
| def _full_mapping(self, data): |
| """Return a mapping object with the value contained in data |
| dictionary""" |
| x = self._empty_mapping() |
| for key, value in data.items(): |
| x[key] = value |
| return x |
| |
| def setUp(self): |
| if IS_MACOS: |
| raise unittest.SkipTest("non-portable load_library call used in test") |
| |
| def __init__(self, *args, **kw): |
| unittest.TestCase.__init__(self, *args, **kw) |
| if IS_SANDCASTLE or IS_FBCODE: |
| torch.ops.load_library( |
| "//caffe2/test/cpp/jit:test_custom_class_registrations" |
| ) |
| elif IS_MACOS: |
| # don't load the library, just skip the tests in setUp |
| return |
| else: |
| lib_file_path = find_library_location("libtorchbind_test.so") |
| if IS_WINDOWS: |
| lib_file_path = find_library_location("torchbind_test.dll") |
| torch.ops.load_library(str(lib_file_path)) |
| |
| self.reference = self._reference().copy() |
| |
| # A (key, value) pair not in the mapping |
| key, value = self.reference.popitem() |
| self.other = {key: value} |
| |
| # A (key, value) pair in the mapping |
| key, value = self.reference.popitem() |
| self.inmapping = {key: value} |
| self.reference[key] = value |
| |
| def test_read(self): |
| # Test for read only operations on mapping |
| p = self._empty_mapping() |
| p1 = dict(p) # workaround for singleton objects |
| d = self._full_mapping(self.reference) |
| if d is p: |
| p = p1 |
| # Indexing |
| for key, value in self.reference.items(): |
| self.assertEqual(d[key], value) |
| knownkey = next(iter(self.other.keys())) |
| self.assertRaises(KeyError, lambda: d[knownkey]) |
| # len |
| self.assertEqual(len(p), 0) |
| self.assertEqual(len(d), len(self.reference)) |
| # __contains__ |
| for k in self.reference: |
| self.assertIn(k, d) |
| for k in self.other: |
| self.assertNotIn(k, d) |
| # cmp |
| self.assertTrue( |
| p == p |
| ) # NB: don't use assertEqual, that doesn't actually use == |
| self.assertTrue(d == d) |
| self.assertTrue(p != d) |
| self.assertTrue(d != p) |
| # bool |
| if p: |
| self.fail("Empty mapping must compare to False") |
| if not d: |
| self.fail("Full mapping must compare to True") |
| |
| # keys(), items(), iterkeys() ... |
| def check_iterandlist(iter, lst, ref): |
| self.assertTrue(hasattr(iter, "__next__")) |
| self.assertTrue(hasattr(iter, "__iter__")) |
| x = list(iter) |
| self.assertTrue(set(x) == set(lst) == set(ref)) |
| |
| check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys()) |
| check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) |
| check_iterandlist(iter(d.values()), list(d.values()), self.reference.values()) |
| check_iterandlist(iter(d.items()), list(d.items()), self.reference.items()) |
| # get |
| key, value = next(iter(d.items())) |
| knownkey, knownvalue = next(iter(self.other.items())) |
| self.assertEqual(d.get(key, knownvalue), value) |
| self.assertEqual(d.get(knownkey, knownvalue), knownvalue) |
| self.assertNotIn(knownkey, d) |
| |
| def test_write(self): |
| # Test for write operations on mapping |
| p = self._empty_mapping() |
| # Indexing |
| for key, value in self.reference.items(): |
| p[key] = value |
| self.assertEqual(p[key], value) |
| for key in self.reference.keys(): |
| del p[key] |
| self.assertRaises(KeyError, lambda: p[key]) |
| p = self._empty_mapping() |
| # update |
| p.update(self.reference) |
| self.assertEqual(dict(p), self.reference) |
| items = list(p.items()) |
| p = self._empty_mapping() |
| p.update(items) |
| self.assertEqual(dict(p), self.reference) |
| d = self._full_mapping(self.reference) |
| # setdefault |
| key, value = next(iter(d.items())) |
| knownkey, knownvalue = next(iter(self.other.items())) |
| self.assertEqual(d.setdefault(key, knownvalue), value) |
| self.assertEqual(d[key], value) |
| self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) |
| self.assertEqual(d[knownkey], knownvalue) |
| # pop |
| self.assertEqual(d.pop(knownkey), knownvalue) |
| self.assertNotIn(knownkey, d) |
| self.assertRaises(KeyError, d.pop, knownkey) |
| default = 909 |
| d[knownkey] = knownvalue |
| self.assertEqual(d.pop(knownkey, default), knownvalue) |
| self.assertNotIn(knownkey, d) |
| self.assertEqual(d.pop(knownkey, default), default) |
| # popitem |
| key, value = d.popitem() |
| self.assertNotIn(key, d) |
| self.assertEqual(value, self.reference[key]) |
| p = self._empty_mapping() |
| self.assertRaises(KeyError, p.popitem) |
| |
| def test_constructor(self): |
| self.assertEqual(self._empty_mapping(), self._empty_mapping()) |
| |
| def test_bool(self): |
| self.assertTrue(not self._empty_mapping()) |
| self.assertTrue(self.reference) |
| self.assertTrue(bool(self._empty_mapping()) is False) |
| self.assertTrue(bool(self.reference) is True) |
| |
| def test_keys(self): |
| d = self._empty_mapping() |
| self.assertEqual(list(d.keys()), []) |
| d = self.reference |
| self.assertIn(next(iter(self.inmapping.keys())), d.keys()) |
| self.assertNotIn(next(iter(self.other.keys())), d.keys()) |
| self.assertRaises(TypeError, d.keys, None) |
| |
| def test_values(self): |
| d = self._empty_mapping() |
| self.assertEqual(list(d.values()), []) |
| |
| self.assertRaises(TypeError, d.values, None) |
| |
| def test_items(self): |
| d = self._empty_mapping() |
| self.assertEqual(list(d.items()), []) |
| |
| self.assertRaises(TypeError, d.items, None) |
| |
| def test_len(self): |
| d = self._empty_mapping() |
| self.assertEqual(len(d), 0) |
| |
| def test_getitem(self): |
| d = self.reference |
| self.assertEqual( |
| d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values())) |
| ) |
| |
| self.assertRaises(TypeError, d.__getitem__) |
| |
| def test_update(self): |
| # mapping argument |
| d = self._empty_mapping() |
| d.update(self.other) |
| self.assertEqual(list(d.items()), list(self.other.items())) |
| |
| # No argument |
| d = self._empty_mapping() |
| d.update() |
| self.assertEqual(d, self._empty_mapping()) |
| |
| # item sequence |
| d = self._empty_mapping() |
| d.update(self.other.items()) |
| self.assertEqual(list(d.items()), list(self.other.items())) |
| |
| # Iterator |
| d = self._empty_mapping() |
| d.update(self.other.items()) |
| self.assertEqual(list(d.items()), list(self.other.items())) |
| |
| # FIXME: Doesn't work with UserDict |
| # self.assertRaises((TypeError, AttributeError), d.update, None) |
| self.assertRaises((TypeError, AttributeError), d.update, 42) |
| |
| outerself = self |
| |
| class SimpleUserDict: |
| def __init__(self) -> None: |
| self.d = outerself.reference |
| |
| def keys(self): |
| return self.d.keys() |
| |
| def __getitem__(self, i): |
| return self.d[i] |
| |
| d.clear() |
| d.update(SimpleUserDict()) |
| i1 = sorted((id(k), v) for k, v in d.items()) |
| i2 = sorted((id(k), v) for k, v in self.reference.items()) |
| self.assertEqual(i1, i2) |
| |
| class Exc(Exception): |
| pass |
| |
| d = self._empty_mapping() |
| |
| class FailingUserDict: |
| def keys(self): |
| raise Exc |
| |
| self.assertRaises(Exc, d.update, FailingUserDict()) |
| |
| d.clear() |
| |
| class FailingUserDict: |
| def keys(self): |
| class BogonIter: |
| def __init__(self) -> None: |
| self.i = 1 |
| |
| def __iter__(self): |
| return self |
| |
| def __next__(self): |
| if self.i: |
| self.i = 0 |
| return "a" |
| raise Exc |
| |
| return BogonIter() |
| |
| def __getitem__(self, key): |
| return key |
| |
| self.assertRaises(Exc, d.update, FailingUserDict()) |
| |
| class FailingUserDict: |
| def keys(self): |
| class BogonIter: |
| def __init__(self) -> None: |
| self.i = ord("a") |
| |
| def __iter__(self): |
| return self |
| |
| def __next__(self): |
| if self.i <= ord("z"): |
| rtn = chr(self.i) |
| self.i += 1 |
| return rtn |
| raise StopIteration |
| |
| return BogonIter() |
| |
| def __getitem__(self, key): |
| raise Exc |
| |
| self.assertRaises(Exc, d.update, FailingUserDict()) |
| |
| d = self._empty_mapping() |
| |
| class badseq: |
| def __iter__(self): |
| return self |
| |
| def __next__(self): |
| raise Exc |
| |
| self.assertRaises(Exc, d.update, badseq()) |
| |
| self.assertRaises(ValueError, d.update, [(1, 2, 3)]) |
| |
| # no test_fromkeys or test_copy as both os.environ and selves don't support it |
| |
| def test_get(self): |
| d = self._empty_mapping() |
| self.assertTrue(d.get(next(iter(self.other.keys()))) is None) |
| self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) |
| d = self.reference |
| self.assertTrue(d.get(next(iter(self.other.keys()))) is None) |
| self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) |
| self.assertEqual( |
| d.get(next(iter(self.inmapping.keys()))), |
| next(iter(self.inmapping.values())), |
| ) |
| self.assertEqual( |
| d.get(next(iter(self.inmapping.keys())), 3), |
| next(iter(self.inmapping.values())), |
| ) |
| self.assertRaises(TypeError, d.get) |
| self.assertRaises(TypeError, d.get, None, None, None) |
| |
| def test_setdefault(self): |
| d = self._empty_mapping() |
| self.assertRaises(TypeError, d.setdefault) |
| |
| def test_popitem(self): |
| d = self._empty_mapping() |
| self.assertRaises(KeyError, d.popitem) |
| self.assertRaises(TypeError, d.popitem, 42) |
| |
| def test_pop(self): |
| d = self._empty_mapping() |
| k, v = next(iter(self.inmapping.items())) |
| d[k] = v |
| self.assertRaises(KeyError, d.pop, next(iter(self.other.keys()))) |
| |
| self.assertEqual(d.pop(k), v) |
| self.assertEqual(len(d), 0) |
| |
| self.assertRaises(KeyError, d.pop, k) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |