blob: 2ca1e9cda0a0ebcf7f0f662f736d5a48a7a8f250 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
from typing import Tuple, List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestHash(JitTestCase):
def test_hash_tuple(self):
def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
return hash(t1) == hash(t2)
self.checkScript(fn, ((1, 2), (1, 2)))
self.checkScript(fn, ((1, 2), (3, 4)))
self.checkScript(fn, ((1, 2), (2, 1)))
def test_hash_tuple_nested_unhashable_type(self):
# Tuples may contain unhashable types like `list`, check that we error
# properly in that case.
@torch.jit.script
def fn_unhashable(t1: Tuple[int, List[int]]):
return hash(t1)
with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"):
fn_unhashable((1, [1]))
def test_hash_tensor(self):
"""Tensors should hash by identity"""
def fn(t1, t2):
return hash(t1) == hash(t2)
tensor1 = torch.tensor(1)
tensor1_clone = torch.tensor(1)
tensor2 = torch.tensor(2)
self.checkScript(fn, (tensor1, tensor1))
self.checkScript(fn, (tensor1, tensor1_clone))
self.checkScript(fn, (tensor1, tensor2))
def test_hash_none(self):
def fn():
n1 = None
n2 = None
return hash(n1) == hash(n2)
self.checkScript(fn, ())
def test_hash_bool(self):
def fn(b1: bool, b2: bool):
return hash(b1) == hash(b2)
self.checkScript(fn, (True, False))
self.checkScript(fn, (True, True))
self.checkScript(fn, (False, True))
self.checkScript(fn, (False, False))
def test_hash_float(self):
def fn(f1: float, f2: float):
return hash(f1) == hash(f2)
self.checkScript(fn, (1.2345, 1.2345))
self.checkScript(fn, (1.2345, 6.789))
self.checkScript(fn, (1.2345, float("inf")))
self.checkScript(fn, (float("inf"), float("inf")))
self.checkScript(fn, (1.2345, float('nan')))
if sys.version_info < (3, 10):
# Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
# Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
self.checkScript(fn, (float("nan"), float("nan")))
self.checkScript(fn, (float("nan"), float("inf")))
def test_hash_int(self):
def fn(i1: int, i2: int):
return hash(i1) == hash(i2)
self.checkScript(fn, (123, 456))
self.checkScript(fn, (123, 123))
self.checkScript(fn, (123, -123))
self.checkScript(fn, (-123, -123))
self.checkScript(fn, (123, 0))
def test_hash_string(self):
def fn(s1: str, s2: str):
return hash(s1) == hash(s2)
self.checkScript(fn, ("foo", "foo"))
self.checkScript(fn, ("foo", "bar"))
self.checkScript(fn, ("foo", ""))
def test_hash_device(self):
def fn(d1: torch.device, d2: torch.device):
return hash(d1) == hash(d2)
gpu0 = torch.device('cuda:0')
gpu1 = torch.device('cuda:1')
cpu = torch.device('cpu')
self.checkScript(fn, (gpu0, gpu0))
self.checkScript(fn, (gpu0, gpu1))
self.checkScript(fn, (gpu0, cpu))
self.checkScript(fn, (cpu, cpu))