blob: 8bed7220fe88e22a3313305a7e1a0a47fe5e329d [file] [log] [blame] [edit]
# Owner(s): ["module: unknown"]
import unittest
import torch
import torch.testing._internal.common_utils as common
from torch.testing._internal.common_cuda import (
TEST_CUDA,
TEST_MULTIGPU,
TEST_NUMBA_CUDA,
)
from torch.testing._internal.common_utils import TEST_NUMPY
if TEST_NUMPY:
import numpy
if TEST_NUMBA_CUDA:
import numba.cuda
class TestNumbaIntegration(common.TestCase):
@unittest.skipIf(not TEST_NUMPY, "No numpy")
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_cuda_array_interface(self):
"""torch.Tensor exposes __cuda_array_interface__ for cuda tensors.
An object t is considered a cuda-tensor if:
hasattr(t, '__cuda_array_interface__')
A cuda-tensor provides a tensor description dict:
shape: (integer, ...) Tensor shape.
strides: (integer, ...) Tensor strides, in bytes.
typestr: (str) A numpy-style typestr.
data: (int, boolean) A (data_ptr, read-only) tuple.
version: (int) Version 0
See:
https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
"""
types = [
torch.DoubleTensor,
torch.FloatTensor,
torch.HalfTensor,
torch.LongTensor,
torch.IntTensor,
torch.ShortTensor,
torch.CharTensor,
torch.ByteTensor,
]
dtypes = [
numpy.float64,
numpy.float32,
numpy.float16,
numpy.int64,
numpy.int32,
numpy.int16,
numpy.int8,
numpy.uint8,
]
for tp, npt in zip(types, dtypes):
# CPU tensors do not implement the interface.
cput = tp(10)
self.assertFalse(hasattr(cput, "__cuda_array_interface__"))
self.assertRaises(AttributeError, lambda: cput.__cuda_array_interface__)
# Sparse CPU/CUDA tensors do not implement the interface
if tp not in (torch.HalfTensor,):
indices_t = torch.empty(1, cput.size(0), dtype=torch.long).clamp_(min=0)
sparse_t = torch.sparse_coo_tensor(indices_t, cput)
self.assertFalse(hasattr(sparse_t, "__cuda_array_interface__"))
self.assertRaises(
AttributeError, lambda: sparse_t.__cuda_array_interface__
)
sparse_cuda_t = torch.sparse_coo_tensor(indices_t, cput).cuda()
self.assertFalse(hasattr(sparse_cuda_t, "__cuda_array_interface__"))
self.assertRaises(
AttributeError, lambda: sparse_cuda_t.__cuda_array_interface__
)
# CUDA tensors have the attribute and v2 interface
cudat = tp(10).cuda()
self.assertTrue(hasattr(cudat, "__cuda_array_interface__"))
ar_dict = cudat.__cuda_array_interface__
self.assertEqual(
set(ar_dict.keys()), {"shape", "strides", "typestr", "data", "version"}
)
self.assertEqual(ar_dict["shape"], (10,))
self.assertIs(ar_dict["strides"], None)
# typestr from numpy, cuda-native little-endian
self.assertEqual(ar_dict["typestr"], numpy.dtype(npt).newbyteorder("<").str)
self.assertEqual(ar_dict["data"], (cudat.data_ptr(), False))
self.assertEqual(ar_dict["version"], 2)
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
def test_array_adaptor(self):
"""Torch __cuda_array_adaptor__ exposes tensor data to numba.cuda."""
torch_dtypes = [
torch.complex64,
torch.complex128,
torch.float16,
torch.float32,
torch.float64,
torch.uint8,
torch.int8,
torch.uint16,
torch.int16,
torch.uint32,
torch.int32,
torch.uint64,
torch.int64,
torch.bool,
]
for dt in torch_dtypes:
# CPU tensors of all types do not register as cuda arrays,
# attempts to convert raise a type error.
cput = torch.arange(10).to(dt)
npt = cput.numpy()
self.assertTrue(not numba.cuda.is_cuda_array(cput))
with self.assertRaises(TypeError):
numba.cuda.as_cuda_array(cput)
# Any cuda tensor is a cuda array.
cudat = cput.to(device="cuda")
self.assertTrue(numba.cuda.is_cuda_array(cudat))
numba_view = numba.cuda.as_cuda_array(cudat)
self.assertIsInstance(numba_view, numba.cuda.devicearray.DeviceNDArray)
# The reported type of the cuda array matches the numpy type of the cpu tensor.
self.assertEqual(numba_view.dtype, npt.dtype)
self.assertEqual(numba_view.strides, npt.strides)
self.assertEqual(numba_view.shape, cudat.shape)
# Pass back to cuda from host for all equality checks below, needed for
# float16 comparisons, which aren't supported cpu-side.
# The data is identical in the view.
self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
# Writes to the torch.Tensor are reflected in the numba array.
cudat[:5] = 11
self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
# Strided tensors are supported.
strided_cudat = cudat[::2]
strided_npt = cput[::2].numpy()
strided_numba_view = numba.cuda.as_cuda_array(strided_cudat)
self.assertEqual(strided_numba_view.dtype, strided_npt.dtype)
self.assertEqual(strided_numba_view.strides, strided_npt.strides)
self.assertEqual(strided_numba_view.shape, strided_cudat.shape)
# As of numba 0.40.0 support for strided views is ...limited...
# Cannot verify correctness of strided view operations.
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
def test_conversion_errors(self):
"""Numba properly detects array interface for tensor.Tensor variants."""
# CPU tensors are not cuda arrays.
cput = torch.arange(100)
self.assertFalse(numba.cuda.is_cuda_array(cput))
with self.assertRaises(TypeError):
numba.cuda.as_cuda_array(cput)
# Sparse tensors are not cuda arrays, regardless of device.
sparset = torch.sparse_coo_tensor(cput[None, :], cput)
self.assertFalse(numba.cuda.is_cuda_array(sparset))
with self.assertRaises(TypeError):
numba.cuda.as_cuda_array(sparset)
sparse_cuda_t = sparset.cuda()
self.assertFalse(numba.cuda.is_cuda_array(sparset))
with self.assertRaises(TypeError):
numba.cuda.as_cuda_array(sparset)
# Device-status overrides gradient status.
# CPU+gradient isn't a cuda array.
cpu_gradt = torch.zeros(100).requires_grad_(True)
self.assertFalse(numba.cuda.is_cuda_array(cpu_gradt))
with self.assertRaises(TypeError):
numba.cuda.as_cuda_array(cpu_gradt)
# CUDA+gradient raises a RuntimeError on check or conversion.
#
# Use of hasattr for interface detection causes interface change in
# python2; it swallows all exceptions not just AttributeError.
cuda_gradt = torch.zeros(100).requires_grad_(True).cuda()
# conversion raises RuntimeError
with self.assertRaises(RuntimeError):
numba.cuda.is_cuda_array(cuda_gradt)
with self.assertRaises(RuntimeError):
numba.cuda.as_cuda_array(cuda_gradt)
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
@unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
def test_active_device(self):
"""'as_cuda_array' tensor device must match active numba context."""
# Both torch/numba default to device 0 and can interop freely
cudat = torch.arange(10, device="cuda")
self.assertEqual(cudat.device.index, 0)
self.assertIsInstance(
numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
)
# Tensors on non-default device raise api error if converted
cudat = torch.arange(10, device=torch.device("cuda", 1))
with self.assertRaises(numba.cuda.driver.CudaAPIError):
numba.cuda.as_cuda_array(cudat)
# but can be converted when switching to the device's context
with numba.cuda.devices.gpus[cudat.device.index]:
self.assertIsInstance(
numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
)
@unittest.skip(
"Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418"
)
@unittest.skipIf(not TEST_NUMPY, "No numpy")
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
def test_from_cuda_array_interface(self):
"""torch.as_tensor() and torch.tensor() supports the __cuda_array_interface__ protocol.
If an object exposes the __cuda_array_interface__, .as_tensor() and .tensor()
will use the exposed device memory.
See:
https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
"""
dtypes = [
numpy.complex64,
numpy.complex128,
numpy.float64,
numpy.float32,
numpy.int64,
numpy.int32,
numpy.int16,
numpy.int8,
numpy.uint8,
]
for dtype in dtypes:
numpy_arys = [
numpy.ones((), dtype=dtype),
numpy.arange(6).reshape(2, 3).astype(dtype),
numpy.arange(6)
.reshape(2, 3)
.astype(dtype)[1:], # View offset should be ignored
numpy.arange(6)
.reshape(2, 3)
.astype(dtype)[:, None], # change the strides but still contiguous
]
# Zero-copy when using `torch.as_tensor()`
for numpy_ary in numpy_arys:
numba_ary = numba.cuda.to_device(numpy_ary)
torch_ary = torch.as_tensor(numba_ary, device="cuda")
self.assertEqual(
numba_ary.__cuda_array_interface__,
torch_ary.__cuda_array_interface__,
)
self.assertEqual(
torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
)
# Check that `torch_ary` and `numba_ary` points to the same device memory
torch_ary += 42
self.assertEqual(
torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
)
# Implicit-copy because `torch_ary` is a CPU array
for numpy_ary in numpy_arys:
numba_ary = numba.cuda.to_device(numpy_ary)
torch_ary = torch.as_tensor(numba_ary, device="cpu")
self.assertEqual(
torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
)
# Check that `torch_ary` and `numba_ary` points to different memory
torch_ary += 42
self.assertEqual(
torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42
)
# Explicit-copy when using `torch.tensor()`
for numpy_ary in numpy_arys:
numba_ary = numba.cuda.to_device(numpy_ary)
torch_ary = torch.tensor(numba_ary, device="cuda")
self.assertEqual(
torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
)
# Check that `torch_ary` and `numba_ary` points to different memory
torch_ary += 42
self.assertEqual(
torch_ary.cpu().data.numpy(),
numpy.asarray(numba_ary, dtype=dtype) + 42,
)
@unittest.skipIf(not TEST_NUMPY, "No numpy")
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
def test_from_cuda_array_interface_inferred_strides(self):
"""torch.as_tensor(numba_ary) should have correct inferred (contiguous) strides"""
# This could, in theory, be combined with test_from_cuda_array_interface but that test
# is overly strict: it checks that the exported protocols are exactly the same, which
# cannot handle differing exported protocol versions.
dtypes = [
numpy.float64,
numpy.float32,
numpy.int64,
numpy.int32,
numpy.int16,
numpy.int8,
numpy.uint8,
]
for dtype in dtypes:
numpy_ary = numpy.arange(6).reshape(2, 3).astype(dtype)
numba_ary = numba.cuda.to_device(numpy_ary)
self.assertTrue(numba_ary.is_c_contiguous())
torch_ary = torch.as_tensor(numba_ary, device="cuda")
self.assertTrue(torch_ary.is_contiguous())
@unittest.skip(
"Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418"
)
@unittest.skipIf(not TEST_NUMPY, "No numpy")
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
def test_from_cuda_array_interface_lifetime(self):
"""torch.as_tensor(obj) tensor grabs a reference to obj so that the lifetime of obj exceeds the tensor"""
numba_ary = numba.cuda.to_device(numpy.arange(6))
torch_ary = torch.as_tensor(numba_ary, device="cuda")
self.assertEqual(
torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__
) # No copy
del numba_ary
self.assertEqual(
torch_ary.cpu().data.numpy(), numpy.arange(6)
) # `torch_ary` is still alive
@unittest.skip(
"Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418"
)
@unittest.skipIf(not TEST_NUMPY, "No numpy")
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
@unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
def test_from_cuda_array_interface_active_device(self):
"""torch.as_tensor() tensor device must match active numba context."""
# Zero-copy: both torch/numba default to device 0 and can interop freely
numba_ary = numba.cuda.to_device(numpy.arange(6))
torch_ary = torch.as_tensor(numba_ary, device="cuda")
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
self.assertEqual(
torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__
)
# Implicit-copy: when the Numba and Torch device differ
numba_ary = numba.cuda.to_device(numpy.arange(6))
torch_ary = torch.as_tensor(numba_ary, device=torch.device("cuda", 1))
self.assertEqual(torch_ary.get_device(), 1)
self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
if1 = torch_ary.__cuda_array_interface__
if2 = numba_ary.__cuda_array_interface__
self.assertNotEqual(if1["data"], if2["data"])
del if1["data"]
del if2["data"]
self.assertEqual(if1, if2)
if __name__ == "__main__":
common.run_tests()