blob: ad9316cac762030c08eb84e17446be660b6ee483 [file] [log] [blame] [edit]
# Owner(s): ["module: mps"]
import io
import platform
import sys
import math
import random
import unittest
import warnings
import subprocess
import tempfile
import os
import copy
import gc
import threading
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
from collections import defaultdict
from torch import inf
from torch.nn import Buffer, Parameter
from torch.testing._internal import opinfo
from torch.testing._internal.common_utils import \
(gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
import torch.backends.mps
from torch.distributions import Uniform, Exponential
from functools import partial
from torch.testing._internal.common_methods_invocations import (
op_db,
DecorateInfo,
UnaryUfuncInfo,
ReductionOpInfo,
SpectralFuncInfo,
BinaryUfuncInfo,
)
from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
import numpy as np
import torch
import torch.utils._pytree as pytree
from itertools import product
import operator
test_consistency_op_db = copy.deepcopy(op_db)
test_error_inputs_op_db = copy.deepcopy(op_db)
# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
_ref_test_ops = tuple(
filter(
lambda op: not isinstance(
op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
)
and op.ref is not None,
op_db,
)
)
def xfailIf(condition):
def wrapper(func):
if condition:
return unittest.expectedFailure(func)
else:
return func
return wrapper
def xfailIfMacOS14_4Plus(func):
return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821
def mps_ops_grad_modifier(ops):
XFAILLIST_GRAD = {
# precision issues
'special.polygammaspecial_polygamma_n_0': [torch.float16],
'polygammapolygamma_n_0': [torch.float16],
'nn.functional.binary_cross_entropy': [torch.float16],
# Unimplemented ops
'__getitem__': [torch.float16],
'_segment_reduce': [torch.float16, torch.float32],
'_chunk_cat': [torch.float16, torch.float32],
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
'unfold': [torch.float16, torch.float32],
'sparse.mmreduce': [torch.float32], # csr not supported
'unique_consecutive': [torch.float16, torch.float32],
'special_modified_bessel_i0': [torch.float16, torch.float32],
'scalar_tensor': [torch.float16, torch.float32],
'cdist': [torch.float32],
'masked.scatter': [torch.float16, torch.float32],
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'aminmax': [torch.float32, torch.float16],
# Correctness issues
'atanh': [torch.float32],
# Random output
'exponential': [torch.float16, torch.float32],
# CPU errors
# derivative for aten::nextafter is not implemented on CPU
'nextafter': None,
# derivative for aten::floor_divide is not implemented on CPU
'floor_divide': [torch.float16, torch.float32],
# derivative for aten::narrow_copy is not implemented on CPU
'narrow_copy': [torch.float16, torch.float32],
# derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
'histogramdd': [torch.float16, torch.float32],
# derivative for aten::histogram is not implemented
'histogram': [torch.float16, torch.float32],
# 'bool' object is not iterable
'allclose': [torch.float16, torch.float32],
'equal': [torch.float16, torch.float32],
# 'float' object is not iterable
'item': [torch.float16, torch.float32],
# "mse_backward_cpu_out" not implemented for 'Half'
'nn.functional.mse_loss': [torch.float16],
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
'nn.functional.smooth_l1_loss': [torch.float16],
# cpu error: grad requires non-empty inputs
'randn': [torch.float16, torch.float32],
'signal.windows.bartlett': [torch.float32],
'signal.windows.blackman': [torch.float32],
'signal.windows.cosine': [torch.float32],
'signal.windows.exponential': [torch.float32],
'signal.windows.gaussian': [torch.float32],
'signal.windows.general_cosine': [torch.float32],
'signal.windows.general_hamming': [torch.float32],
'signal.windows.hamming': [torch.float32],
'signal.windows.hann': [torch.float32],
'signal.windows.kaiser': [torch.float32],
'signal.windows.nuttall': [torch.float32],
'eye': [torch.float16, torch.float32],
# trunc_tensor not working properly for float16
'divtrunc_rounding': [torch.float16],
'fmod': [torch.float16],
# round not working properly for float16
'round': [torch.float16],
# atomic operation in backward pass
'_unsafe_masked_index': [torch.float16],
'_unsafe_masked_index_put_accumulate': [torch.float16],
}
MACOS_12_3_XFAILLIST_GRAD = {
# Unsupported Border padding mode, forward pass success as fallback to cpu
'grid_sampler_2d': [torch.float32],
# Unimplemented
'logaddexp2': [torch.float32],
}
MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
'masked.softmin': [torch.float32, torch.float16],
'masked.softmax': [torch.float32, torch.float16],
'masked.log_softmax': [torch.float32, torch.float16],
# Unsupported Border padding mode, forward pass success as fallback to cpu
'grid_sampler_2d': [torch.float32],
# Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
# Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
# On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
# Running `msort` with stable `sort` passes.
'msort': [torch.float16],
}
SKIPLIST_GRAD = {
'nn.functional.pairwise_distance': [torch.float16],
# failed assertion `destination datatype must be fp32'
'nn.functional.conv1d': [torch.float16],
'nn.functional.conv2d': [torch.float16],
'nn.functional.conv3d': [torch.float16],
'nn.functional.conv_transpose1d': [torch.float16],
'nn.functional.conv_transpose2d': [torch.float16],
'nn.functional.conv_transpose3d': [torch.float16],
}
MACOS_13_3_XFAILLIST_GRAD = {
# Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
# Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
# On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
# Running `msort` with stable `sort` passes.
'msort': [torch.float16],
}
ON_MPS_XFAILLIST = {
# Failures due to lack of implementation of downstream functions on MPS backend
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
'linalg.matrix_rank': None,
# Exception: Caused by sample input at index 3 on MPS
'nn.functional.conv3d': [torch.float32],
}
def addDecorator(op, d) -> None:
op.decorators = list(op.decorators) if op.decorators is not None else []
op.decorators.append(d)
for op in ops:
key = op.name + op.variant_test_name
if key in XFAILLIST_GRAD:
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=XFAILLIST_GRAD[key]))
if key in SKIPLIST_GRAD:
addDecorator(op, DecorateInfo(
unittest.skip,
dtypes=SKIPLIST_GRAD[key]))
if key in ON_MPS_XFAILLIST:
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=ON_MPS_XFAILLIST[key]))
if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_12_3_XFAILLIST_GRAD[key]))
if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key]))
if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_13_3_XFAILLIST_GRAD[key]))
yield op
def mps_ops_modifier(ops):
# Supported complex OPS
SUPPORTED_COMPLEX_OPS = {
'__radd__',
'__rmul__',
'__getitem__',
'abs',
'add',
'alias_copy',
'argwhere',
'atleast_1d',
'atleast_2d',
'atleast_3d',
'as_strided',
'as_strided_copy',
'as_strided_scatter',
'broadcast_tensors',
'broadcast_to',
'chalf',
'cfloat',
'chunk',
'clone',
'conj',
'conj_physical',
'contiguous',
'diag',
'diag_embed',
'diagflat',
'diagonal',
'diagonal_copy',
'diagonal_scatter',
'dsplit',
'empty',
'empty_permuted',
'empty_strided',
'eye',
'exp',
'expand',
'expand_as',
'expand_copy',
'flatten',
'fill',
'full',
'H',
'hsplit',
'imag',
'index_select',
'isfinite',
'isinf',
'isreal',
'item',
'kron',
'linalg.diagonal',
'linalg.svd',
'linspace',
'logspace',
'linspacetensor_overload',
'logspacetensor_overload',
'mH',
'mT',
'masked_scatter',
'masked_select',
'meshgridlist_of_tensors',
'meshgridvariadic_tensors',
'movedim',
'mul',
'narrow',
'narrow_copy',
'nn.functional.conv1d',
'nn.functional.conv2d',
'nn.functional.conv_transpose1d',
'nn.functional.conv_transpose2d',
'nn.functional.feature_alpha_dropoutwithout_train',
'nn.functional.padcircular',
'nn.functional.tanhshrink',
'nn.functional.unfold',
'nonzero',
'ones',
'outer',
'permute',
'positive',
'randn',
'ravel',
'real',
'repeat_interleave',
'reshape_as',
'reshape',
'resolve_conj',
'resolve_neg',
'scalar_tensor',
'select',
'sgn',
'slice',
'split',
'split_with_sizes',
'split_with_sizes_copy',
'splitlist_args',
'squeeze',
'squeezemultiple',
'sub',
'svd',
't',
't_copy',
'tanh',
'tensor_split',
'transpose',
'T',
'unbind',
'unflatten',
'unfold',
'unfold_copy',
'unsafe_chunk',
'unsafe_split',
'unsqueeze',
'unsqueeze_copy',
'view_as',
'view_as_real',
'view',
'view_copy',
'vsplit',
'zero_',
'zeros',
}
AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
'__rdiv__',
'__rmatmul__',
'_chunk_cat',
'_unsafe_masked_index',
'acos',
'acosh',
'all',
'allclose',
'any',
'addcdiv',
'addcmul',
'addmmdecomposed',
'addmv',
'asin',
'atan',
'atanh',
'bfloat16',
'bmm',
'bool',
'cartesian_prod',
'cat',
'char',
'column_stack',
'combinations',
'corrcoef',
'constant_pad_nd',
'cos',
'cosh',
'count_nonzero',
'diff',
'div',
'divno_rounding_mode',
'dot',
'dstack',
'einsum',
'eq',
'equal',
'exp2',
'expm1',
'fft.fft',
'fft.fft2',
'fft.fftn',
'fft.fftshift',
'fft.ifft',
'fft.ifft2',
'fft.ifftn',
'fft.ifftshift',
'fft.irfftn',
'fft.irfft2',
'fft.irfft',
'fft.hfftn',
'fft.hfft2',
'fft.hfft',
'flip',
'fliplr',
'flipud',
'float',
'gradient',
'half',
'hstack',
'inner',
'int',
'isclose',
'isnan',
'ldexp',
'linalg.multi_dot',
'linalg.pinv',
'log10',
'log1p',
'log2',
'log',
'logical_and',
'logical_not',
'logical_or',
'logical_xor',
'logsumexp',
'long',
'masked_fill',
'masked.mean',
'masked.prod',
'masked.std',
'masked.sum',
'masked.var',
'masked.logsumexp',
'matmul',
'mean',
'mm',
'mv',
'ne',
'neg',
'nn.functional.padconstant',
'nn.functional.padreflect',
'nn.functional.padreplicate',
'nn.functional.pixel_shuffle',
'nn.functional.pixel_unshuffle',
'nn.functional.rms_norm',
'nn.functional.softsign',
'pinverse',
'prod',
'reciprocal',
'roll',
'rot90',
'rsqrt',
'short',
'sigmoid',
'sin',
'sinh',
'sqrt',
'square',
'stack',
'stft',
'sum',
'sum_to_size',
'tan',
'tensordot',
'trace',
'trapz',
'trapezoid',
'tril',
'triu',
'true_divide',
'vstack',
'where',
'byte',
}
# Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
MACOS_12_3_XFAILLIST = {
# Top 60
# expected failures
# The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
# fixed in macOS 13.3. Currently error is not raised.
'pow': [torch.int16, torch.int64, torch.uint8, torch.int8],
# expected failures
'__rpow__': [torch.uint8, torch.int8],
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
'cdist': [torch.float32],
'tan': [torch.uint8, torch.float32],
# Data type support starts from macOS 13
'nn.functional.avg_pool1d': [torch.int64],
'nn.functional.avg_pool2d': [torch.int64],
'nn.functional.local_response_norm': [torch.int64],
'__radd__': [torch.uint8],
'__rdiv__': [torch.uint8],
'__rmul__': [torch.uint8],
'abs': [torch.uint8],
'acos': [torch.uint8],
'acosh': [torch.uint8],
'add': [torch.uint8],
'asin': [torch.uint8],
'asinh': [torch.uint8],
'atan': [torch.uint8],
'atanh': [torch.uint8],
'ceil': [torch.uint8],
'corrcoef': [torch.uint8],
'cos': [torch.uint8],
'cosh': [torch.uint8],
'cov': [torch.uint8],
'cumulative_trapezoid': [torch.uint8],
'deg2rad': [torch.uint8],
'diff': [torch.uint8],
'eq': [torch.uint8],
'equal': [torch.uint8],
'erf': [torch.uint8],
'exp2': [torch.uint8],
'exp': [torch.uint8],
'expm1': [torch.uint8],
'floor': [torch.uint8],
'fmax': [torch.uint8],
'fmin': [torch.uint8],
'fmod': [torch.uint8],
'ge': [torch.uint8],
'gt': [torch.uint8],
'isclose': [torch.uint8],
'isnan': [torch.uint8],
'kron': [torch.uint8],
'le': [torch.uint8],
'log10': [torch.uint8],
'log1p': [torch.uint8],
'log2': [torch.uint8],
'log': [torch.uint8],
'logical_and': [torch.uint8],
'logical_or': [torch.uint8],
'logical_xor': [torch.uint8],
'logit': [torch.uint8],
'lt': [torch.uint8],
'masked.mean': [torch.uint8],
'masked.std': [torch.uint8],
'masked.var': [torch.uint8],
'maximum': [torch.uint8],
'minimum': [torch.uint8],
'mul': [torch.uint8],
'ne': [torch.uint8],
'neg': [torch.uint8],
'nn.functional.cosine_embedding_loss': [torch.uint8],
'nn.functional.margin_ranking_loss': [torch.uint8],
'nn.functional.poisson_nll_loss': [torch.uint8],
'nn.functional.softsign': [torch.uint8],
'nn.functional.tanhshrink': [torch.uint8],
'nn.functional.triplet_margin_loss': [torch.uint8],
'nn.functional.triplet_margin_with_distance_loss': [torch.uint8],
'nn.functional.pairwise_distance': [torch.uint8],
'outer': [torch.uint8],
'rad2deg': [torch.uint8],
'reciprocal': [torch.uint8],
'remainder': [torch.uint8],
'round': [torch.uint8],
'rsqrt': [torch.uint8],
'sigmoid': [torch.uint8],
'sign': [torch.uint8],
'signbit': [torch.uint8],
'sin': [torch.uint8],
'sinh': [torch.uint8],
'special.ndtr': [torch.uint8],
'sqrt': [torch.uint8],
'sub': [torch.uint8],
'trapezoid': [torch.uint8],
'trapz': [torch.uint8],
'true_divide': [torch.uint8],
'trunc': [torch.uint8],
'xlogy': [torch.uint8],
'minbinary': [torch.uint8],
'maxbinary': [torch.uint8],
'divtrunc_rounding': [torch.uint8],
'divfloor_rounding': [torch.uint8],
'divno_rounding_mode': [torch.uint8],
'floor_divide': [torch.uint8],
'ldexp': [torch.uint8],
# square internally calls into power, and will type cast to int64, which supports starting from macOS 13
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# cpu not giving nan for x/0.0
'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# inconsistency errors between cpu and mps, max seen atol is 2
'nn.functional.interpolatebilinear': [torch.uint8],
}
MACOS_BEFORE_13_3_XFAILLIST = {
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
'tan': [torch.float32],
'cdist': [torch.float32],
# CPU Error: cpu not giving nan for x/0.0
'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# test blow pass on macOS 12 as it falls back to cpu
# Argsort case using duplicate indices (undefined behaviour):
# - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
# - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
# Elements from index 30 and 5133 are both equal.
# Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
# Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
# The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
# Unsupported dtypes
'cumsum': [torch.int64],
'cumprod': [torch.int64],
'cumulative_trapezoid': [torch.int64],
'masked.cumsum': [torch.int64],
'masked.cumprod': [torch.int64],
'linalg.vander': [torch.int64],
}
MACOS_AFTER_13_1_XFAILLIST = {
# before macOS 13.2 it falls back to cpu and pass the forward pass
'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode
# inconsistency errors between cpu and mps, max seen atol is 2
'nn.functional.interpolatebilinear': [torch.uint8],
}
MACOS_13_3_XFAILLIST = {
# Failure due to precision issue for fp16
# on both cpu and mps there are test cases that might produce inf result
# 'nn.functional.pairwise_distance': [torch.float16],
# test blow pass on macOS 12 as it falls back to cpu
# Argsort case using duplicate indices (undefined behaviour):
# - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
# - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
# Elements from index 30 and 5133 are both equal.
# Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
# Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
# The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
}
MACOS_BEFORE_14_4_XFAILLIST = {
# These ops work fine in 14.4 but fail in 14.2 or 13.x
'fft.hfft2': [torch.complex64],
}
# Those ops are not expected to work
UNIMPLEMENTED_XFAILLIST = {
# Failures due to lack of op implementation on MPS backend
'login': None,
'linalg.eig': None,
'linalg.eigvals': None,
'put': None,
'nn.functional.conv_transpose3d': None,
'rounddecimals_neg_3': None,
'rounddecimals_3': None,
'rounddecimals_0': None,
'__rsub__': None,
'angle': None,
'cauchy_': None,
'cauchy': None,
'cholesky': None,
'cholesky_inverse': None,
'cholesky_solve': None,
'cummax': None,
'cummin': None,
'erfc': None,
'frexp': None,
'gcd': None,
'geqrf': None,
'nn.functional.grid_sample': None, # Unsupported Border padding mode
'heaviside': None,
'i0': None,
'igamma': None,
'igammac': None,
'index_copy': None,
'index_reduceprod': None,
'index_reducemean': None,
'index_reduceamax': None,
'index_reduceamin': None,
'isneginf': None,
'isposinf': None,
'kthvalue': None,
'lcm': None,
'linalg.cholesky': None,
'linalg.cholesky_ex': None,
'linalg.cond': None,
'linalg.detsingular': None,
'linalg.det': None,
'linalg.eigh': None,
'linalg.eigvalsh': None,
'linalg.householder_product': None,
'linalg.ldl_factor': None,
'linalg.ldl_factor_ex': None,
'linalg.ldl_solve': None,
'linalg.lstsq': None,
'linalg.lstsqgrad_oriented': None,
'linalg.lu': None,
'linalg.lu_factor_ex': None,
'linalg.lu_solve': None,
'linalg.matrix_norm': [torch.float32],
'linalg.norm': [torch.float32],
'linalg.normsubgradients_at_zero': [torch.float32],
'linalg.qr': None,
'linalg.slogdet': None,
'linalg.solve': None,
'linalg.solve_ex': None,
'linalg.svdvals': None,
'linalg.tensorsolve': None,
'linalg.vecdot': None,
'logcumsumexp': None,
'logdet': None,
'lu': None,
'lu_solve': None,
'lu_unpack': None,
'masked.median': None,
'matrix_exp': None,
'mode': None,
'nanmedian': None,
'native_dropout_backward': None,
'normnuc': None,
'nn.functional.fractional_max_pool2d': None,
'nn.functional.fractional_max_pool3d': None,
'nn.functional.adaptive_avg_pool3d': None,
'nn.functional.adaptive_max_pool3d': None,
'nn.functional.interpolatearea': None,
'nn.functional.interpolatebicubic': None,
'nn.functional.interpolatetrilinear': None,
'nn.functional.max_unpool1dgrad': None,
'nn.functional.max_unpool2dgrad': None,
'nn.functional.max_unpool3dgrad': None,
'nn.functional.avg_pool3d': None,
'nn.functional.ctc_loss': None,
'nn.functional.embedding_bag': None,
'nn.functional.hardshrink': None,
'nn.functional.max_pool3d': None,
'nn.functional.max_unpool1d': None,
'nn.functional.max_unpool2d': None,
'nn.functional.max_unpool3d': None,
'nn.functional.multi_margin_loss': None,
'nn.functional.multilabel_margin_loss': None,
'nn.functional.pdist': None,
'nn.functional.rrelu': None,
'nn.functional.norm': None,
'ormqr': None,
'pca_lowrank': None,
'qr': None,
'rsub': None,
'scatter_reduceamax': None,
'scatter_reduceamin': None,
'scatter_reducemin': None,
'scatter_reducemean': None,
'scatter_reduceprod': None,
'scatter_reducesum': None,
'segment_reduce': None,
'_segment.reduce': None,
'segment.reduce': None,
'segment_reduce_offsets': None,
'_segment_reduce_offsets': None,
'_segment_reduce_lengths': None,
'_segment_reducelengths': None,
'_segment_reduceoffsets': None,
'sinc': None,
'sparse.mm': None,
'sparse.mmreduce': None,
'special.airy_ai': None,
'special.bessel_j0': None,
'special.bessel_j1': None,
'special.bessel_y0': None,
'special.bessel_y1': None,
'special.chebyshev_polynomial_t': None,
'special.chebyshev_polynomial_u': None,
'special.entr': None,
'special.erfcx': None,
'special.hermite_polynomial_h': None,
'special.hermite_polynomial_he': None,
'special.i0e': None,
'special.i1': None,
'special.i1e': None,
'special.laguerre_polynomial_l': None,
'special.log_ndtr': None,
'special.modified_bessel_i0': None,
'special.modified_bessel_i1': None,
'special.modified_bessel_k0': None,
'special.modified_bessel_k1': None,
'special.ndtri': None,
'special.scaled_modified_bessel_k0': None,
'special.scaled_modified_bessel_k1': None,
'special.spherical_bessel_j0': None,
'special.xlog1py': None,
'special.zeta': None,
'svd_lowrank': None,
'symeig': None,
'take': None,
'to': None,
'to_sparse': None,
'unique': None,
'vdot': None,
'segment_reduce_': None,
'_upsample_bilinear2d_aa': None,
'geometric' : None,
'geometric_': None,
'log_normal_': None,
'log_normal': None,
'cdouble': None,
'double': None,
'nn.functional.softminwith_dtype': None,
'log_softmaxwith_dtype': None,
'softmaxwith_dtype': None,
'float_power': None,
'full_like': None,
'linalg.matrix_rankhermitian': None,
'linalg.pinvhermitian': None,
'nonzero_static': None,
# MPS: input sizes must be divisible by output sizes
'nn.functional.adaptive_avg_pool1d': None,
'nn.functional.adaptive_avg_pool2d': None,
# Unsupported dtypes
# bmm is not supported for integral types
'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'ones_like': None,
'zeros_like': None,
# Convolution for integral types is not supported on MPS
'nn.functional.conv1d': [torch.int64],
'nn.functional.conv2d': [torch.int64],
'nn.functional.conv3d': [torch.int64],
'nn.functional.conv_transpose1d': [torch.int64],
'nn.functional.conv_transpose2d': [torch.int64],
# Unsupported dtypes
'dot': [torch.int64],
'histc': [torch.float16],
'index_add': [torch.int64],
'log1p': [torch.int64],
'sigmoid': [torch.int64],
'atan2': [torch.int64],
# GEMM on MPS is not supported for integral types
'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'unravel_index': [torch.int32, torch.int64],
# new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as
# the MPS framework doesn't support float64
'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# returned output on CPU is float64
'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# trunc_tensor not working properly for float16
'divtrunc_rounding': [torch.float16],
'fmod': [torch.float16],
# round not working properly for float16
'round': [torch.float16],
# atomic operations not supported
'_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64],
}
if product_version < 14.0:
# FFT and BFloat16 support was added in MacOS 14
UNIMPLEMENTED_XFAILLIST.update({
'bfloat16': None,
'fft.fft': None,
'fft.fft2': None,
'fft.fftn': None,
'fft.hfft': None,
'fft.hfft2': None,
'fft.hfftn': None,
'fft.ifft': None,
'fft.ifft2': None,
'fft.ifftn': None,
'fft.ihfft': None,
'fft.ihfft2': None,
'fft.ihfftn': None,
'fft.irfft': None,
'fft.irfft2': None,
'fft.irfftn': None,
'fft.rfft': None,
'fft.rfft2': None,
'fft.rfftn': None,
'stft': None,
# Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers,
# not reproducible in later OS. Added assert to op if used in < 14.0
'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8],
'nn.functional.max_pool2d': [torch.uint8],
})
if product_version < 15.0:
UNIMPLEMENTED_XFAILLIST.update({
'quantile': None,
'nanquantile': None,
})
UNDEFINED_XFAILLIST = {
# Top 60 operators
# topk fails with duplicate indices
'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# Failures due to random output that they generate using
# Philox engine causing mismatch with CPU results
'multinomial': [torch.float16, torch.float32], # random results
'uniform': [torch.float16, torch.float32],
'rand_like': [torch.float16, torch.float32],
'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'randn_like': [torch.float16, torch.float32],
'bernoulli': [torch.float16, torch.float32],
'exponential': [torch.float16, torch.float32],
'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32],
'normal': [torch.float16, torch.float32, torch.float16, torch.float32],
'normalin_place': [torch.float16, torch.float32],
'normalnumber_mean': [torch.float16, torch.float32],
'nn.functional.alpha_dropout': [torch.float16, torch.float32],
'nn.functional.dropout': [torch.float16, torch.float32],
'nn.functional.dropout2d': [torch.float16, torch.float32],
'nn.functional.dropout3d': [torch.float16, torch.float32],
# See https://github.com/pytorch/pytorch/issues/111479
'nn.functional.multi_head_attention_forward': [torch.float32, torch.float16],
# duplicate indices are used in the testcase - undefined behaviour
'index_put': None,
# zero to negative integer powers are undefined
'__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64],
'resize_': [torch.float16, torch.float32],
'resize_as_': [torch.float16, torch.float32],
# CPU Errors:
'addr': [torch.bool, torch.int16, torch.int32,
torch.int64, torch.uint8, torch.int8], # "addmv_impl_cpu" not implemented for 'Half'
'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
# random results
# mps vs cpu:
# Mismatched elements: 40 / 96 (41.7%)
# Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
# Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
# cuda(2.0.0.dev20230301+cu117) vs cpu:
# Mismatched elements: 56 / 96 (58.3%)
# Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
# Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
'nn.functional.scaled_dot_product_attention': [torch.float32, torch.float16],
# float output for float16 input on MPS
'logit': [torch.float16],
}
ON_MPS_XFAILLIST = {
# Failures due to lack of implementation of downstream functions on MPS backend
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
'linalg.matrix_rank': None,
}
EMPTY_OPS_SKIPLIST = {
# Fill tensors with uninitialized data, causing mismatch with CPU.
# They occasionally match, thus skipping them.
# See https://github.com/pytorch/pytorch/issues/100175
'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16,
torch.int32, torch.int64, torch.uint8, torch.int8],
'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# CPU: empty is returning all 0's and there is a mismatch with MPS
# allocation (MacOS 13). According to
# https://pytorch.org/docs/2.0/generated/torch.empty.html
'empty': [torch.bool, torch.float16, torch.float32, torch.int16,
torch.int32, torch.int64, torch.uint8, torch.int8],
'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16,
torch.int32, torch.int64, torch.uint8, torch.int8],
}
SKIPLIST = {
# Unsupported
# input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible
'nn.functional.avg_pool2d': [torch.float16],
# This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
'nn.functional.conv3d': None,
}
def addDecorator(op, d) -> None:
op.decorators = list(op.decorators) if op.decorators is not None else []
op.decorators.append(d)
for op in ops:
key = op.name + op.variant_test_name
if key in EMPTY_OPS_SKIPLIST:
addDecorator(op, DecorateInfo(
unittest.skip("Skipping empty ops."),
dtypes=EMPTY_OPS_SKIPLIST[key]))
if key in SKIPLIST:
addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
if key in xfaillist:
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=xfaillist[key]))
if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_BEFORE_14_4_XFAILLIST[key]))
if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_BEFORE_13_3_XFAILLIST[key]))
if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_AFTER_13_1_XFAILLIST[key]))
if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_13_3_XFAILLIST[key]))
if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=MACOS_12_3_XFAILLIST[key]))
# If ops is not supported for complex types, expect it to fail
if key not in SUPPORTED_COMPLEX_OPS and (key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS or product_version < 14.0):
addDecorator(op, DecorateInfo(unittest.expectedFailure, dtypes=[torch.complex32, torch.complex64]))
yield op
def mps_ops_error_inputs_modifier(ops):
# Error input samples do not take a dtype argument.
XFAILLIST = {
# Exceptions are not raised
'__rmod__',
'__rsub__',
'__rpow__',
'bernoulli',
'clamp_max',
'clamp_min',
'masked_scatter',
# unsupported float64 dtype
'cat',
'complex',
'multinomial',
'nn.functional.conv1d',
'nn.functional.conv2d',
'nn.functional.conv3d',
'gather',
'scatter',
'scatter_add',
# unsupported complex dtypes
'masked_fill',
# MPS does not support tensor dimensions > 16
'amax',
'amin',
'aminmax',
# memory overlapping checks
'index_select',
# unimplemented
'logcumsumexp',
}
def addDecorator(op, d) -> None:
op.decorators = list(op.decorators) if op.decorators is not None else []
op.decorators.append(d)
for op in ops:
if op.error_inputs_func is None:
continue
key = op.name + op.variant_test_name
if key in XFAILLIST:
addDecorator(op, DecorateInfo(unittest.expectedFailure))
yield op
# Same logic as test_cuda.py
if not torch.backends.mps.is_available():
print('MPS not available, skipping tests', file=sys.stderr)
TestCase = NoTest # noqa: F811
NNTestCase = NoTest # noqa: F811
product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
# Determine whether to enable MPS memory leak check (uses same code as CUDA).
TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
def skipMPSMemoryLeakCheckIf(condition):
def dec(fn):
if getattr(fn, '_do_mps_memory_leak_check', True):
fn._do_mps_memory_leak_check = not condition
return fn
return dec
class MpsMemoryLeakCheck:
def __init__(self, testcase, name=None):
self.name = testcase.id() if name is None else name
self.testcase = testcase
def __enter__(self):
# Performs a gc if required (required if any memory is held)
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
if caching_allocator_mem_allocated > 0:
gc.collect()
torch.mps.empty_cache()
# Acquires caching allocator and driver statistics before the test is run
self.caching_allocator_before = torch.mps.current_allocated_memory()
self.driver_before = torch.mps.driver_allocated_memory()
def __exit__(self, exec_type, exec_value, traceback):
# Don't check for leaks if an exception was thrown
if exec_type is not None:
return
# Compares caching allocator before/after statistics
# An increase in allocated memory is a discrepancy indicating a possible memory leak
discrepancy_detected = False
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
if caching_allocator_mem_allocated > self.caching_allocator_before:
discrepancy_detected = True
# Short-circuits if no discrepancy detected
if not discrepancy_detected:
return
# Validates the discrepancy persists after garbage collection and
# is confirmed by the driver API
gc.collect()
torch.mps.empty_cache()
discrepancy_detected = True
# Query memory multiple items to ensure leak was not transient
for n in range(3):
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
driver_mem_allocated = torch.mps.driver_allocated_memory()
caching_allocator_discrepancy = False
driver_discrepancy = False
if caching_allocator_mem_allocated > self.caching_allocator_before:
caching_allocator_discrepancy = True
if driver_mem_allocated > self.driver_before:
driver_discrepancy = True
if not (caching_allocator_discrepancy or driver_discrepancy):
# Leak was false positive, exit loop
discrepancy_detected = False
break
if caching_allocator_discrepancy and not driver_discrepancy:
# Just raises a warning if the leak is not validated by the driver API
msg = ("MPS caching allocator reports a memory leak not "
f"verified by the driver API in {self.name}! "
f"Caching allocator allocated memory was {self.caching_allocator_before} "
f"and is now reported as {caching_allocator_mem_allocated}. "
f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
warnings.warn(msg)
elif caching_allocator_discrepancy and driver_discrepancy:
# A caching allocator discrepancy validated by the driver API is a failure
msg = (f"MPS driver API confirmed a leak in {self.name}! "
f"Caching allocator allocated memory was {self.caching_allocator_before} "
f"and is now reported as {caching_allocator_mem_allocated}. "
f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
raise RuntimeError(msg)
class TestAutocastMPS(TestCase):
def test_matmul_autocast(self):
autocast_tensor_A = torch.rand((8, 8), device="mps")
autocast_tensor_B = torch.rand((8, 8), device="mps")
tensor_A = autocast_tensor_A.clone().detach()
tensor_B = autocast_tensor_B.clone().detach()
autocast_output_tensor = torch.empty(8, 8)
output_tensor = autocast_output_tensor.clone().detach()
with torch.autocast(device_type="mps"):
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B)
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor)
output_tensor = torch.mm(tensor_A, tensor_B)
output_tensor = torch.mm(tensor_A, output_tensor)
self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16")
self.assertEqual(autocast_output_tensor,
output_tensor.to(torch.float16),
f"Autocast & non-autocast tensors did not match, \
got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}")
# Expand TestCase class with Memory Leak Detection on MPS device
class TestCaseMPS(TestCase):
_do_mps_memory_leak_check = True
def __init__(self, method_name='runTest'):
super().__init__(method_name)
test_method = getattr(self, method_name, None)
if test_method is not None:
# Wraps the tested method if we should do MPS memory check.
if TEST_MPS_MEM_LEAK_CHECK:
if self._do_mps_memory_leak_check:
self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
def assertLeaksNoMpsTensors(self, name=None):
name = self.id() if name is None else name
return MpsMemoryLeakCheck(self, name)
def wrap_with_mps_policy(self, method_name, policy):
test_method = getattr(self, method_name)
setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
# checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
def wrap_with_mps_memory_check(self, method):
return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
class TestMemoryLeak(TestCaseMPS):
def test_mps_memory_leak_detection(self):
l = []
@self.wrap_with_mps_memory_check
def no_leak():
pass
# Trigger an intentional memory leak
@self.wrap_with_mps_memory_check
def leak_gpu0():
# increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
no_leak()
# check if a runtime error for memory leak was emitted which would
# confirm whether memory leak detection worked successfully or not.
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
leak_gpu0()
def test_copy_cast_no_leak(self):
def step(x):
x = x.to(device='cpu', dtype=torch.float32)
x = x.to(device='mps', dtype=torch.float16)
a = torch.randn(128, 128, device='mps', dtype=torch.float16)
# Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
step(a)
torch.mps.empty_cache()
driver_before = torch.mps.driver_allocated_memory()
step(a)
torch.mps.empty_cache()
driver_after = torch.mps.driver_allocated_memory()
self.assertEqual(driver_before, driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
class TestPixelShuffle(TestCaseMPS):
def test_pixel_shuffle_unshuffle(self):
def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
upscale_factor=None, is_contiguous=True):
def generate_input():
# If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
height = random.randint(5, 10)
width = random.randint(5, 10)
if num_input_dims == 1:
input = torch.rand(channels, requires_grad=True, device='mps')
assert is_contiguous
elif num_input_dims == 2:
input = torch.rand(width, height, requires_grad=True, device='mps').T
if is_contiguous:
input = input.contiguous()
else:
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
input = input.transpose(-1, -2)
if is_contiguous:
input = input.contiguous()
if not is_contiguous and len(input.reshape(-1)) > 0:
assert not input.is_contiguous()
input = input.detach().clone()
input.requires_grad = True
return input
# Function to imperatively ensure pixels are shuffled to the correct locations.
# Used to validate the batch operations in pixel_shuffle.
def _verify_pixel_shuffle(input, output, upscale_factor):
for c in range(output.size(-3)):
for h in range(output.size(-2)):
for w in range(output.size(-1)):
height_idx = h // upscale_factor
weight_idx = w // upscale_factor
channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
(c * upscale_factor ** 2)
self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
input = generate_input()
ps = nn.PixelShuffle(upscale_factor)
pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
output = ps(input)
_verify_pixel_shuffle(input, output, upscale_factor)
output.backward(output.data)
self.assertEqual(input.data, input.grad.data)
# Ensure unshuffle properly inverts shuffle.
unshuffle_output = pus(output)
self.assertEqual(input, unshuffle_output)
else:
self.assertRaises(RuntimeError, lambda: ps(input))
def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
downscale_factor=None):
downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
channels = random.randint(1, 4)
# If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
# If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
if num_input_dims == 1:
input = torch.rand(channels, requires_grad=True, device='mps')
elif num_input_dims == 2:
input = torch.rand(height, width, requires_grad=True, device='mps')
else:
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
pus = nn.PixelUnshuffle(downscale_factor)
self.assertRaises(RuntimeError, lambda: pus(input))
def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
# For 1D - 2D, this is an error case.
# For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
for is_contiguous in is_contiguous_check:
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, is_contiguous=is_contiguous
)
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
)
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
)
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
)
# Error cases for pixel_unshuffle.
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
def test_pixel_shuffle_unshuffle_1D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
def test_pixel_shuffle_unshuffle_2D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
def test_pixel_shuffle_unshuffle_3D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
def test_pixel_shuffle_unshuffle_4D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
def test_pixel_shuffle_unshuffle_5D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
test_pixel_shuffle_unshuffle_1D()
test_pixel_shuffle_unshuffle_2D()
test_pixel_shuffle_unshuffle_3D()
test_pixel_shuffle_unshuffle_4D()
test_pixel_shuffle_unshuffle_5D()
class MPSReluTest(TestCaseMPS):
def _npRelu(self, np_features):
return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
def testNpRelu(self):
torch.testing.assert_close(
np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
self._npRelu(
np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
0.9]])))
def _testRelu(self, np_features, device):
np_relu = self._npRelu(np_features)
# Convert the numpy array to a PyTorch Tensor,
# and move the Tensor to the CPU/GPU based on the "device" parameter
py_tensor = torch.from_numpy(np_features).to(device)
py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
py_relu_cpu = py_relu.to("cpu")
self.assertEqual(np_relu, py_relu_cpu)
def _testReluInPlace(self, np_features, device):
np_relu = self._npRelu(np_features)
# Convert the numpy array to a PyTorch Tensor,
# and move the Tensor to the CPU/GPU based on the "device" parameter
py_tensor = torch.from_numpy(np_features).to(device)
py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
py_relu_cpu = py_relu.to("cpu")
self.assertEqual(np_relu, py_relu_cpu)
# Inplace Relu modifies the initial input and it should match the output of Relu
self.assertEqual(np_relu, py_tensor.to("cpu"))
def testNumbersCPU(self):
for t in [np.int32]:
# Force execution on CPU even if a GPU kernel is available for the type.
self._testRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="cpu")
self._testReluInPlace(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="cpu")
def testNumbersGPU(self):
for t in [np.float16, np.float32]:
self._testRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="mps")
self._testReluInPlace(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="mps")
self._testRelu(np.array([]).astype(t), device="mps")
self._testReluInPlace(np.array([]).astype(t), device="mps")
class MatmulTest(TestCaseMPS):
def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
if expand_tensor_1_shape:
tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
else:
tensor1_mps = torch.randn(shape_tensor_1, device="mps")
if expand_tensor_2_shape:
tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
else:
tensor2_mps = torch.randn(shape_tensor_2, device="mps")
tensor1_cpu = tensor1_mps.to("cpu")
tensor2_cpu = tensor2_mps.to("cpu")
matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
def test_vector_x_vector(self):
# uses `dot`
self._helper(3, 3)
def test_matrix_x_vector(self):
# uses `addmv`
self._helper((3, 4), 4)
def test_batched_matrix_x_broadcasted_vector(self):
self._helper((10, 3, 4), 4)
def test_batched_matrix_x_batched_matrix(self):
# uses `bmm.out`
self._helper((10, 3, 4), (10, 4, 5))
def test_batched_matrix_x_broadcasted_matrix(self):
self._helper((10, 3, 4), (4, 5))
class MPSLeakyReluTest(TestCaseMPS):
def _npLeakyRelu(self, np_features, negative_slope=0.1):
return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
def testNpLeakyRelu(self):
torch.testing.assert_close(
np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
[0.1, -0.03, 0.5, -0.07, 0.9]]),
self._npLeakyRelu(
np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
0.9]]),
negative_slope=0.1))
def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
if not contiguous and not (0 in shape or len(shape) < 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
mps_x = mps_x.transpose(0, 1)
assert not mps_x.is_contiguous()
cpu_x.requires_grad_()
mps_x.requires_grad_()
relu_op = torch.nn.LeakyReLU(negative_slope)
cpu_leaky_relu = relu_op(cpu_x)
mps_leaky_relu = relu_op(mps_x)
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
# test backward pass
cpu_grad = torch.ones_like(cpu_leaky_relu)
mps_grad = cpu_grad.to('mps')
mps_leaky_relu.backward(gradient=mps_grad)
cpu_leaky_relu.backward(gradient=cpu_grad)
assert cpu_x.grad is not None # Check that the grad is well-populated
self.assertEqual(cpu_x.grad, mps_x.grad)
def testNumbersCPU(self):
for t in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
self._testLeakyRelu(shape,
dtype=t,
negative_slope=0.2,
contiguous=contiguous)
class TestAvgPool(TestCaseMPS):
def _sum_pool2d(self, x, kernel_size):
windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
return torch.sum(windows, dim=1)
def _sum_pool3d(self, x, kernel_size):
# Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
h = kernel_size[0]
splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
# sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
joined_x = torch.cat(splited_x)
return joined_x.view(1, joined_x.numel())
def _avg_pool2d(self, x, kernel_size):
size = reduce(operator.mul, kernel_size) # noqa: F821
return self._sum_pool2d(x, kernel_size) / size
def _avg_pool3d(self, x, kernel_size):
size = reduce(operator.mul, kernel_size) # noqa: F821
return self._sum_pool3d(x, kernel_size) / size
def test_avg_pool2d_with_zero_divisor(self):
self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
def test_doubletensor_avg_pool2d_with_divisor(self):
n, m = 3, 3
input = torch.rand(1, 1, n, m)
for i in range(1, n + 1):
for j in range(1, m + 1):
for divisor in [1, 7, i * j]:
actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
actual = actual.view(1, actual.numel())
expected = self._sum_pool2d(input, (i, j)) / divisor
self.assertEqual(actual, expected, rtol=0, atol=1e-5)
def test_avg_pool2d_ceil_mode(self):
# Regression test for gh-36977
x = 10 * torch.randn((1, 16, 4, 4))
y = torch.nn.functional.avg_pool2d(
x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
padding=(0, 1), stride=2)
self.assertFalse(torch.isnan(y).any())
y = torch.nn.functional.avg_pool2d(
x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
padding=(0, 1), stride=2)
self.assertFalse(torch.isnan(y).any())
class TestMPS(TestCaseMPS):
def test_exp(self, device="mps", dtype=torch.float):
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
a = torch.tensor(v, dtype=dtype, device="mps") * b
self.compare_with_numpy(torch.exp, np.exp, a)
def test_conv_raises_error(self, device='mps', dtype=torch.float):
conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps')
x = torch.ones([1, 1, 3])
with self.assertRaises(NotImplementedError):
y = conv(x.to("mps"))
def test_triu_inf(self, device="mps", dtype=torch.float):
for diag in [-1, 0, 1]:
mask = torch.full((3, 6, 6), float("-inf"))
mask_mps = mask.clone().detach().to('mps')
cpu_ref = torch.triu(mask, diagonal=diag)
mps_out = torch.triu(mask_mps, diagonal=diag)
self.assertEqual(cpu_ref, mps_out)
def test_exp1(self, device="mps", dtype=torch.float):
input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
output = torch.exp(input)
output_cpu = torch.exp(input.cpu())
# If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
# Mismatched elements: 3 / 4 (75.0%)
# Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
# Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
def test_exp_strided_output(self):
x = torch.rand((256, 10), device='mps')
x_cpu = x.to("cpu")
x = x.permute(1, 0)
x_cpu = x_cpu.permute(1, 0)
res = x.exp()
res_cpu = x_cpu.exp()
self.assertEqual(res, res_cpu)
def _testLeakyRelu(self, np_features, negative_slope, device):
cpu_x = torch.from_numpy(np_features).requires_grad_()
mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
relu_op = torch.nn.LeakyReLU(negative_slope)
cpu_leaky_relu = relu_op(cpu_x)
mps_leaky_relu = relu_op(mps_x)
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
# test backward pass
cpu_grad = torch.ones_like(cpu_leaky_relu)
mps_grad = cpu_grad.to('mps')
cpu_leaky_relu.backward(gradient=cpu_grad)
mps_leaky_relu.backward(gradient=mps_grad)
torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
def testNumbersGPU(self):
for t in [np.float32]:
self._testLeakyRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
negative_slope=0.1,
device="mps")
def test_fill(self):
def helper(val, shape, dtype):
tensor = torch.zeros(shape, device='mps', dtype=dtype)
tensor_mps = tensor.fill_(val)
tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
tensor_cpu = tensor_0.fill_(val)
self.assertEqual(tensor_mps, tensor_cpu)
helper(0, [1024], torch.float32)
helper(0.2, [2, 3], torch.float32)
helper(0.2 + 0.5j, [2, 3], torch.complex64)
def test_fill_storage_offset(self):
shape = [2, 10]
val = 0.2
tensor = torch.ones(shape, device="mps")
tensor_mps = tensor[:][1].fill_(val)
tensor_0 = torch.ones(shape, device="cpu")
tensor_cpu = tensor_0[:][1].fill_(val)
self.assertEqual(tensor_mps, tensor_cpu)
self.assertEqual(tensor, tensor_0)
shape = [1, 10]
val = 0.0
tensor = torch.ones(shape, device="mps")
val_tensor_mps = torch.tensor(val, device="mps")
tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
# Regression test for https://github.com/pytorch/pytorch/issues/114692
tensor[:, 5].fill_(val_tensor_mps)
tensor_0 = torch.ones(shape, device="cpu")
val_tensor_cpu = torch.tensor(val, device="cpu")
tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
tensor_0[:, 5].fill_(val_tensor_cpu)
self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
self.assertEqual(tensor.to(device="cpu"), tensor_0)
def test_cdist_large(self, device="mps"):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(100, 10, device=device)
y = torch.randn(100, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual)
def test_cdist_large_batch(self, device="mps"):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 100, 10, device=device)
y = torch.randn(4, 3, 100, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual)
def test_cdist_non_contiguous(self, device="mps"):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(5, 7, device=device).mT
y = torch.randn(5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(7, 5, device=device)
y = torch.randn(5, 3, device=device).t()
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(5, 7, device=device).t()
y = torch.randn(3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertEqual(expected, actual)
def test_cdist_non_contiguous_batch(self, device="mps"):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 2, 5, 7, device=device).mT
y = torch.randn(4, 3, 2, 5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(7, 2, 7, 5, device=device)
y = torch.randn(7, 2, 5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(4, 5, 7, device=device).mT
y = torch.randn(4, 3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertEqual(expected, actual)
def test_cdist_euclidean_large(self, device="mps"):
def _test_euclidean_large_cdist(sizex, sizey=None):
if sizey is None:
sizey = sizex
x = torch.randn(sizex, device=device, dtype=torch.float)
y = torch.randn(sizey, device=device, dtype=torch.float)
eps = 1e-6
# to avoid extremum
x = x - (((x - y) < eps).float() * 2 * eps)
x.requires_grad = True
y.requires_grad = True
dist = torch.cdist(x, y, p=2)
# Do a backward pass to check that it is valid for large
# matrices
loss = dist.sum()
loss.backward()
_test_euclidean_large_cdist((2000, 5))
def test_cdist_same_inputs(self, device="mps"):
# Test to detect issues in cdist gradient calculation
# When the distances are 0
sizex = (1, 27, 32)
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
x = torch.randn(sizex, device=device, dtype=torch.float)
dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
y = x.clone()
eps = 1e-6
x.requires_grad = True
d = torch.cdist(x, y)
d.backward(dist_grad)
# Check that the backward passs does not contain invalid
# values such as nan or inf
assert torch.isfinite(x.grad).all()
def _brute_cdist(self, x, y, p=2):
r1 = x.shape[-2]
r2 = y.shape[-2]
if r1 == 0 or r2 == 0:
return torch.empty(r1, r2, device=x.device)
return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
def test_cdist_norm(self, device="mps"):
for r1 in [3, 4]:
for m in [2, 3]:
for r2 in [4, 6]:
for p in [0, 1, 1.5, 2.5, float('inf')]:
x = torch.randn(r1, m, device=device)
y = torch.randn(r2, m, device=device)
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual, rtol=0, atol=0.02)
else:
actual = torch.cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)
def test_cdist_norm_batch(self, device="mps"):
for r1 in [3, 4]:
for m in [2, 3]:
for r2 in [4, 6]:
for p in [0, 3, 1.5, 2.5, float('inf')]:
x = torch.randn(2, 3, 6, r1, m, device=device)
y = torch.randn(2, 3, 6, r2, m, device=device)
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual, rtol=0, atol=0.02)
else:
actual = torch.cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)
def test_mm(self):
B = torch.ones(5, 6).to("mps")
C = torch.ones(6, 5).to("mps")
D = torch.mm(B, C).cpu()
torch.testing.assert_close(D, torch.full((5, 5), 6.0))
def test_linalg_cross(self):
def helper(dtype):
device = "mps"
if dtype is torch.int32 or dtype is torch.int64:
x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
else:
x = torch.rand(100, 3, 100, dtype=dtype, device=device)
y = torch.rand(100, 3, 100, dtype=dtype, device=device)
x_cpu = x.to("cpu")
y_cpu = y.to("cpu")
res1 = torch.linalg.cross(x, y, dim=1)
res2 = torch.tensor((), dtype=dtype, device=device)
res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
torch.linalg.cross(x, y, dim=1, out=res2)
torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
self.assertEqual(res1, res2)
self.assertEqual(res1, res1_cpu)
self.assertEqual(res2, res2_cpu)
# test for broadcastable inputs
if dtype is torch.int32 or dtype is torch.int64:
x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
else:
x = torch.rand(1, 3, 2, dtype=dtype, device=device)
y = torch.rand(4, 3, 1, dtype=dtype, device=device)
x_cpu = x.to("cpu")
y_cpu = y.to("cpu")
res1 = torch.linalg.cross(x, y, dim=1)
res2 = torch.tensor((), dtype=dtype, device=device)
res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
torch.linalg.cross(x, y, dim=1, out=res2)
torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
self.assertEqual(res1, res2)
self.assertEqual(res1, res1_cpu)
self.assertEqual(res2, res2_cpu)
[helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
def test_cross(self):
a = torch.randn(4, 3, device="mps")
b = torch.randn(4, 3, device="mps")
a_cpu = a.to("cpu")
b_cpu = b.to("cpu")
res = torch.cross(a, b, dim=1)
res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
self.assertEqual(res, res_cpu)
def test_addmm(self):
A = torch.ones(5, 5).to("mps")
B = torch.ones(5, 6).to("mps")
C = torch.ones(6, 5).to("mps")
D = torch.addmm(A, B, C).to("cpu")
torch.testing.assert_close(D, torch.full((5, 5), 7.0))
def test_bmm(self):
batch1_cpu = torch.randn(10, 3, 4)
batch2_cpu = torch.randn(10, 4, 5)
batch1_mps = batch1_cpu.detach().clone().to("mps")
batch2_mps = batch2_cpu.detach().clone().to("mps")
output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
output_mps = torch.bmm(batch1_mps, batch2_mps)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
@xfailIf(product_version < 15.0)
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_large_bmm(self, dtype):
batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps')
batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps')
output_cpu = torch.bmm(batch1.cpu(), batch2.cpu())
output_mps = torch.bmm(batch1, batch2)
# Using the low precision comparison for FP16
tol = 1e-2 if dtype == torch.float16 else None
self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_addr(self):
A = torch.ones(5, 10).to("mps")
B = torch.ones(5).to("mps")
C = torch.ones(10).to("mps")
D = torch.addr(A, B, C).to("cpu")
torch.testing.assert_close(D, torch.full((5, 10), 2.0))
def test_trace(self):
M_cpu = torch.randn(3, 3)
M_mps = M_cpu.detach().clone().to("mps")
output_cpu = torch.trace(M_cpu)
output_mps = torch.trace(M_mps)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_addbmm(self):
M_cpu = torch.randn(3, 5)
batch1_cpu = torch.randn(10, 3, 4)
batch2_cpu = torch.randn(10, 4, 5)
M_mps = M_cpu.detach().clone().to("mps")
batch1_mps = batch1_cpu.detach().clone().to("mps")
batch2_mps = batch2_cpu.detach().clone().to("mps")
output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_baddbmm(self):
def helper(input_shape, batch1_shape, batch2_shape):
M_cpu = torch.randn(input_shape)
batch1_cpu = torch.randn(batch1_shape)
batch2_cpu = torch.randn(batch2_shape)
alpha = 1.2
beta = 0.8
M_mps = M_cpu.detach().clone().to("mps")
batch1_mps = batch1_cpu.detach().clone().to("mps")
batch2_mps = batch2_cpu.detach().clone().to("mps")
output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
def test_local_scalar_dense_mps(self):
x_cpu = torch.randn(1)
y_mps = x_cpu.to("mps")
torch.testing.assert_close(x_cpu.item(), y_mps.item())
def test_linear_1d_weight(self):
device = 'cpu'
projected = torch.rand([8]).to(device)
x = torch.rand([1, 2, 2, 8]).to(device)
x_mps = x.to('mps')
projected_mps = projected.to('mps')
linear = F.linear(x, projected)
linear_mps = F.linear(x_mps, projected_mps)
self.assertEqual(linear, linear_mps)
projected = torch.rand([1, 8]).to(device)
x = torch.rand([1, 2, 2, 8]).to(device)
x_mps = x.to('mps')
projected_mps = projected.to('mps')
linear = F.linear(x, projected)
linear_mps = F.linear(x_mps, projected_mps)
self.assertEqual(linear, linear_mps)
def test_linear_bias(self):
def helper(bias_shape):
device = "cpu"
x = torch.randn(2, 2, 2, 64, device=device)
linear = torch.nn.Linear(64, 4, device=device)
linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device))
y = linear(x)
device = "mps"
x_mps = x.to(device)
linear.to(device)
y_mps = linear(x_mps)
self.assertEqual(y, y_mps)
helper(())
helper((2, 4))
def test_linear_errors(self):
# Mixed CPU<->MPS tensors
size = (3, 3)
# Unsupported dtypes
with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
torch.nn.functional.linear(torch.rand(size, device='mps'),
torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
# Weigths on wrong device
with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
torch.nn.functional.linear(torch.rand(size, device='mps'),
torch.rand(size, device='cpu'))
# Input on wrong device
with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
torch.nn.functional.linear(torch.rand(size, device='cpu'),
torch.rand(size, device='mps'))
def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
# Use the same weights and bias as the ones from the cpu
mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
if bias:
mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
linear_mps_input = torch.randn(shape).to('mps')
linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
if backward_pass:
linear_mps_input = linear_mps_input.requires_grad_()
linear_cpu_input = linear_cpu_input.requires_grad_()
linear_cpu_output = cpu_linear(linear_cpu_input)
linear_mps_output = mps_linear(linear_mps_input)
self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
if backward_pass:
cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True)
grad = cpu_grad.detach().to('mps').requires_grad_()
linear_cpu_output.backward(gradient=cpu_grad, create_graph=True)
linear_mps_output.backward(gradient=grad, create_graph=True)
self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
if bias:
self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
# test gradgrad
x_grad_out = torch.rand_like(linear_cpu_input)
x_grad_out_mps = x_grad_out.to("mps")
w_grad_out = torch.rand_like(cpu_linear.weight)
w_grad_out_mps = w_grad_out.to("mps")
linear_cpu_input.grad.detach().zero_()
linear_mps_input.grad.detach().zero_()
cpu_linear.weight.grad.detach().zero_()
mps_linear.weight.grad.detach().zero_()
if bias:
b_grad_out = torch.rand_like(cpu_linear.bias)
b_grad_out_mps = b_grad_out.to("mps")
cpu_linear.bias.grad.detach().zero_()
mps_linear.bias.grad.detach().zero_()
linear_cpu_input.grad.backward(x_grad_out, retain_graph=True)
linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True)
cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True)
mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True)
if bias:
cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True)
mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True)
self.assertEqual(cpu_grad.grad, grad.grad)
self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad)
self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad)
if bias:
self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad)
def test_linear1D(self):
self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False)
def test_linear1D_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True)
def test_linear2D(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
def test_linear2D_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
def test_linear2D_no_bias(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
def test_linear2D_no_bias_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
def test_linear3D(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
def test_linear3D_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
def test_linear3D_no_bias(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
def test_linear3D_no_bias_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
def test_uniform(self):
low = torch.zeros(5, 5, requires_grad=True)
high = (torch.ones(5, 5) * 3).requires_grad_()
low_1d = torch.zeros(1, requires_grad=True)
high_1d = (torch.ones(1) * 3).requires_grad_()
self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
# Check log_prob computation when value outside range
uniform = Uniform(low_1d, high_1d, validate_args=False)
above_high = torch.tensor([4.0])
below_low = torch.tensor([-1.0])
self.assertEqual(uniform.log_prob(above_high).item(), -inf)
self.assertEqual(uniform.log_prob(below_low).item(), -inf)
# check cdf computation when value outside range
self.assertEqual(uniform.cdf(below_low).item(), 0)
self.assertEqual(uniform.cdf(above_high).item(), 1)
state = torch.get_rng_state()
rand = low.new(low.size()).uniform_()
torch.set_rng_state(state)
u = Uniform(low, high).rsample()
u.backward(torch.ones_like(u))
self.assertEqual(low.grad, 1 - rand)
self.assertEqual(high.grad, rand)
low.grad.zero_()
high.grad.zero_()
def test_randperm(self, device="mps"):
rng_device = None
for n in (5, 100, 50000, 100000):
for dtype in (torch.long, torch.half, torch.float):
if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here.
continue
if n > 256 and dtype == torch.bfloat16:
continue
with torch.random.fork_rng(devices=rng_device):
res1 = torch.randperm(n, dtype=dtype, device=device)
res2 = torch.empty(0, dtype=dtype, device=device)
torch.randperm(n, out=res2, dtype=dtype, device=device)
self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
# Default type is long
for n in (100, 10000):
self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
# randperm of 0 elements is an empty tensor
res1 = torch.randperm(0)
res2 = torch.tensor(5, dtype=dtype, device=device)
torch.randperm(0, out=res2)
self.assertEqual(res1.numel(), 0)
self.assertEqual(res2.numel(), 0)
# Test non-contiguous tensors
for n in (4, 5, 6, 10, 20):
non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
self.assertFalse(non_contiguous_tensor.is_contiguous())
with torch.random.fork_rng(devices=rng_device):
res = torch.randperm(n, dtype=torch.long, device=device)
torch.randperm(n, out=non_contiguous_tensor)
self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
# Test forward maxpool2d
def test_max_pool2d(self):
def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
cpu_x = None
if (test_ties):
cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
else:
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
ceil_mode=ceil_mode, return_indices=return_indices)
if (return_indices is False):
y = pool(x)
ref_y = pool(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
else:
y, idx = pool(x)
ref_y, ref_idx = pool(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(idx, ref_idx)
self.assertEqual(x.grad, cpu_x.grad)
# Test with no batch dimension
helper((8, 4, 4), ks=2)
helper((2, 8, 4, 4), ks=2)
helper((1, 1000, 32, 32), ks=4)
helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d
# Test padding
helper((1, 1000, 32, 32), ks=4, padding=1)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d
# Test dilation
helper((1, 1000, 32, 32), ks=4, dilation=2)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d
# Test ceil mode
helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d
# Test return indices
for test_ties in [False, True]:
# Test with no batch dimension
helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d
# Test padding
helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
return_indices=True, test_ties=test_ties) # test for max_pool1d
# Test dilation
helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
return_indices=True, test_ties=test_ties) # test for max_pool1d
# Test ceil mode
helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
return_indices=True, test_ties=test_ties) # test for max_pool1d
def test_adaptive_avg_pool2d_output_size_one(self):
def helper(size, memory_format):
x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
if memory_format == 'non_contiguous':
x = x[::2, ::2, ::2, ::2]
else:
x = x.to(memory_format=memory_format)
net = torch.nn.AdaptiveAvgPool2d((1, 1))
out = net(x)
ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
out.sum().backward() # make sure it doesn't crash
self.assertEqual(out, ref_out)
if memory_format == torch.channels_last:
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
c = out.size(1)
self.assertEqual(out.stride(), [c, 1, c, c])
else:
self.assertTrue(out.is_contiguous())
c = out.size(1)
self.assertEqual(out.stride(), [c, 1, 1, 1])
helper((2, 3, 6, 6), torch.contiguous_format)
def test_masked_scatter(self):
def helper(shape):
x_mps = torch.randn(shape, device="mps")
x_cpu = x_mps.detach().clone().cpu()
mask_mps = torch.rand(shape, device="mps") < 0.6
mask_cpu = mask_mps.detach().clone().cpu()
y_mps = torch.randn(shape, device="mps")
y_cpu = y_mps.detach().clone().cpu()
y_mps.masked_scatter_(mask_mps, x_mps)
y_cpu.masked_scatter_(mask_cpu, x_cpu)
self.assertEqual(y_mps, y_cpu)
helper([2, 5])
helper([10, 10])
helper([5, 10, 3])
helper([10, 5, 10, 3])
helper([10, 5, 10, 3, 20])
def test_masked_fill(self):
device = "mps"
dtype = torch.float32
mask_dtype = torch.bool
num_dest = 10
dst = torch.zeros(num_dest, dtype=dtype, device=device)
mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
val = random.random()
dst2 = torch.zeros(num_dest, dtype=dtype)
mask_cpu = mask.to("cpu")
dst.masked_fill_(mask, val)
for i in range(num_dest):
if mask_cpu[i]:
dst2[i] = val
self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
def test_masked_fill__non_contiguous(self):
shape = (3, 5)
x_mps = torch.randn(shape, device="mps")
x_cpu = x_mps.detach().clone().cpu()
mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
mask_cpu = mask_mps.detach().clone().cpu()
x_mps_strided = x_mps.T
x_cpu_strided = x_cpu.T
x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
self.assertEqual(x_mps_strided, x_cpu_strided)
self.assertFalse((x_mps_strided == float("-inf")).any())
def test_nhwc_operation(self):
def helper(shape, channels_last=False):
import numpy as np
np.random.seed(332)
arr = (256 - 128) * np.random.random_sample(size=shape) + 128
cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
# This passes
self.assertEqual(x, cpu_x)
helper((2, 2, 2, 2), True)
# Test forward batch norm
def test_batch_norm(self):
def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
track_running_stats=True, test_module=False):
import numpy as np
np.random.seed(332)
arr = (256 - 128) * np.random.random_sample(size=shape) + 128
cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
mean_shape = [shape[1]]
cpu_running_mean = None
cpu_running_var = None
running_mean = None
running_var = None
if (track_running_stats):
mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
var_arr = 32 * np.random.random_sample(size=mean_shape)
cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
running_mean = cpu_running_mean.detach().clone().to('mps')
running_var = cpu_running_var.detach().clone().to('mps')
weight = None
cpu_weight = None
bias = None
cpu_bias = None
if (wts):
cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
weight = cpu_weight.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = None
ref_y = None
if (not test_module):
y = torch.nn.functional.batch_norm(x, running_mean, running_var,
weight=weight,
bias=bias,
training=training,
momentum=momentum, eps=eps)
ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
weight=cpu_weight,
bias=cpu_bias,
training=training,
momentum=momentum, eps=eps)
else:
batchnorm_op = None
mps_batchnorm_op = None
if (len(shape) == 3):
batchnorm_op = torch.nn.BatchNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 4):
batchnorm_op = torch.nn.BatchNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 5):
batchnorm_op = torch.nn.BatchNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
if (track_running_stats):
batchnorm_op.running_mean = cpu_running_mean
batchnorm_op.running_var = cpu_running_var
mps_batchnorm_op.running_mean = running_mean
mps_batchnorm_op.running_var = running_var
if (wts):
batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
mps_batchnorm_op.weight = torch.nn.Parameter(weight)
mps_batchnorm_op.bias = torch.nn.Parameter(bias)
ref_y = batchnorm_op(cpu_x)
y = mps_batchnorm_op(x)
self.assertEqual(y, ref_y)
if (not test_module):
self.assertEqual(running_mean, cpu_running_mean)
self.assertEqual(running_var, cpu_running_var)
else:
self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
ref_y.backward(gradient=cpu_grad)
y.backward(gradient=grad)
self.assertEqual(x.grad, cpu_x.grad)
if (wts):
if (not test_module):
self.assertEqual(weight.grad, cpu_weight.grad)
self.assertEqual(bias.grad, cpu_bias.grad)
else:
self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
for test_module in [False, True]:
for track_running_stats in [True, False]:
for channels_last in [False]:
if (channels_last and len(shape) != 4):
continue
# Running stats must be tracked in eval mode
if (track_running_stats):
helper(shape, eps=0, momentum=1, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
def test_batch_norm_backward(self):
inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
x = torch.nn.BatchNorm2d(8).to("mps")
y = torch.nn.BatchNorm2d(8).to("mps")
y.weight.requires_grad = False
y.bias.requires_grad = False
outputs = y(x(inputs))
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
outputs.sum().backward()
# Regression test for https://github.com/pytorch/pytorch/issues/133520
def test_batch_norm_slices(self):
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
x_mps = x_cpu.to('mps')
res_cpu = bn_cpu(x_cpu[5:])
res_mps = bn_mps(x_mps[5:])
self.assertEqual(res_cpu, res_mps)
def test_layer_norm_backward(self):
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
x = torch.nn.LayerNorm(4).to("mps")
y = torch.nn.LayerNorm(4).to("mps")
y.weight.requires_grad = False
y.bias.requires_grad = False
outputs = y(x(inputs))
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
outputs.sum().backward()
def test_norm(self):
a = torch.arange(9, dtype=torch.float, device="mps") - 4
b = a.reshape((3, 3))
a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
b_cpu = a_cpu.reshape((3, 3))
res = torch.norm(a)
res_cpu = torch.norm(a_cpu)
self.assertEqual(res, res_cpu)
res = torch.norm(b)
res_cpu = torch.norm(b_cpu)
self.assertEqual(res, res_cpu)
res = torch.norm(a, float('inf'))
res_cpu = torch.norm(a_cpu, float('inf'))
self.assertEqual(res, res_cpu)
res = torch.norm(b, float('inf'))
res_cpu = torch.norm(b_cpu, float('inf'))
self.assertEqual(res, res_cpu)
c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
res = torch.norm(c, dim=0)
res_cpu = torch.norm(c_cpu, dim=0)
self.assertEqual(res, res_cpu)
res = torch.norm(c, dim=1)
res_cpu = torch.norm(c_cpu, dim=1)
self.assertEqual(res, res_cpu)
res = torch.norm(c, p=1, dim=1)
res_cpu = torch.norm(c_cpu, p=1, dim=1)
self.assertEqual(res, res_cpu)
d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
res = torch.norm(d, dim=(1, 2))
res_cpu = torch.norm(d_cpu, dim=(1, 2))
self.assertEqual(res, res_cpu)
res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
self.assertEqual(res, res_cpu)
def test_linalg_vector_norm(self):
x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps")
x_cpu = x_mps.detach().clone().cpu()
res_mps = torch.linalg.vector_norm(x_mps, ord=0)
res_cpu = torch.linalg.vector_norm(x_cpu, ord=0)
self.assertEqual(res_mps, res_cpu)
a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4
a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4
B_mps = a_mps.reshape(3, 3, 3)
B_cpu = a_cpu.reshape(3, 3, 3)
res_mps = torch.linalg.vector_norm(a_mps, ord=3.5)
res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5)
self.assertEqual(res_mps, res_cpu)
res_mps = torch.linalg.vector_norm(B_mps, ord=3.5)
res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5)
self.assertEqual(res_mps, res_cpu)
for dim in range(0, B_mps.dim()):
res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim)
res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim)
self.assertEqual(res_mps, res_cpu)
def test_layer_norm(self):
# TODO: Test non-contiguous
def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):
cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype)
mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
wt = cpu_wt.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
if (elementwise_affine):
cpu_op.weight = torch.nn.Parameter(cpu_wt)
mps_op.weight = torch.nn.Parameter(wt)
cpu_op.bias = torch.nn.Parameter(cpu_bias)
mps_op.bias = torch.nn.Parameter(bias)
cpu_result = cpu_op(cpu_x)
result = mps_op(x)
cpu_grad = torch.randn(cpu_result.shape)
grad = cpu_grad.to('mps')
cpu_result.backward(cpu_grad)
result.backward(grad)
self.assertEqual(result, cpu_result)
self.assertEqual(x.grad, cpu_x.grad)
if (elementwise_affine):
self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad)
self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad)
for elementwise_affine in [True, False]:
helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine)
helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
# Regression test for https://github.com/pytorch/pytorch/issues/96113
torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
@xfailIf(product_version < 14.0)
def test_ifft(self):
# See: https://github.com/pytorch/pytorch/issues/124096
device = torch.device("mps")
N = 64
signal = torch.rand(N, device=device)
fft_result = torch.fft.rfft(signal)
ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0])
# Expecting the inverted to yield the original signal
self.assertEqual(ifft_result, signal)
# Regression test for https://github.com/pytorch/pytorch/issues/135223
def test_fftfreq(self):
freq_cpu = torch.fft.fftfreq(10**4, device='cpu')
freq_mps = torch.fft.fftfreq(10**4, device='mps')
self.assertEqual(freq_cpu, freq_mps)
def test_instance_norm(self):
def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
import numpy as np
np.random.seed(332)
arr = (256 - 128) * np.random.random_sample(size=shape) + 128
cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
mean_shape = [shape[1]]
cpu_running_mean = None
cpu_running_var = None
running_mean = None
running_var = None
if (track_running_stats):
mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
var_arr = 32 * np.random.random_sample(size=mean_shape)
cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
running_mean = cpu_running_mean.detach().clone().to('mps')
running_var = cpu_running_var.detach().clone().to('mps')
weight = None
cpu_weight = None
bias = None
cpu_bias = None
if (wts):
cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
weight = cpu_weight.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = None
ref_y = None
if (not test_module):
ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
weight=cpu_weight,
bias=cpu_bias,
momentum=momentum, eps=eps)
y = torch.nn.functional.instance_norm(x, running_mean, running_var,
weight=weight,
bias=bias,
momentum=momentum, eps=eps)
else:
instancenorm_op = None
mps_instancenorm_op = None
if (len(shape) == 3):
instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 4):
instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 5):
instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
if (track_running_stats):
instancenorm_op.running_mean = cpu_running_mean
instancenorm_op.running_var = cpu_running_var
mps_instancenorm_op.running_mean = running_mean
mps_instancenorm_op.running_var = running_var
if (wts):
instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
mps_instancenorm_op.weight = torch.nn.Parameter(weight)
mps_instancenorm_op.bias = torch.nn.Parameter(bias)
ref_y = instancenorm_op(cpu_x)
y = mps_instancenorm_op(x)
self.assertEqual(y, ref_y)
if (not test_module):
self.assertEqual(running_mean, cpu_running_mean)
self.assertEqual(running_var, cpu_running_var)
else:
self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
ref_y.backward(gradient=cpu_grad)
y.backward(gradient=grad)
self.assertEqual(x.grad, cpu_x.grad)
if (wts):
if (not test_module):
self.assertEqual(weight.grad, cpu_weight.grad)
self.assertEqual(bias.grad, cpu_bias.grad)
else:
self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
for test_module in [False, True]:
for track_running_stats in [True, False]:
for channels_last in [False]:
if (channels_last and len(shape) != 4):
continue
# Running stats must be tracked in eval mode
if (track_running_stats):
helper(shape, eps=0, momentum=1, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
def test_weight_norm(self):
def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim):
cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim)
norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim)
cpu_out = cpu_norm(cpu_x)
out = norm(x)
self.assertEqual(cpu_out, out)
cpu_grad = torch.randn(cpu_out.shape)
grad = cpu_grad.to('mps')
cpu_out.backward(gradient=cpu_grad)
out.backward(gradient=grad)
self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad)
self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad)
self.assertEqual(x.grad, cpu_x.grad)
def helper(dim, layer='linear', dtype=torch.float32):
# linear layer
if layer == 'linear':
cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True)
weight = cpu_weight.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
cpu_linear = torch.nn.Linear(5, 10, device='cpu')
linear = torch.nn.Linear(5, 10, device='mps')
with torch.no_grad():
cpu_linear.weight.copy_(cpu_weight)
cpu_linear.bias.copy_(cpu_bias)
linear.weight.copy_(weight)
linear.bias.copy_(bias)
validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim)
# conv layer
if layer == 'conv':
cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu')
conv = torch.nn.Conv2d(3, 3, 3, device='mps')
with torch.no_grad():
conv.weight.copy_(cpu_conv.weight)
conv.bias.copy_(cpu_conv.bias)
validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
# conv3d layer
if layer == 'conv3d':
cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
conv = torch.nn.Conv3d(3, 3, 3, device='mps')
with torch.no_grad():
conv.weight.copy_(cpu_conv.weight)
conv.bias.copy_(cpu_conv.bias)
validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
helper(0, layer='linear')
helper(1, layer='linear')
helper(-1, layer='linear')
helper(0, layer='conv')
helper(1, layer='conv')
helper(2, layer='conv')
helper(3, layer='conv')
helper(-1, layer='conv')
if product_version >= 13.2:
# Conv3d is only available from MacOS 13 onwards
helper(0, layer='conv3d')
helper(1, layer='conv3d')
helper(2, layer='conv3d')
helper(3, layer='conv3d')
helper(4, layer='conv3d')
helper(-1, layer='conv3d')
# Test conv2d
def test_conv2d_unit(self):
def helper(input_shape, wt_shape,
stride=1, padding=0,
dilation=1, groups=1,
bias_shape=None):
cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
wt = cpu_wt.detach().clone().to('mps').requires_grad_()
cpu_bias = None
bias = None
if (bias_shape is not None):
cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
padding=padding, dilation=dilation, groups=groups)
ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
padding=padding, dilation=dilation, groups=groups)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
if (bias_shape is not None):
self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
N = 1
C_in = 3
C_out = 64
H = 64
W = 64
kH = 4
kW = 4
stride = 2
padding = 1
helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
N = 4
C_in = 16
H = 32
W = 32
C_out = 8
kH = 3
kW = 3
for groups in [1, 2, 4]:
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
# Test conv transpose 2d
def test_conv_transpose2d(self):
def helper(input_shape, wt_shape,
stride=1, padding=0,
output_padding=0,
dilation=1, groups=1,
bias_shape=None):
cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
wt = cpu_wt.detach().clone().to('mps').requires_grad_()
cpu_bias = None
bias = None
if (bias_shape is not None):
cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = torch.nn.functional.conv_transpose2d(
x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
ref_y = torch.nn.functional.conv_transpose2d(
cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
output_padding=output_padding, groups=groups, dilation=dilation)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
# if (bias_shape is not None):
# print(cpu_bias.grad)
# print(bias.grad.to('cpu'))
# self.assertEqual(bias.grad, cpu_bias.grad)
N = 4
C_in = 2
H = 32
W = 32
C_out = 8
groups = 1
kH = 3
kW = 3
for stride in [1, 2, 3]:
for padding in [0, 1, 2]:
for output_padding in [0, 1, 2]:
for dilation in [1, 2]:
if (output_padding >= stride or output_padding >= dilation):
continue
helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
# Test sigmoid
def test_sigmoid(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
sigmoid_op = torch.nn.Sigmoid()
y = sigmoid_op(x)
ref_y = sigmoid_op(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 3, 4, 5))
helper((2, 3, 4))
helper((2, 8, 4, 5))
# Test tanh
def test_tanh(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
tanh_op = torch.nn.Tanh()
y = tanh_op(x)
ref_y = tanh_op(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 3, 4, 5))
helper((2, 3, 4))
helper((2, 8, 4, 5))
def test_threshold(self):
def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
output_cpu = m(input_cpu)
output_mps = m(input_mps)
cpu_grad = torch.ones_like(output_cpu)
mps_grad = cpu_grad.to('mps')
self.assertEqual(output_cpu, output_mps)
if requires_grad:
output_cpu.backward(gradient=cpu_grad)
output_mps.backward(gradient=mps_grad)
self.assertEqual(input_cpu.grad, input_mps.grad)
helper(threshold=0.1, value=20, num_elems=2)
helper(threshold=-0.1, value=10, num_elems=10)
helper(threshold=0.5, value=-15, num_elems=100)
helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
# Test pow
def test_pow(self):
def helper(shape):
# aten::pow.Tensor_Tensor
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
z = torch.pow(x, y)
ref_z = torch.pow(cpu_x, cpu_y)
self.assertEqual(z, ref_z)
# aten::pow.Tensor_Scalar
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
exp = random.random()
z = torch.pow(x, exp)
ref_z = torch.pow(cpu_x, exp)
self.assertEqual(z, ref_z)
# aten::pow.Scalar
x = random.random()
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
z = torch.pow(x, y)
ref_z = torch.pow(x, cpu_y)
self.assertEqual(z, ref_z)
helper((2, 8, 4, 5))
# Test addcmul
def test_addcmul(self):
def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None):
def rand_helper(dtype):
if dtype.is_floating_point:
return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False)
cpu_x = rand_helper(xtype)
x = cpu_x.detach().clone().to('mps')
cpu_y = rand_helper(ytype if ytype is not None else xtype)
y = cpu_y.detach().clone().to('mps')
cpu_z = rand_helper(ztype if ztype is not None else xtype)
z = cpu_z.detach().clone().to('mps')
y = torch.addcmul(x, y, z, value=value)
ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
self.assertEqual(y, ref_y)
helper((2, 3, 4, 5), 0.1)
helper((2, 8, 4, 5), 0.1)
helper((2, 3, 4, 5), 0.2)
helper((2, 8, 4, 5), 0.2)
# Integral types
helper((2, 2), 1.0, xtype=torch.int32)
helper((2, 2), 2.0, xtype=torch.int16)
# Mixed types
helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32)
helper((3, 2), 1.0, ytype=torch.float16)
helper((2, 3), 1.0, ztype=torch.float16)
helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8)
helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8)
# Test addcdiv
def test_addcdiv(self):
def helper(shape, value):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
# clamp to avoid division by 0
cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
mps_z = cpu_z.detach().clone().to('mps')
mps_out = cpu_out.detach().clone().to('mps')
result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
self.assertEqual(result_div_mps, result_div_cpu)
# test .out variant
self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
helper((2, 3, 4, 5), 0.1)
helper((2, 8, 4, 5), 0.2)
helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally
def test_addcdiv_transpose(self):
# Regression test for issue https://github.com/pytorch/pytorch/issues/118115
# Testing continuity of all input tensors
def helper(shape, value):
shape_t = shape[::-1]
for i in range(2):
for j in range(2):
for k in range(2):
x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t()
y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t()
z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t()
x_mps = x.detach().clone().to(device="mps")
y_mps = y.detach().clone().to(device="mps")
z_mps = z.detach().clone().to(device="mps")
result_cpu = x.addcdiv_(y, z, value=value)
result_mps = x_mps.addcdiv(y_mps, z_mps, value=value)
result_mps_out = result_cpu.detach().clone().to('mps')
torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value)
self.assertEqual(result_cpu, result_mps)
self.assertEqual(result_cpu, result_mps_out)
helper((2, 3), 1.0)
helper((2, 3), 0.2)
helper((100, 300), 1.0)
helper((100, 300), 0.2)
def test_buffer_size_match(self):
# this test shouldn't cause any crash
size = 16
cpu_A = torch.rand(size, device='cpu')
cpu_F = torch.rand(size, size, size, device='cpu')
mps_A = cpu_A.to('mps')
mps_F = cpu_F.to('mps')
self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
def test_transpose_inplace(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_x.transpose_(0, 1)
mps_x.transpose_(0, 1)
self.assertEqual(cpu_x, mps_x.to('cpu'))
def test_expand_cpu_to_mps_copy(self):
# https://github.com/pytorch/pytorch/issues/78642
x = torch.tensor(1).expand([10]).to("mps")
x_cpu = torch.tensor(1).expand([10])
self.assertEqual(x_cpu, x.cpu())
def test_cpu_to_strided_mps_copy(self):
# https://github.com/pytorch/pytorch/issues/86975
a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
b1 = torch.Tensor([-1, -1])
a1[1:, 1] = b1
a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
a2[1:, 1] = b2
self.assertEqual(a1, a2)
def test_view_slice_reshape(self):
x = torch.randn([1, 4, 4], device="mps")
y = x[0, :1, 1:]
x_cpu = x.to("cpu")
y_cpu = x_cpu[0, :1, 1:]
r = y + 1
r_cpu = y_cpu + 1
self.assertEqual(r, r_cpu)
def test_slice_reshape(self):
x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
x_cpu = x.detach().clone().to("cpu")
x = x[:, 3:].view(2, 3, 4, 1)
x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1)
self.assertEqual(x, x_cpu)
x = x + 2
x_cpu = x_cpu + 2
self.assertEqual(x, x_cpu)
def test_reshape_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/95883
B = 4
T = 1
lin_cpu = nn.Linear(10, 256)
lin_mps = nn.Linear(10, 256, device="mps")
# Use the same weights and bias as the ones from the cpu
lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
x_cpu = x_mps.detach().clone().cpu().requires_grad_()
x_mps = lin_mps(x_mps)
x_cpu = lin_cpu(x_cpu)
self.assertEqual(x_mps.shape, (B, T, 256))
self.assertEqual(x_cpu.shape, (B, T, 256))
cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
cls_token_cpu = cls_token_mps.detach().clone().cpu()
x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
x_mps = x_mps.transpose(0, 1)
x_cpu = x_cpu.transpose(0, 1)
target_mps = torch.rand_like(x_mps)
target_cpu = target_mps.detach().clone().cpu()
loss_mps = F.mse_loss(x_mps, target_mps)
loss_cpu = F.mse_loss(x_cpu, target_cpu)
self.assertEqual(loss_mps, loss_cpu)
loss_mps.backward()
loss_cpu.backward()
self.assertEqual(x_mps.grad, x_cpu.grad)
def test_stack_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/87856
x_cpu = torch.tensor([[1, 2]])
x_mps = x_cpu.detach().clone().to("mps")
y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
self.assertEqual(y_cpu, y_mps)
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
t_cpu = t_mps.detach().cpu().detach()
x_mps = t_mps[2:]
y_mps = t_mps[:2]
x_cpu = t_cpu[2:]
y_cpu = t_cpu[:2]
res_mps = torch.stack((y_mps, x_mps), dim=-1)
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
self.assertEqual(res_mps, res_cpu)
def test_unsafe_chunk(self):
# https://github.com/pytorch/pytorch/issues/91065
a = torch.rand(5, dtype=torch.float32, device="cpu")
ret = a.unsafe_chunk(4, 0)
y = ret[0] * ret[2]
a_mps = a.to("mps")
ret_mps = a_mps.unsafe_chunk(4, 0)
y_mps = ret_mps[0] * ret_mps[2]
self.assertEqual(y, y_mps)
def test_slice_casting(self):
# generate random binary numbers
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
mps_in = cpu_in.detach().clone().to("mps")
# check copy_cast(unit8 -> bool) on tensors with storage offset
cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
self.assertEqual(cpu_out, mps_out)
def test_slice_reshape_contg_view(self):
import torch
x_mps = torch.randn(1, 4800, 2, device="mps")
x_cpu = x_mps.detach().clone().cpu()
r_mps = x_mps + 2
r_cpu = x_cpu + 2
self.assertEqual(r_mps, r_cpu)
def test_contiguous_slice_2d(self):
def helper(shape):
for i in range(0, shape[0]):
for j in range(0, shape[1]):
t_mps = torch.randn(shape, device="mps")
t_cpu = t_mps.detach().clone().cpu()
y_mps = t_mps[i:, :j]
y_cpu = t_cpu[i:, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[i:, j]
y_cpu = t_cpu[i:, j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[i, :j]
y_cpu = t_cpu[i, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, :j]
y_cpu = t_cpu[:i, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, j]
y_cpu = t_cpu[:i, j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, j:]
y_cpu = t_cpu[:i, j:]
self.assertEqual(y_mps + 1, y_cpu + 1)
l = []
for N in range(1, 3):
l.append(N)
for C in range(1, 3):
l.append(C)
helper(l)
for D in range(1, 3):
l.append(D)
helper(l)
for H in range(1, 3):
l.append(H)
helper(l)
for W in range(1, 3):
l.append(W)
helper(l)
l.pop()
l.pop()
l.pop()
l.pop()
l.pop()
helper([9, 15, 4])
helper([9, 3, 2])
helper([3, 4, 18, 22])
helper([3, 4, 18, 22, 150])
def test_contiguous_slice_3d(self):
x = torch.randn(2, 3, 3, device="mps")
x_cpu = x.detach().clone().cpu()
x = x[:1]
x_cpu = x_cpu[:1]
out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
self.assertEqual(out, out_cpu)
def test_view_slice(self):
# https://github.com/pytorch/pytorch/issues/83995
NUM_SAMPLES = 60
s = (0, 1)
X = torch.rand(8000, 3, dtype=torch.float32, device='cpu')
X_mps = X.detach().clone().to("cpu")
idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s))
pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1]))
idx_mps = idx.to("mps")
pts_mps = pts.to("mps")
pts[:, s] = idx
pts_mps[:, s] = idx_mps
actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float)
actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
for i in range(NUM_SAMPLES):
for j in range(X.shape[1]):
actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j]
actual_pts[i, j] = X[pts[i, j], j]
self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j])
def test_slice_scatter(self):
shape = (4, 4)
tensor = torch.randint(10, shape, device="mps")
tensor_before = tensor.clone()
torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
torch.testing.assert_close(tensor, tensor_before)
def test_slice(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
cpu_slice1 = cpu_x[:2, :]
mps_slice1 = mps_x[:2, :]
self.assertEqual(cpu_slice1, mps_slice1)
cpu_slice2 = cpu_x[:, :1]
mps_slice2 = mps_x[:, :1]
self.assertEqual(cpu_slice2, mps_slice2)
cpu_slice3 = cpu_x[1:2, :]
mps_slice3 = mps_x[1:2, :]
self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
cpu_slice4 = cpu_x[1, :]
mps_slice4 = mps_x[1, :].to('cpu')
self.assertEqual(cpu_slice4, mps_slice4)
@parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16])
def test_slice_view_api(self, torch_type: torch.dtype):
def helper(x_tensor, y_func, z_func, r_func=None):
x_mps = x_tensor.detach().clone().to("mps")
y = y_func(x_tensor)
y_mps = y_func(x_mps)
self.assertEqual(y, y_mps)
z = z_func(y)
z_mps = z_func(y_mps)
self.assertEqual(z, z_mps)
self.assertEqual(z.storage_offset(), z_mps.storage_offset())
if r_func:
r = r_func(z)
r_mps = r_func(z_mps)
self.assertEqual(r, r_mps)
# Skip bfloat16 before MacOS15
if not (product_version < 15.0 and torch_type == torch.bfloat16):
# Tests for previously encountered MPS bugs
helper(
torch.randn(4, 4, dtype=torch_type),
lambda x: x[1],
lambda y: y.reshape(2, 2),
lambda z: z + 1
)
helper(
torch.randn(2, 4, dtype=torch_type),
lambda x: x[1],
lambda y: y + torch.ones(4, device=y.device)
)
helper(
torch.randn(4, 6, dtype=torch_type),
lambda x: x[1],
lambda y: y.reshape(3, 2).t(),
lambda z: z + 1
)
helper(
torch.arange(4, dtype=torch_type).resize(1, 2, 2),
lambda x: x.permute(2, 0, 1),
lambda y: y + 1
)
helper(
torch.randn(4, 8, dtype=torch_type),
lambda x: x.transpose(0, 1).reshape(-1),
lambda y: y[:2],
lambda z: z + 1
)
helper(
torch.randn(1, dtype=torch_type),
lambda x: x.expand(2, 3),
lambda y: y + torch.ones(2, 3, device=y.device)
)
def test_slice_reshape_contiguous(self):
x = torch.randn(4, 4)
x_mps = x.detach().clone().to("mps")
y = x[1]
y_mps = x_mps[1]
self.assertEqual(y, y_mps)
z = y.reshape(2, 2)
z_mps = y_mps.reshape(2, 2)
self.assertEqual(z, z_mps)
self.assertEqual(z.storage_offset(), z_mps.storage_offset())
def test_scalar_from_slice_unary(self):
# https://github.com/pytorch/pytorch/issues/82543
tensor_list = torch.tensor([1.0, 1.2], device="mps")
for scalar in tensor_list:
r_mps = torch.ceil(scalar)
r_cpu = torch.ceil(scalar.to("cpu"))
self.assertEqual(r_mps.cpu(), r_cpu)
def test_scalar_from_slice_binary(self):
# https://github.com/pytorch/pytorch/issues/82543
def helper(binary_op):
tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
for scalar in tensor_list:
r_mps = binary_op(scalar, 1.0)
r_cpu = binary_op(scalar.cpu(), 1.0)
self.assertEqual(r_mps.cpu(), r_cpu)
helper(torch.sub)
helper(torch.add)
helper(torch.not_equal)
helper(torch.eq)
def test_slice_contiguous_view(self):
# https://github.com/pytorch/pytorch/issues/77750
def helper(operator):
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
# contiguous view
x_mps = t_mps[2:] # 3, 4
y_mps = t_mps[:2] # 1, 2
x_cpu = t_cpu[2:]
y_cpu = t_cpu[:2]
res_mps = res_cpu = None
if operator == "<=":
res_mps = x_mps <= y_mps
res_cpu = x_cpu <= y_cpu
elif operator == "<":
res_mps = x_mps < y_mps
res_cpu = x_cpu < y_cpu
elif operator == ">=":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
elif operator == ">":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
elif operator == "==":
res_mps = x_mps == y_mps
res_cpu = x_cpu == y_cpu
elif operator == "!=":
res_mps = x_mps != y_mps
res_cpu = x_cpu != y_cpu
elif operator == "stack":
res_mps = torch.stack((y_mps, x_mps), dim=-1)
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
self.assertEqual(res_mps, res_cpu)
for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
helper(op)
def test_slice_of_slice(self):
x = torch.tensor([0.5, 0.5], device="cpu")
x_mps = torch.tensor([0.5, 0.5], device="mps")
tensor = x[1][None]
tensor_mps = x_mps[1][None]
res = tensor.ne(0)
res_mps = tensor_mps.ne(0)
self.assertEqual(res, res_mps)
def test_index_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/78107
a = torch.tensor([8.2670e-01, -1.0293e+00])
b_cpu = a[0]
c_cpu = a[1]
# both 'b' and 'c' are views of 'a'
# 'b' has a storage offset of 0, while 'c' has a storage offset of 1
# when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
# otherwise it ends with same value as 'b'
b = b_cpu.to('mps')
c = c_cpu.to('mps')
res_mps = b > c
res_cpu = b_cpu > c_cpu
self.assertEqual(res_mps, res_cpu)
res_mps = c > b
res_cpu = c_cpu > b_cpu
self.assertEqual(res_mps, res_cpu)
def test_flatten(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_flatten1 = cpu_x.flatten()
mps_flatten1 = mps_x.flatten().to('cpu')
self.assertEqual(cpu_flatten1, mps_flatten1)
cpu_flatten2 = cpu_x.flatten(start_dim=1)
mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
self.assertEqual(cpu_flatten2, mps_flatten2)
cpu_flatten3 = cpu_x.flatten(end_dim=1)
mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
self.assertEqual(cpu_flatten3, mps_flatten3)
# Test repeat
def test_repeat(self):
def helper(shape, repeats):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
y = x.repeat(repeats)
ref_y = cpu_x.repeat(repeats)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 3, 4, 5), (2, 3, 4, 5))
helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
helper((3, 4, 5), (2, 3, 4, 5))
helper((3, 4, 5), (2, 2, 2))
def test_torch_repeat_interleave(self, device="mps"):
y = torch.tensor([[1, 2], [3, 4]], device=device)
# exercise single argument function signature
temp = y.repeat_interleave(2)
self.assertEqual(torch.Size([8]), temp.size())
for dtype in [torch.int, torch.long]:
lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
output_size = torch.sum(lengths)
a = torch.repeat_interleave(
y,
lengths,
dim=0,
)
self.assertEqual(a.dtype, y.dtype)
self.assertEqual(a.size(), torch.Size([3, 2]))
a_with_output = torch.repeat_interleave(
y,
lengths,
dim=0,
output_size=output_size,
)
self.assertEqual(a_with_output.dtype, y.dtype)
self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
def test_repeat_interleave(self, device="mps"):
x = torch.tensor([0, 1, 2, 3], device=device)
expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device)
# Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32
self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=product_version >= 13.3)
with self.assertRaises(RuntimeError):
torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
with self.assertRaises(RuntimeError):
torch.repeat_interleave(torch.arange(4.0, device=device))
with self.assertRaises(RuntimeError):
torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
y = torch.tensor([[1, 2], [3, 4]], device=device)
y1_v1 = torch.repeat_interleave(y, 2)
y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
self.assertEqual(y1_v1, y1_expect)
self.assertEqual(y1_v2, y1_expect)
self.assertEqual(y1_v3, y1_expect)
y2 = torch.repeat_interleave(y, 3, dim=1)
y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]], device=device)
self.assertEqual(y2, y2_expect)
y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
y3_expect = torch.tensor([[1, 2],
[3, 4],
[3, 4]], device=device)
self.assertEqual(y3, y3_expect)
with self.assertRaises(RuntimeError):
torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
with self.assertRaises(RuntimeError):
torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
# test zero sized dimension
x = torch.zeros((5, 0), device=device)
y = torch.repeat_interleave(x, repeats=3, dim=1)
self.assertEqual(y, x.new_zeros(5, 0, device=device))
x = torch.tensor([], dtype=torch.int64, device=device)
y = torch.repeat_interleave(x, x)
self.assertEqual(y, x)
def test_repeat_interleave_simple(self):
def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None):
x = torch.randn(shape, dtype=dtype, device="mps")
x_cpu = x.detach().clone().cpu()
num_repeats_cpu = num_repeats.detach().clone().cpu()
repeats = torch.repeat_interleave(x, num_repeats, dim)
repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim)
self.assertEqual(repeats, repeats_cpu)
helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
def test_count_nonzero(self):
def helper(dtype):
n = [
[[1, 0, 2], [3, 0, 2], [7, 9, -4]],
[[0, 2, 3], [3, 2, 1], [2, 0, 0]],
]
cpu_x = torch.tensor(n, dtype=dtype)
mps_x = torch.tensor(n, dtype=dtype).to('mps')
# All non-zeros
self.assertEqual(
torch.count_nonzero(cpu_x),
torch.count_nonzero(mps_x)
)
# dim=1
self.assertEqual(
torch.count_nonzero(cpu_x, dim=1),
torch.count_nonzero(mps_x, dim=1)
)
# dim=(0, 1)
self.assertEqual(
torch.count_nonzero(cpu_x, dim=(0, 1)),
torch.count_nonzero(mps_x, dim=(0, 1))
)
helper(torch.int32)
helper(torch.int64)
helper(torch.float16)
helper(torch.float32)
def _test_module_empty_input(self, module, inp, check_size=True):
inp.requires_grad_(True)
out = module(inp)
gO = torch.rand_like(out)
out.backward(gO)
if check_size:
self.assertEqual(out.size(), inp.size())
for p in module.parameters():
if p.requires_grad:
self.assertEqual(p.grad, torch.zeros_like(p.grad))
self.assertEqual(inp.grad, torch.zeros_like(inp))
# Test dtype casting, with and without simultaneous device change
def test_to(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
self.assertEqual(cpu_x.int(), mps_x.int().cpu())
self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
self.assertEqual(cpu_x.float(), mps_x.float().cpu())
self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
torch.tensor(1, dtype=torch.int32))
self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
torch.tensor(1, dtype=torch.int32))
self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
torch.tensor(1.0))
self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
torch.tensor(4, dtype=torch.int32))
self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
torch.tensor(4, dtype=torch.int32))
self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
# Cast int8 and uint8 to float and compare results
# See https://github.com/pytorch/pytorch/issues/80009 for more details
cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8)
cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8)
for x_cpu in [cpu_byte, cpu_char]:
x_mps = x_cpu.to('mps')
self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32))
def test_setitem_scalar(self) -> None:
device = 'mps'
for dtype in [torch.int32, torch.float32, torch.int64]:
for i in range(3, 6):
for j in range(3, 6):
t = torch.zeros(i, j, dtype=dtype, device=device)
self.assertEqual(t.sum(), 0)
t[1, 1] = 1
t[2, 1] = j
t[1, 2] = i
self.assertEqual(t[1, 1], 1)
self.assertEqual(t[1, 2], i)
self.assertEqual(t[2, 1], j)
self.assertEqual(t.sum(), 1 + i + j)
def test_stride_of_strides(self) -> None:
x = torch.rand(32, 1, device='mps')
y = x.as_strided(size=(32, 2), stride=(1, 0))
# Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
# See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
def test_type_casting(self):
# https://github.com/pytorch/pytorch/issues/81567
def helper(data, to_dtype):
a_cpu = torch.tensor(data)
a_mps = a_cpu.to(torch.device('mps'))
res_cpu = a_cpu.type(to_dtype)
res_mps = a_mps.type(to_dtype)
self.assertEqual(res_cpu, res_mps)
helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
def test_to_casting(self):
# https://github.com/pytorch/pytorch/issues/81567
def helper(data, to_dtype):
a_cpu = torch.tensor(data)
a_mps = a_cpu.to(torch.device('mps'))
res_cpu = a_cpu.to(to_dtype)
res_mps = a_mps.to(to_dtype)
self.assertEqual(res_cpu, res_mps)
helper([9.0, 3.0, 5.0, 4.0], torch.int64)
helper([9.0, 3.0, 5.0, 4.0], torch.float)
helper([9.0, 3.0, 5.0, 4.0], torch.int32)
helper([9.0, 3.0, 5.0, 4.0], torch.short)
helper([9.0, 3.0, 5.0, 4.0], torch.half)
helper([9.0, 3.0, 5.0, 4.0], torch.int8)
helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
def test_storage_offset_greater_than_src_nbytes(self):
# https://github.com/pytorch/pytorch/issues/80844
n_tensors = 100
n_tensor_elems = 784
elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
tensor_list = []
for i in range(0, n_tensors - 1):
# create a list of contiguous view tensors (view tensor created by the slice op)
t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
tensor_list.append(t)
for i in range(0, n_tensors - 1):
t = tensor_list[i].view(1, n_tensor_elems)
t_mps = t.to("mps")
self.assertEqual(t, t_mps.cpu(), f"i={i}")
# See https://github.com/pytorch/pytorch/issues/82427
# and https://github.com/pytorch/pytorch/issues/83692
def test_full_bugs(self):
# Test should not crash
x = torch.full((3, 3), True, device='mps')
# torch.full should work for uint8
y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
self.assertEqual(y_mps, y_cpu)
@unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
# See https://github.com/pytorch/pytorch/issues/84995
def test_div_bugs(self):
for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
if dtype != torch.int64:
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
y = torch.div(x, 101, rounding_mode=mode)
self.assertEqual(y.sum(), 0)
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):
x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
def test_int_expand(self):
x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps')
y = torch.tensor([0, 1], dtype=torch.int8, device='mps')
self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
# Empty unary op should return tensor of the same size
def test_empty_neg(self):
x = torch.tensor([[]], device='mps')
y = -x
self.assertEqual(x, y)
def _test_unique_scalar_empty(self, dtype, device, f):
# test scalar
x = torch.tensor(0, dtype=dtype, device=device)
unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
expected_unique = torch.tensor([0], dtype=dtype, device=device)
expected_inverse = torch.tensor(0, device=device)
expected_counts = torch.tensor([1], device=device)
self.assertEqual(unique, expected_unique)
self.assertEqual(inverse, expected_inverse)
self.assertEqual(counts, expected_counts)
# test zero sized tensor
x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
expected_unique = torch.tensor([], dtype=dtype, device=device)
expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
expected_counts = torch.tensor([], dtype=torch.long, device=device)
self.assertEqual(unique, expected_unique)
self.assertEqual(inverse, expected_inverse)
self.assertEqual(counts, expected_counts)
def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
def ensure_tuple(x):
if isinstance(x, torch.Tensor):
return (x,)
return x
for return_inverse in [True, False]:
for return_counts in [True, False]:
# test with expected
ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
self.assertEqual(expected_unique, ret[0])
if return_inverse:
self.assertEqual(expected_inverse, ret[1])
if return_counts:
count_index = 1 + int(return_inverse)
self.assertEqual(expected_counts, ret[count_index])
# tests per-element unique on a higher rank tensor.
y = x.view(additional_shape)
y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
self.assertEqual(expected_unique, y_unique)
self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
self.assertEqual(expected_counts, y_counts)
def test_unique_all_dtypes(self, device="mps"):
def helper(dtype):
def ensure_tuple(x):
if isinstance(x, torch.Tensor):
return (x,)
return x
if dtype is torch.bool:
x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
else:
x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
# test sorted unique
fs = (
lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
lambda x, **kwargs: x.unique(sorted=True, **kwargs),
)
x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
xs = (x, x_sliced)
for f, x in product(fs, xs):
self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
self._test_unique_scalar_empty(dtype, device, f)
# test unsorted unique
fs = (
lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
lambda x, **kwargs: x.unique(sorted=False, **kwargs)
)
for f, x in product(fs, xs):
self._test_unique_scalar_empty(dtype, device, f)
for return_inverse, return_counts in product((True, False), repeat=2):
ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
x_list = x.tolist()
x_unique_list = ret[0].tolist()
self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
if return_inverse:
x_inverse_list = ret[1].tolist()
for i, j in enumerate(x_inverse_list):
self.assertEqual(x_list[i], x_unique_list[j])
if return_counts:
count_index = 1 + int(return_inverse)
x_counts_list = ret[count_index].tolist()
for i, j in zip(x_unique_list, x_counts_list):
count = 0
for k in x_list:
if k == i:
count += 1
self.assertEqual(j, count)
[helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]]
def test_unique(self):
def helper(x, return_inverse, return_counts):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts)
result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts)
self.assertEqual(result, result_cpu)
helper(torch.tensor([1, 2, 4, 2, 1]), False, False)
helper(torch.randint(3, (10, )), False, False)
helper(torch.randint(3, (10, )), True, False)
helper(torch.randint(3, (10, )), False, True)
helper(torch.randint(3, (10, )), True, True)
helper(torch.randint(3, (1, )), True, True)
helper(torch.randint(3, (0, )), True, True)
# Regression test for https://github.com/pytorch/pytorch/issues/104879
x = torch.arange(2, device="mps")
self.assertEqual(x.reshape(1, 1, 2).unique(), x)
def test_unique_consecutive(self):
def helper(x, dim, return_inverse, return_counts):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
self.assertEqual(result, result_cpu)
helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False)
helper(torch.randint(3, (10, )), 0, False, False)
helper(torch.randint(3, (10, )), 0, True, False)
helper(torch.randint(3, (10, )), 0, False, True)
helper(torch.randint(3, (10, )), 0, True, True)
helper(torch.randint(3, (10, )), 0, True, True)
helper(torch.randint(3, (1, )), 0, True, True)
helper(torch.randint(3, (0, )), 0, True, True)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True)
helper(torch.randint(2, (20, 2)), 0, True, True)
helper(torch.randint(2, (1, 2)), 0, True, True)
helper(torch.randint(2, (0, 2)), 0, True, True)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True)
helper(torch.randint(2, (2, 20)), 1, True, True)
helper(torch.randint(2, (2, 1)), 1, True, True)
helper(torch.randint(2, (2, 0)), 1, True, True)
# See https://github.com/pytorch/pytorch/issues/85675
def test_cat_non_contiguous(self):
def rotate_subset(data, dim):
x1 = data[:, :, :2, :]
x2 = data[:, :, 2:, :]
self.assertFalse(x1.is_contiguous())
self.assertFalse(x2.is_contiguous())
return torch.concat((x1, x2), dim=dim)
for dtype in MPS_DTYPES:
if dtype == torch.bool:
continue
data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
data = data.to(memory_format=torch.channels_last)
mps_data = data.to("mps")
self.assertEqual(data, mps_data)
for dim in range(data.dim()):
cpu_result = rotate_subset(data, dim)
mps_result = rotate_subset(mps_data, dim)
self.assertEqual(cpu_result, mps_result.to("cpu"))
# TODO: enable memory format test
# self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
# See https://github.com/pytorch/pytorch/issues/85967
def test_from_numpy_non_contiguous(self):
a = np.arange(9).reshape(3, 3)[:, :2]
t_cpu = torch.tensor(a, device="cpu")
t_mps = torch.tensor(a, device="mps")
self.assertEqual(t_cpu, t_mps.to("cpu"))
# See https://github.com/pytorch/pytorch/issues/86954
def test_copy_non_contiguous(self):
x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
self.assertFalse(x.is_contiguous())
y = x.to('mps')
self.assertFalse(y.is_contiguous())
self.assertEqual(x, y.to('cpu'))
x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
y = x.to('mps')
self.assertEqual(x, y.to('cpu'))
x = torch.full((4, 4, 4, 4), 13, device="cpu")
y = torch.full((4, 4, 4, 4), 13, device="mps")
z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
x.permute(3, 2, 1, 0)[1::, ::2] = z
# As y is on MPS and z on CPU, this dispatches to a copy operator
y.permute(3, 2, 1, 0)[1::, ::2] = z
self.assertEqual(x, y.to('cpu'))
# See https://github.com/pytorch/pytorch/issues/95417
def test_copy_storage_offset(self):
x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
x_cpu[2:4] = update_cpu
x_mps[2:4] = update_mps # implicit type casting and copy
self.assertEqual(x_cpu, x_mps)
x_cpu[2:4] = update_mps # implicit device moving and copy
self.assertEqual(x_cpu, x_mps)
def test_copy_broadcasting(self):
def helper(src_shape, dst_shape, src_dtype, dst_dtype):
cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
cpu_result = cpu_dst.copy_(cpu_src)
mps_src = cpu_src.to("mps")
mps_dst = cpu_dst.to("mps")
mps_result = mps_dst.copy_(mps_src)
self.assertEqual(cpu_result, mps_result)
test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
helper((2, 1), (2, 3), src_dtype, dst_dtype)
helper((2, 1), (2, 2), src_dtype, dst_dtype)
helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
helper((3,), (2, 3), src_dtype, dst_dtype)
helper((2,), (2, 2), src_dtype, dst_dtype)
helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
# Regression test for https://github.com/pytorch/pytorch/issues/107867
self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
# See https://github.com/pytorch/pytorch/pull/84742
# and https://github.com/pytorch/pytorch/pull/78319
def test_binops_dtype_precedence(self):
# Test dtype precedence (casting order) in binary operations by comparing to CPU result
# Example values for all dtypes supported on the MPS backend
sample_vals = {
torch.bool: [False, True],
torch.int16: [-15, 0, 1, 10],
torch.int32: [-376, 0, 1, 13],
torch.int64: [-8, 0, 1, 77],
torch.float16: [-234.5, 0.0, 1.0, 2.0],
torch.float32: [-1.0, 0.0, 0.1, 111.99],
}
# Test all combinations of dtypes, operations, dimensionality
for dtype1, dtype2, binop in itertools.product(
sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']):
# bool minus bool is generally unsupported, so skip
if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
continue
full_shape = (10,)
for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
# print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
# print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
# (torch.tensor(val2, dtype=dtype2, device='mps')))
# print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
# (torch.tensor(val2, dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
(torch.tensor(val2, dtype=dtype2, device='mps')),
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
(torch.tensor(val2, dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
(torch.tensor([val2], dtype=dtype2, device='mps')),
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
(torch.tensor([val2], dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
(torch.tensor([val2], dtype=dtype2, device='mps')),
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
(torch.tensor([val2], dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
(torch.tensor(val2, dtype=dtype2, device='mps')),
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
(torch.tensor(val2, dtype=dtype2, device='cpu')))
# Test tensors created with torch.full
x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
y1 = torch.tensor(val2, dtype=dtype2, device='mps')
x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
x3 = torch.tensor(val1, dtype=dtype1, device='mps')
y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
self.assertEqual(
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
(torch.full(full_shape, val2, dtype=dtype2, device='mps')),
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
(torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
def test_nansum(self):
def helper(dtype, noncontiguous, dim):
zero_cpu = torch.zeros((), dtype=dtype)
# Randomly scale the values
scale = random.randint(10, 100)
x_cpu: torch.Tensor = make_tensor(
(5, 5), dtype=dtype, device='cpu',
low=-scale, high=scale, noncontiguous=noncontiguous)
if dtype.is_floating_point:
nan_mask_cpu = x_cpu < (0.2 * scale)
x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
x_cpu[nan_mask_cpu] = np.nan
else:
x_no_nan_cpu = x_cpu
x_mps = x_cpu.to('mps')
actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
expect_out_cpu = torch.empty(0, dtype=dtype)
dim_kwargs = {"dim": dim} if dim is not None else {}
expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
# Sanity check on CPU
self.assertEqual(expect, actual_cpu)
# Test MPS
actual_mps = torch.nansum(x_mps, **dim_kwargs)
# Test out= variant
torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
self.assertEqual(expect, actual_mps)
self.assertEqual(expect_out_cpu, actual_out_mps)
args = itertools.product(
(torch.float16, torch.float32, torch.int32, torch.int64), # dtype
(True, False), # noncontiguous
(0, 1, None), # dim
)
for dtype, noncontiguous, dim in args:
with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
helper(dtype, noncontiguous, dim)
def test_cumsum_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
a = t.cumsum(0, dtype=dtype)
a_cpu = t_cpu.cumsum(0, dtype=dtype)
self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
try:
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
" Support has been added in macOS 13.3")
def test_cumsum_bool(self):
a = torch.ones(2**16, dtype=torch.bool)
t_cpu = a.cumsum(0)
t_mps = a.to("mps").cumsum(0)
self.assertEqual(t_cpu, t_mps)
def test_cumsum_minus_one_axis(self):
def helper(dtype):
# Test with axis -1
cpu_x = None
if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')
cpu_y = cpu_x.cumsum(-1)
y = x.cumsum(-1)
self.assertEqual(y, cpu_y)
[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
def test_cumprod_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
a = t.cumprod(0, dtype=dtype)
a_cpu = t_cpu.cumprod(0, dtype=dtype)
self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
try:
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
+ " Support has been added in macOS 13.3")
def test_cumprod_minus_one_axis(self):
def helper(dtype):
# Test with axis -1
cpu_x = None
if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')
cpu_y = cpu_x.cumprod(-1)
y = x.cumprod(-1)
self.assertEqual(y, cpu_y)
[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
def test_median_int16(self):
def helper(shape, dtype):
cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
median_result = torch.median(x)
median_result_cpu = torch.median(cpu_x)
self.assertEqual(median_result, median_result_cpu)
helper((2, 8, 4, 5), torch.int16)
def test_activation_checkpoint_does_not_error(self):
from torch.utils.checkpoint import checkpoint
for use_reentrant in (True, False):
a = torch.tensor(1., device="mps", requires_grad=True)
def fn(x):
return x.sin().cos().exp()
out = checkpoint(fn, a, use_reentrant=use_reentrant)
out.backward()
def test_as_strided(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
values_1 = [[1.0, 1.0], [1.0, 1.0]]
cpu_x = torch.tensor(values, device='cpu')
ones1 = torch.tensor(values_1, device='mps')
x = cpu_x.detach().clone().to('mps').requires_grad_()
strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
strided_mps = torch.as_strided(x, (2, 2), (1, 2))
self.assertEqual(strided_mps, strided_cpu)
strided_cpu_out = strided_cpu + ones1.to('cpu')
strided_mps_out = strided_mps + ones1
self.assertEqual(strided_cpu_out, strided_mps_out)
# test with storage offsets
cpu_x = torch.rand(3, 3, device='cpu')
mps_x = cpu_x.to('mps')
strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0)
strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0)
strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1)
strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1)
strided_cpu_out = strided_cpu1 - strided_cpu2
strided_mps_out = strided_mps1 - strided_mps2
self.assertEqual(strided_cpu_out, strided_mps_out)
def test_unfold(self):
x = torch.arange(1., 8)
x_mps = torch.arange(1., 8, device="mps")
y = x.unfold(0, 2, 1)
y_mps = x_mps.unfold(0, 2, 1)
self.assertEqual(y, y_mps)
def test_unfold_all_devices_and_dtypes(self):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
for dt in supported_dtypes:
x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
def test_unfold_scalars(self):
x = torch.tensor(0.5, device="mps")
# unfold on a 0-dimensional tensor should always return a 1-d dimensional
# tensor of shape [size] (i.e., the second parameter to unfold)
self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
def test_bincount_simple(self):
input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
input_cpu = input.to("cpu")
weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
weights_cpu = weights.to("cpu")
x = torch.bincount(input)
x_cpu = torch.bincount(input_cpu)
self.assertEqual(x, x_cpu)
y = input.bincount(weights)
y_cpu = input_cpu.bincount(weights_cpu)
self.assertEqual(y, y_cpu)
def test_bincount_reduction(self):
device = "mps"
# negative input throws
with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32))
# n-d input, with n > 1 throws
with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
# minlength < 0 throws
with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
torch.bincount(torch.tensor([1, 3], device=device),
torch.tensor([.2, .2], device=device),
minlength=-1)
# n-d weights, with n > 1 throws
with self.assertRaisesRegex(RuntimeError, '1-d'):
torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float))
# input and weights dim mismatch
with self.assertRaisesRegex(RuntimeError, 'same length'):
torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float))
# 1-d input with no elements and default minlength
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
torch.zeros(0, dtype=torch.long, device=device))
# 1-d input with no elements and specified minlength
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
torch.zeros(10, dtype=torch.long, device=device))
# test tensor method without weights
long_counts = torch.tensor(
[0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
self.assertEqual(
torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
long_counts)
# test avoiding overflow for uint8 (#76979)
count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
self.assertEqual(count_uint8, count_int16)
# test minlength functionality
int_counts = torch.bincount(
torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5)
self.assertEqual(
torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
int_counts)
# test weights
byte_counts = torch.bincount(
torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
torch.tensor([.1, .2, .3, .4, .5], device=device))
self.assertEqual(
torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
byte_counts = torch.bincount(
torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
self.assertEqual(
torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts)
# test non-contiguous inputs and weights
inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32)
weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
for i in [0, 1]:
assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
# inputs are non-contiguous but weights are contiguous
self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
# inputs and weights are non-contiguous
self.assertEqual(
inputs[:, 1].bincount(weights[:, 1]),
torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
# weights are non-contiguous but inputs are contiguous
self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
# test bincount on non-contiguous slices
all0s = torch.zeros((32, 2), dtype=torch.int32, device=device)
self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
all1s = torch.ones((32, 2), dtype=torch.int32, device=device)
self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
# test large number of bins - global memory use
big_exp = torch.zeros(100, device=device)
big_exp[-1] = 50.0
big_w = torch.tensor([.5] * 100, device=device)
big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w)
self.assertEqual(big_exp, big_out)
# test large input size
big_exp = torch.zeros(2, device=device, dtype=torch.int64)
big_exp[1] = 10
big_out = torch.ones(10, dtype=torch.int8, device=device).bincount()
self.assertEqual(big_exp, big_out)
def test_bincount(self):
device = "mps"
input_size = (5000,)
w = torch.randn(input_size, dtype=torch.float, device=device)
w_cpu = w.cpu()
t = torch.randint(50, input_size, dtype=torch.int8, device=device)
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
t = torch.randint(500, input_size, dtype=torch.int32, device=device)
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
t = torch.randint(2000, input_size, dtype=torch.int32, device=device)
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
t = torch.zeros([10], dtype=torch.int32, device=device)
t[0] = 35488
counted = t.bincount(minlength=65536)
self.assertEqual(torch.sum(counted), 10)
def test_sum_backward(self):
def helper(n, c):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_sum = torch.sum(x)
all_sum_cpu = torch.sum(cpu_x)
all_sum.backward()
all_sum_cpu.backward()
self.assertEqual(all_sum, all_sum_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper(3, 3)
# L1 loss
def test_l1_loss(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.L1Loss(reduction=reduction)
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
targetMPS = targetCPU.detach().clone().to('mps')
# forward pass
outputCPU = loss(inputCPU, targetCPU)
outputMPS = loss(inputMPS, targetMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass
if reduction != 'none':
# chose 2 just to make the grad_output > 1 in backward pass
outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper([8, 5, 4], 'none')
helper([7, 5, 2, 4], 'sum')
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
# Mean Squared Error
def test_mse_loss(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.MSELoss(reduction=reduction)
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
targetMPS = targetCPU.detach().clone().to('mps')
# forward pass
outputCPU = loss(inputCPU, targetCPU)
outputMPS = loss(inputMPS, targetMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass
if reduction != 'none':
# chose 2 just to make the grad_output > 1 in backward pass
outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper([8, 5, 4], 'none')
helper([7, 5, 2, 4], 'sum')
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
def test_mse_loss_strided_output(self):
# https://github.com/pytorch/pytorch/issues/124621
lf = nn.MSELoss(reduction='none')
model_cpu = nn.Sequential(
nn.Conv1d(3, 3, 1),
)
model_mps = copy.deepcopy(model_cpu).to("mps")
x = torch.randn(128, 10, 3)
x = x.permute(0, 2, 1)
x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
x_mps = x_mps.permute(0, 2, 1)
y = model_cpu(x)
y_mps = model_mps(x_mps)
y = y.permute(0, 2, 1)[:, :5, :]
y_mps = y_mps.permute(0, 2, 1)[:, :5, :]
y_hat = torch.randn(128, 5, 3)
y_hat_mps = y_hat.detach().clone().to("mps")
loss = lf(y, y_hat)
loss_mps = lf(y_mps, y_hat_mps)
self.assertEqual(loss, loss_mps)
# Binary Cross Enropy
def test_bce_loss_simple(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.BCELoss(reduction=reduction)
# input and target must be within [0..1]
input_t = np.random.random_sample(size=shape).astype(np.float32)
target_t = np.random.random_sample(size=shape).astype(np.float32)
inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
targetMPS = targetCPU.detach().clone().to('mps')
# forward pass
outputCPU = loss(inputCPU, targetCPU)
outputMPS = loss(inputMPS, targetMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass
if reduction != 'none':
# chose 0.6 just to have the grad_output != 1
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper([8, 5, 4], 'none')
helper([7, 5, 2, 4], 'sum')
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
helper([1, 1, 32, 32], 'mean')
def test_bce_loss_always_nonnegative(self):
target = torch.ones(5, device='mps')
input = torch.ones(5, device='mps')
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
target = torch.zeros(5, device='mps')
input = torch.zeros(5, device='mps')
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
def test_bce_loss_size_mismatch(self):
bceloss = nn.BCELoss()
a = torch.rand(25, device='mps')
b = torch.rand(25, 1, device='mps')
with self.assertRaisesRegex(ValueError, r'Using a target size \('):
bceloss(a, b)
def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
x_size = 1024
y_size = 256
target = torch.rand(x_size, y_size, device='mps')
for reduction in ['none', 'mean', 'sum']:
output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
output_logits = output_sig.clone().detach()
output_sig.requires_grad = True
output_logits.requires_grad = True
weight = torch.rand(y_size, device='mps')
loss_sig = nn.BCELoss(weight, reduction=reduction)(
torch.sigmoid(output_sig), target
)
loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
output_logits, target
)
self.assertEqual(loss_logits, loss_sig)
if reduction == 'none':
grad = torch.rand(x_size, y_size, device='mps')
loss_sig.backward(grad)
loss_logits.backward(grad)
else:
loss_sig.backward()
loss_logits.backward()
self.assertEqual(output_sig.grad, output_logits.grad)
def test_bce_with_logits_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True, device='mps')
target = torch.zeros(3, 1, device='mps')
nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
self.assertEqual(output.grad, expected_grad)
def test_bce_with_logits_broadcasts_weights(self):
target = torch.rand(16, 4, device='mps')
output = torch.rand(16, 4, device='mps') - 0.5
weight = torch.rand(4, device='mps')
out1 = nn.BCEWithLogitsLoss(weight)(output, target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCEWithLogitsLoss(weight)(output, target)
self.assertEqual(out1, out2)
weight = torch.rand(16, 1, device='mps')
out1 = nn.BCEWithLogitsLoss(weight)(output, target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCEWithLogitsLoss(weight)(output, target)
self.assertEqual(out1, out2)
def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
target = torch.rand(64, 4, device='mps')
output = torch.rand(64, 4, device='mps') - 0.5
pos_weight = torch.ones(64, 4, device='mps')
self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
def test_bce_with_logits_broadcasts_pos_weights(self):
target = torch.rand(64, 4, device='mps')
output = torch.rand(64, 4, device='mps') - 0.5
pos_weight = torch.rand(4, device='mps')
out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
pos_weight1 = pos_weight.expand(1, 4)
out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
pos_weight2 = pos_weight.expand(64, 4)
out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
self.assertEqual(out1, out2)
self.assertEqual(out1, out3)
def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True, device='mps')
target = torch.zeros(3, 1, device='mps')
pos_weight = torch.ones(3, 1, device='mps')
nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
grad = output.grad
self.assertEqual(grad, expected_grad)
def test_bce_with_logits_stability(self):
output = torch.tensor([0., -120.], device='mps')
target = torch.tensor([0., 1.], device='mps')
pos_weight = torch.tensor([1., 1.], device='mps')
out1 = nn.BCEWithLogitsLoss()(output, target)
self.assertTrue(torch.isfinite(out1).all().item())
out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
self.assertTrue(torch.isfinite(out2).all().item())
def test_bce_loss_broadcasts_weights(self):
sigmoid = nn.Sigmoid()
target = torch.rand(16, 4, device='mps')
output = torch.rand(16, 4, device='mps') - 0.5
weight = torch.rand(4, device='mps')
out1 = nn.BCELoss(weight)(sigmoid(output), target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCELoss(weight)(sigmoid(output), target)
self.assertEqual(out1, out2)
weight = torch.rand(16, 1, device='mps')
out1 = nn.BCELoss(weight)(sigmoid(output), target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCELoss(weight)(sigmoid(output), target)
self.assertEqual(out1, out2)
def test_cross_entropy_loss(self):
# Regression test for https://github.com/pytorch/pytorch/issues/116095
loss = nn.CrossEntropyLoss()
pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps')
target = torch.ones(3, dtype=torch.long, device='mps')
output = loss(pred, target)
output.backward()
def test_log_softmax(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
mps_x = torch.tensor(values, device='mps', requires_grad=True)
cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
mps_log_softmax = F.log_softmax(mps_x, dim=0)
self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
cpu_grad = torch.ones_like(cpu_log_softmax)
mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
cpu_log_softmax.backward(gradient=cpu_grad)
mps_log_softmax.backward(gradient=mps_grad)
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
def test_log_softmax_large_numbers(self):
values = [
[10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
[-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
]
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
mps_x = torch.tensor(values, device='mps', requires_grad=True)
cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
mps_log_softmax = F.log_softmax(mps_x, dim=-1)
self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
cpu_grad = torch.ones_like(cpu_log_softmax)
mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
cpu_log_softmax.backward(gradient=cpu_grad)
mps_log_softmax.backward(gradient=mps_grad)
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
def test_eq(self):
values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
mps_x = torch.tensor(values1, device='mps')
mps_y = torch.tensor(values2, device='mps')
cpu_x = torch.tensor(values1, device='cpu')
cpu_y = torch.tensor(values2, device='cpu')
result_mps = torch.eq(mps_x, mps_y)
result_cpu = torch.eq(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
@unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_signed_vs_unsigned_comparison(self):
cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
# in the comparison of signed vs. unsigned we should always cast to unsigned
self.assertEqual(cpu_x == -1, mps_x == -1)
self.assertEqual(cpu_x > -1, mps_x > -1)
self.assertEqual(cpu_x < -1, mps_x < -1)
def test_eq_int64(self):
values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
mps_x = torch.tensor(values1, device='mps')
mps_y = torch.tensor(values2, device='mps')
cpu_x = torch.tensor(values1, device='cpu')
cpu_y = torch.tensor(values2, device='cpu')
result_mps = torch.eq(mps_x, mps_y)
result_cpu = torch.eq(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
def test_ne(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.ne(mps_x, mps_y)
result_cpu = torch.ne(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_ne_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.ne(mps_x, 0.0)
result_cpu = torch.ne(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_lt(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.lt(mps_x, mps_y)
result_cpu = torch.lt(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_lt_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.lt(mps_x, 0.0)
result_cpu = torch.lt(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_le(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.le(mps_x, mps_y)
result_cpu = torch.le(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_le_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.le(mps_x, 0.0)
result_cpu = torch.le(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_ge(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.ge(mps_x, mps_y)
result_cpu = torch.ge(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_ge_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.ge(mps_x, 0.0)
result_cpu = torch.ge(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_gt(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.gt(mps_x, mps_y)
result_cpu = torch.gt(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_gt_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.gt(mps_x, 0.0)
result_cpu = torch.gt(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_argmax(self):
# https://github.com/pytorch/pytorch/issues/98191
cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]])
res_cpu = torch.argmax(cpu_tensor, dim=1)
mps_tensor = cpu_tensor.to(torch.device('mps'))
res_mps = torch.argmax(mps_tensor, dim=1)
self.assertEqual(res_cpu, res_mps)
# https://github.com/pytorch/pytorch/issues/92311
mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32)
cpu_tensor = mps_tensor.detach().clone().cpu()
res_mps = torch.argmax(mps_tensor, dim=1)
res_cpu = torch.argmax(cpu_tensor, dim=1)
self.assertEqual(res_cpu, res_mps)
# Test forward argmin argmax
def test_argmin_argmax(self):
def helper(n, c, h, w, reduction_type, dtype=torch.float32):
if reduction_type == "max":
arg_reduction_fn = torch.argmax
else:
arg_reduction_fn = torch.argmin
cpu_x = None
x = None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
y = arg_reduction_fn(x)
ref_y = arg_reduction_fn(cpu_x)
self.assertEqual(y, ref_y)
y_0 = arg_reduction_fn(x, dim=0)
refy_0 = arg_reduction_fn(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_1 = arg_reduction_fn(x, dim=1)
refy_1 = arg_reduction_fn(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
y_2 = arg_reduction_fn(x, dim=2)
refy_2 = arg_reduction_fn(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
y_3 = arg_reduction_fn(x, dim=3)
refy_3 = arg_reduction_fn(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
helper(2, 8, 4, 4, "max", torch.float32)
helper(2, 8, 4, 4, "max", torch.int32)
helper(2, 8, 4, 4, "max", torch.float16)
helper(2, 8, 4, 4, "max", torch.int64)
helper(2, 8, 4, 4, "min", torch.float32)
helper(2, 8, 4, 4, "min", torch.int32)
helper(2, 8, 4, 4, "min", torch.float16)
helper(2, 8, 4, 4, "min", torch.int64)
@unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
def test_reduction_sum_max_long_val(self):
x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
x_cpu = x_mps.detach().clone().cpu()
res_mps = torch.sum(x_mps)
res_cpu = torch.sum(x_cpu)
self.assertEqual(res_mps, res_cpu)
# Test forward max
# Note - don't test grad now
def test_max_el(self):
def helper(n, c, h, w, dtype=torch.float32):
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps')
ref_y = torch.max(cpu_x)
y = torch.max(x)
self.assertEqual(y, ref_y)
for dim in [0, 1, 2, 3]:
for keepdim in [True, False]:
y, idx = torch.max(x, dim=dim, keepdim=keepdim)
refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
self.assertEqual(y, refy)
self.assertEqual(idx, refidx)
y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=0, out=(y_0, idx_0))
refy_0, refidx_0 = torch.max(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
self.assertEqual(idx_0, refidx_0)
y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
self.assertEqual(idx_0dim, refidx_0dim)
y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=1, out=(y_1, idx_1))
refy_1, refidx_1 = torch.max(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
self.assertEqual(idx_1, refidx_1)
y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
self.assertEqual(y_1dim, refy_1dim)
self.assertEqual(idx_1dim, refidx_1dim)
y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
torch.max(x, dim=2, out=(y_2, idx_2))
refy_2, refidx_2 = torch.max(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
self.assertEqual(idx_2, refidx_2)
y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
self.assertEqual(y_2dim, refy_2dim)
self.assertEqual(idx_2dim, refidx_2dim)
y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
torch.max(x, dim=3, out=(y_3, idx_3))
refy_3, refidx_3 = torch.max(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
self.assertEqual(idx_3, refidx_3)
y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
self.assertEqual(y_3dim, refy_3dim)
self.assertEqual(idx_3dim, refidx_3dim)
helper(2, 8, 4, 5, torch.float32)
helper(2, 8, 4, 5, torch.int32)
# helper(2, 8, 4, 5, torch.int64)
def test_median(self):
def helper_dtype_int32(n1, n2, n3):
cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
mps_x = cpu_x.detach().clone().to('mps')
result_cpu = torch.median(cpu_x)
result_mps = torch.median(mps_x)
self.assertEqual(result_cpu, result_mps)
for dim in [0, 1, 2]:
for keepdim in [True, False]:
y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
self.assertEqual(y, refy)
self.assertEqual(idx, refidx)
def helper_dtype_float32(n1, n2, n3):
cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
mps_x = cpu_x.detach().clone().to('mps')
result_cpu = torch.median(cpu_x)
result_mps = torch.median(mps_x)
self.assertEqual(result_cpu, result_mps)
for dim in [0, 1, 2]:
for keepdim in [True, False]:
y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
self.assertEqual(y, refy)
self.assertEqual(idx, refidx)
helper_dtype_int32(10, 10, 10) # median at even place
helper_dtype_int32(3, 3, 3) # median at odd place
helper_dtype_int32(1, 1, 1)
helper_dtype_int32(1, 2, 3)
helper_dtype_float32(10, 10, 10)
helper_dtype_float32(3, 3, 3)
helper_dtype_float32(1, 1, 1)
def test_any(self):
def helper(shape):
input_xs = []
prod = 1
for i in range(len(shape)):
prod *= shape[i]
input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
for i, cpu_x in enumerate(input_xs):
x = cpu_x.detach().clone().to('mps')
y = torch.any(x)
ref_y = torch.any(cpu_x)
self.assertEqual(y, ref_y)
y_0 = torch.any(x, dim=0)
refy_0 = torch.any(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
y_0dim = torch.any(x, dim=0, keepdim=True)
refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_0dim = torch.any(x, dim=0, keepdim=True)
refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_1 = torch.any(x, dim=1)
refy_1 = torch.any(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
y_1dim = torch.any(x, dim=1, keepdim=True)
refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
if (len(shape) > 2):
y_2 = torch.any(x, dim=2)
refy_2 = torch.any(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
y_2dim = torch.any(x, dim=2, keepdim=True)
refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
y_3 = torch.any(x, dim=3)
refy_3 = torch.any(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
y_3dim = torch.any(x, dim=3, keepdim=True)
refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
helper((1, 1, 1, 1))
helper((1, 1, 3, 3))
helper((7, 13))
helper((2, 8, 4, 5))
def test_reduction_ops_5D(self):
def helper(fn, dim):
shape = (1, 1, 2, 1, 1)
x_cpu = fn(torch.zeros(shape), dim=dim)
x_mps = fn(torch.zeros(shape, device="mps"), dim=dim)
self.assertEqual(x_cpu, x_mps.cpu())
for fn in [torch.any, torch.all]:
for dim in range(0, 4):
helper(fn, dim)
# 6D tensor reductions
# Regression test for https://github.com/pytorch/pytorch/issues/95538
x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu()
self.assertEqual(x.all(), x.cpu().all())
for i in range(-5, 6):
self.assertEqual(x.all(dim=i), x.cpu().all(dim=i))
def test_all(self):
def helper(shape):
input_xs = []
prod = 1
for i in range(len(shape)):
prod *= shape[i]
input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
for i, cpu_x in enumerate(input_xs):
x = cpu_x.detach().clone().to('mps')
y = torch.all(x)
ref_y = torch.all(cpu_x)
self.assertEqual(y, ref_y)
y_0 = torch.all(x, dim=0)
refy_0 = torch.all(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
y_0dim = torch.all(x, dim=0, keepdim=True)
refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_0dim = torch.all(x, dim=0, keepdim=True)
refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_1 = torch.all(x, dim=1)
refy_1 = torch.all(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
y_1dim = torch.all(x, dim=1, keepdim=True)
refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
if (len(shape) > 2):
y_2 = torch.all(x, dim=2)
refy_2 = torch.all(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
y_2dim = torch.all(x, dim=2, keepdim=True)
refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
y_3 = torch.all(x, dim=3)
refy_3 = torch.all(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
y_3dim = torch.all(x, dim=3, keepdim=True)
refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
helper((1, 1, 1, 1))
helper((1, 1, 3, 3))
helper((7, 13))
helper((2, 8, 4, 5))
# Empty tensor
x_cpu = torch.tensor([], dtype=torch.bool)
x_mps = x_cpu.to("mps")
self.assertEqual(x_cpu.all(), x_mps.all().cpu())
# Test forward min
def test_min_el(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
y = torch.min(x)
ref_y = torch.min(cpu_x)
self.assertEqual(y, ref_y)
y_0, idx_0 = torch.min(x, dim=0)
refy_0, refidx_0 = torch.min(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
self.assertEqual(idx_0, refidx_0)
y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=0, out=(y_0, idx_0))
refy_0, refidx_0 = torch.min(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
self.assertEqual(idx_0, refidx_0)
y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
self.assertEqual(idx_0dim, refidx_0dim)
y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
self.assertEqual(idx_0dim, refidx_0dim)
y_1, idx_1 = torch.min(x, dim=1)
refy_1, refidx_1 = torch.min(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
self.assertEqual(idx_1, refidx_1)
y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=1, out=(y_1, idx_1))
refy_1, refidx_1 = torch.min(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
self.assertEqual(idx_1, refidx_1)
y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
self.assertEqual(idx_1dim, refidx_1dim)
y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
self.assertEqual(y_1dim, refy_1dim)
self.assertEqual(idx_1dim, refidx_1dim)
y_2, idx_2 = torch.min(x, dim=2)
refy_2, refidx_2 = torch.min(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
self.assertEqual(idx_2, refidx_2)
y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
torch.min(x, dim=2, out=(y_2, idx_2))
refy_2, refidx_2 = torch.min(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
self.assertEqual(idx_2, refidx_2)
y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
self.assertEqual(idx_2dim, refidx_2dim)
y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
self.assertEqual(y_2dim, refy_2dim)
self.assertEqual(idx_2dim, refidx_2dim)
y_3, idx_3 = torch.min(x, dim=3)
refy_3, refidx_3 = torch.min(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
self.assertEqual(idx_3, refidx_3)
y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
torch.min(x, dim=3, out=(y_3, idx_3))
refy_3, refidx_3 = torch.min(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
self.assertEqual(idx_3, refidx_3)
y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
self.assertEqual(idx_3dim, refidx_3dim)
y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
self.assertEqual(y_3dim, refy_3dim)
self.assertEqual(idx_3dim, refidx_3dim)
helper(2, 8, 4, 5)
# Test forward sum
def test_sum(self):
def helper(n, c, h, w, dtype=torch.float32):
cpu_x = None
x = None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_sum = torch.sum(x)
all_sum_cpu = torch.sum(cpu_x)
self.assertEqual(all_sum, all_sum_cpu)
nil_dim_sum = torch.sum(x, dim=[])
nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
zero_dim_sum = torch.sum(x, dim=[0])
zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
zero_one_dim_sum = torch.sum(x, dim=[0, 1])
zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
two_three_dim_sum = torch.sum(x, dim=[2, 3])
two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
helper(2, 8, 4, 5)
helper(2, 8, 4, 5, dtype=torch.int32)
helper(2, 8, 4, 5, dtype=torch.int64)
helper(2, 8, 4, 5, dtype=torch.bool)
# Regression test for https://github.com/pytorch/pytorch/issues/136132
x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2)
self.assertEqual(x.numel(), 8)
self.assertEqual(x.max().item(), 30.0)
# Test forward prod
def test_prod(self):
def helper(shape, dtype=torch.float32):
cpu_x = None
x = None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_prod = torch.prod(x)
all_prod_cpu = torch.prod(cpu_x)
self.assertEqual(all_prod, all_prod_cpu)
for dim in range(len(shape)):
dim_prod = torch.prod(x, dim=dim)
dim_prod_cpu = torch.prod(cpu_x, dim=dim)
self.assertEqual(dim_prod, dim_prod_cpu)
dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
helper((2, 3), dtype)
# Test forward mean
def test_mean(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_mean = torch.mean(x)
all_mean_cpu = torch.mean(cpu_x)
self.assertEqual(all_mean, all_mean_cpu)
nil_dim_mean = torch.mean(x, dim=[])
nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
zero_dim_mean = torch.mean(x, dim=[0])
zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
zero_one_dim_mean = torch.mean(x, dim=[0, 1])
zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
two_three_dim_mean = torch.mean(x, dim=[2, 3])
two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
helper(2, 8, 4, 5)
# Test std
def test_std(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
all_std = torch.std(x, unbiased=False)
all_std_cpu = torch.std(cpu_x, unbiased=False)
self.assertEqual(all_std, all_std_cpu)
nil_dim_std = torch.std(x, dim=[], unbiased=False)
nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
self.assertEqual(nil_dim_std, nil_dim_std_cpu)
nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
zero_dim_std = torch.std(x, dim=[0], unbiased=False)
zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
self.assertEqual(zero_dim_std, zero_dim_std_cpu)
zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
all_std = torch.std(x, unbiased=True)
all_std_cpu = torch.std(cpu_x, unbiased=True)
self.assertEqual(all_std, all_std_cpu)
nil_dim_std = torch.std(x, dim=[], unbiased=True)
nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
self.assertEqual(nil_dim_std, nil_dim_std_cpu)
nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
zero_dim_std = torch.std(x, dim=[0], unbiased=True)
zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
self.assertEqual(zero_dim_std, zero_dim_std_cpu)
zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
helper((4, 5, 6, 7))
# verify if a change in shape of input would cause problems with graph caching
helper((9, 5, 6, 7))
# Test var
def test_var_simple(self):
def helper():
shape = [2, 3, 4, 5]
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
for unbiased in [False, True]:
for keepdim in [False, True]:
zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
self.assertEqual(zero_dim_var, zero_dim_var_cpu)
all_var = torch.var(x, unbiased=unbiased)
all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
self.assertEqual(all_var, all_var_cpu)
nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(nil_dim_var, nil_dim_var_cpu)
zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(zero_dim_var, zero_dim_var_cpu)
zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
helper()
# Test forward amax
def test_amax(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
result = torch.amax(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')
result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)
self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)
# Test forward amin
def test_amin(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
result = torch.amin(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')
result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)
self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)
# Test minimum and maximum
def test_minimum_maximum(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
minimum_result_mps = torch.minimum(mps_x, mps_y)
self.assertEqual(minimum_result_cpu, minimum_result_mps)
maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
maximum_result_mps = torch.maximum(mps_x, mps_y)
self.assertEqual(maximum_result_cpu, maximum_result_mps)
helper(1, 1, 4, 5)
def test_clamp_fp16_fp32(self):
cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
dtype = torch.float16
clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16)
clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10
clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps)
clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16)
clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10
clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu)
self.assertEqual(clamp_result_mps, clamp_result_cpu)
def test_clamp_nan(self):
t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu")
clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100)
clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100)
self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu)
clamp_min_mps = torch.clamp(t_mps, min=-100)
clamp_min_cpu = torch.clamp(t_cpu, min=-100)
self.assertEqual(clamp_min_mps, clamp_min_cpu)
clamp_max_mps = torch.clamp(t_mps, max=100)
clamp_max_cpu = torch.clamp(t_cpu, max=100)
self.assertEqual(clamp_max_mps, clamp_max_cpu)
# Test clamp_min
def test_clamp_min(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
min_t = cpu_min_t.detach().clone().to('mps')
clamp_min_result = torch.clamp_min(x, min=5.0)
clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
self.assertEqual(clamp_min_result, clamp_min_result_cpu)
clamp_min_t_result = torch.clamp_min(x, min=min_t)
clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
helper(2, 8, 4, 5)
# Test clamp_max
def test_clamp_max(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
max_t = cpu_max_t.detach().clone().to('mps')
clamp_max_result = torch.clamp_max(x, max=100.0)
clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
self.assertEqual(clamp_max_result, clamp_max_result_cpu)
clamp_max_t_result = torch.clamp_max(x, max=max_t)
clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
helper(2, 8, 4, 5)
# Test clamp
def test_clamp(self):
def helper(n, c, h, w):
import numpy as np
upper_bound = 1000
half_upper_bound = upper_bound / 2
# x=[0..1000)
x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
# x=[0..500)
min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
min_t = cpu_min_t.detach().clone().to('mps')
# x=[500..1000), to ensure max's are greater than mins
max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
max_t = cpu_max_t.detach().clone().to('mps')
# [200..600]: just an arbitrary range between [0..1000]
clamp_result = torch.clamp(x, min=200.0, max=600.0)
clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
self.assertEqual(clamp_result, clamp_result_cpu)
# test optional scalar refs and cached graph keys by passing only max
clamp_opt_result = torch.clamp(x, max=600.0)
clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
self.assertEqual(clamp_t_result, clamp_t_result_cpu)
# test optional tensor refs and cached graph keys by passing only max
clamp_topt_result = torch.clamp(x, max=max_t)
clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
# test strided x
clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
self.assertEqual(clamp_result, clamp_result_cpu)
# test strided x, min_t, max_t
clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
self.assertEqual(clamp_result, clamp_result_cpu)
# test strided min_t, max_t
clamp_result = torch.clamp(
x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
min=min_t.movedim(0, -1),
max=max_t.movedim(0, -1)
)
clamp_result_cpu = torch.clamp(
cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
min=cpu_min_t.movedim(0, -1),
max=cpu_max_t.movedim(0, -1)
)
self.assertEqual(clamp_result, clamp_result_cpu)
# test inplace clamping
x.clamp_(min=200.0, max=600.0)
cpu_x.clamp_(min=200.0, max=600.0)
self.assertEqual(cpu_x, x)
helper(2, 8, 4, 5)
def test_divmode(self):
def helper(shape, rounding_mode):
for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or
(rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False:
cpu_x = None
cpu_y = None
if (dtype in [torch.float32, torch.float16]):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
else:
cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
# clamp to avoid division by 0
mps_y = cpu_y.detach().clone().to('mps')
if (rounding_mode == "floor_divide"):
result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
result_div_mps = torch.floor_divide(mps_x, mps_y)
self.assertEqual(result_div_mps, result_div_cpu)
else:
result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
self.assertEqual(result_div_mps, result_div_cpu)
helper((2, 8, 4, 5), None)
helper((2, 8, 4, 5), "floor")
helper((2, 8, 4, 5), "trunc")
helper((2, 8, 4, 5), "floor_divide")
def test_rounding(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
result_floor_cpu = torch.floor(cpu_x)
result_floor_mps = torch.floor(mps_x)
self.assertEqual(result_floor_mps, result_floor_cpu)
result_ceil_cpu = torch.ceil(cpu_x)
result_ceil_mps = torch.ceil(mps_x)
self.assertEqual(result_ceil_mps, result_ceil_cpu)
result_trunc_cpu = torch.trunc(cpu_x)
result_trunc_mps = torch.trunc(mps_x)
self.assertEqual(result_trunc_mps, result_trunc_cpu)
result_round_cpu = torch.round(cpu_x)
result_round_mps = torch.round(mps_x)
self.assertEqual(result_round_mps, result_round_cpu)
helper((2, 6, 3, 5))
helper((2, 8, 4, 5))
def test_remainder(self):
res_cpu = torch.remainder(
torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32))
res_mps = torch.remainder(
torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32))
self.assertEqual(res_cpu, res_mps)
res_cpu = torch.remainder(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5)
res_mps = torch.remainder(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
self.assertEqual(res_cpu, res_mps)
def test_expand(self):
def helper(n, c):
values = [[1.0], [4.0], [7.0]]
cpu_x = torch.tensor(values, device='cpu')
x = cpu_x.detach().clone().to('mps')
strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
strided_mps = torch.as_strided(x, (3, 4), (1, 0))
self.assertEqual(strided_mps, strided_cpu)
helper(3, 1)
def test_im2col(self):
def helper(x):
return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3)
x_cpu = torch.rand(1, 1, 200, 100)
x = x_cpu.detach().clone().to('mps')
self.assertEqual(helper(x_cpu), helper(x))
def test_select(self):
def helper(n, c):
cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
strided_mps = torch.as_strided(x, (3, 1), (3, 1))
self.assertEqual(strided_mps, strided_cpu)
strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
strided_mps = torch.as_strided(x, (1, 3), (3, 1))
self.assertEqual(strided_mps, strided_cpu)
strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
self.assertEqual(strided_mps, strided_cpu)
helper(3, 3)
def test_sort(self):
for SIZE in (4, 2049):
device = 'mps'
x = torch.rand(4, SIZE, device=device)
res1val, res1ind = torch.sort(x)
res2val = torch.tensor((), device=device)
res2ind = torch.tensor((), device=device, dtype=torch.long)
torch.sort(x, out=(res2val, res2ind))
self.assertEqual(res1val, res2val, atol=0, rtol=0)
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
self.assertEqual(torch.argsort(x), res1ind)
self.assertEqual(x.argsort(), res1ind)
self.assertEqual(
torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
torch.tensor((10, 20, 30, 40, 50), device=device),
atol=0, rtol=0
)
def test_upsample_nearest2d(self):
def helper(N, C, H, W, memory_format):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().to('mps').requires_grad_()
values = [1, 2, 5, 10, 40]
for i in values:
for j in values:
upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
outputCPU = upsample_nearest2d(inputCPU)
outputMPS = upsample_nearest2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
outputCPU = upsample_nearest2d(inputCPU)
outputMPS = upsample_nearest2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
self.assertEqual(inputCPU.grad, inputMPS.grad)
for memory_format in [torch.channels_last, torch.contiguous_format]:
helper(1, 1, 4, 4, memory_format=memory_format)
helper(7, 5, 3, 2, memory_format=memory_format)
def test_upsample_bilinear2d(self):
def helper(N, C, H, W):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
requires_grad=True).reshape(N, C, H, W)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
values = [1, 2, 5, 10, 40]
for i in values:
for j in values:
upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
outputCPU = upsample_bilinear2d(inputCPU)
outputMPS = upsample_bilinear2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
outputCPU = upsample_bilinear2d(inputCPU)
outputMPS = upsample_bilinear2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper(1, 1, 4, 4)
helper(7, 5, 3, 2)
def test_interpolate(self):
def helper(shape, output_size, scales, mode, align_corners=False):
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
# align_corners is used for 2D interpolation only
if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
if scales is not None:
outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
else:
outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
elif scales is not None:
outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
else:
outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
self.assertEqual(outputCPU, outputMPS)
# backward pass (chose 0.6 just to have the grad_output != 1)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
self.assertEqual(inputCPU.grad, inputMPS.grad)
# 1D interpolation
for mode in ['nearest', 'nearest-exact']:
helper([2, 3, 4], [3], None, mode) # downsample with size
helper([2, 3, 4], [6], None, mode) # upsample with size
helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor
helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor
# 2D interpolation
for mode in ['nearest', 'nearest-exact', 'bilinear']:
helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size
helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size
helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor
helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor
# align_corners=True
helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
# Test concat forward
def test_cat1(self):
def helper(shape_x, shape_y, shape_z):
cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
cat = torch.cat([x, y, z], dim=1)
cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
self.assertEqual(cat, cat_cpu)
helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5])
helper([2, 2, 6, 5], [0], [2, 5, 6, 5])
helper([0], [2, 3, 6, 5], [2, 5, 6, 5])
helper([2, 3, 4, 5], [2, 5, 4, 5], [0])
helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5])
helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5])
# Test stack forward
def test_stack(self):
# All shapes must be same
def helper(shape, dtype=torch.float32):
x, cpu_x = None, None
y, cpu_y = None, None
z, cpu_z = None, None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
else:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
y = cpu_y.detach().clone().to('mps').requires_grad_()
cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
z = cpu_z.detach().clone().to('mps').requires_grad_()
stack = torch.stack([x, y, z], dim=1)
stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
self.assertEqual(stack, stack_cpu)
helper([2, 8, 4, 5])
helper([2, 8, 4, 5], dtype=torch.float16)
helper([2, 8, 4, 5], dtype=torch.int32)
helper([2, 8, 4, 5], dtype=torch.int64)
helper([2, 8, 4, 5], dtype=torch.bool)
# Empty test - Currently failing! Empty tensor not handled!
# helper([0, 2, 4, 5])
# Test abs
def test_abs(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
abs_result = torch.abs(x)
abs_result_cpu = torch.abs(cpu_x)
self.assertEqual(abs_result, abs_result_cpu)
helper((2, 8, 4, 5))
def test_log(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.log(x)
log_result_cpu = torch.log(cpu_x)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
def test_log_ten(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_ten_result = torch.log10(x)
log_ten_result_cpu = torch.log10(cpu_x)
self.assertEqual(log_ten_result, log_ten_result_cpu)
helper((2, 8, 4, 5))
def test_log_two(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_two_result = torch.log2(x)
log_two_result_cpu = torch.log2(cpu_x)
self.assertEqual(log_two_result, log_two_result_cpu)
helper((2, 8, 4, 5))
def test_log1p(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.log1p(x)
log_result_cpu = torch.log1p(cpu_x)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
def test_logaddexp(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
log_result = torch.logaddexp(x, y)
log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
def test_logaddexp2(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
log_result = torch.logaddexp2(x, y)
log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
def test_logsumexp(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.logsumexp(x, -1)
log_result_cpu = torch.logsumexp(cpu_x, -1)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
# Test concat forward
def test_cat2(self):
def helper1(shape_x, shape_y, shape_z, shape_w):
cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
w = cpu_w.detach().clone().to('mps')
cat = torch.cat([x, y, z, w], dim=1)
cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
self.assertEqual(cat, cat_cpu)
def helper(shape_x, shape_y, shape_z):
cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
cat = torch.cat([x, y, z], dim=1)
cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
self.assertEqual(cat, cat_cpu)
helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
# Empty test - Currently failing! Empty tensor not handled!
# helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
# Test isnan
def test_isnan(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
nan_index = [random.randrange(0, shape[0])]
# make a selected row inf
cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
x = cpu_x.detach().clone().to('mps')
isnan_result = torch.isnan(x)
isnan_result_cpu = torch.isnan(cpu_x)
self.assertEqual(isnan_result, isnan_result_cpu)
helper((8, 2, 4, 5))
# Test reciprocal
def test_reciprocal(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
reciprocal_result = torch.reciprocal(x)
reciprocal_result_cpu = torch.reciprocal(cpu_x)
cpu_grad = torch.ones_like(reciprocal_result_cpu)
grad = cpu_grad.to('mps')
reciprocal_result.backward(gradient=grad)
reciprocal_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(reciprocal_result, reciprocal_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 8, 4, 5))
# Test sqrt
def test_sqrt(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
sqrt_result = torch.sqrt(x)
sqrt_result_cpu = torch.sqrt(cpu_x)
cpu_grad = torch.ones_like(sqrt_result_cpu)
grad = cpu_grad.to('mps')
sqrt_result.backward(gradient=grad)
sqrt_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(sqrt_result, sqrt_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 8, 4, 5))
# Test selu, elu, celu
def test_elu(self):
def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
x = cpu_x.detach().clone().to('mps').requires_grad_(True)
for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
elu_result = activation_func(x)
elu_result_cpu = activation_func(cpu_x)
cpu_grad = torch.randn(elu_result_cpu.shape)
grad = cpu_grad.to('mps')
elu_result.backward(gradient=grad)
elu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(elu_result, elu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for memory_fromat in [torch.channels_last, torch.contiguous_format]:
for shape in [(2, 8, 4, 5)]:
for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
helper(shape, alpha, memory_fromat)
def test_elu_strided_output(self):
# https://github.com/pytorch/pytorch/issues/124834
elu_input = torch.randn(1, 1024, 500)
alpha = float(1)
inplace = False
elu_input_noncontiguous = elu_input.transpose(1, 2)
self.assertEqual(
F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace),
F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
)
# Test glu
def test_glu(self):
def helper(shape, dim=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
for activation_func in [torch.nn.GLU(dim=dim)]:
glu_result = activation_func(x)
glu_result_cpu = activation_func(cpu_x)
cpu_grad = torch.randn(glu_result_cpu.shape)
grad = cpu_grad.to('mps')
glu_result.backward(gradient=grad)
glu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(glu_result, glu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for shape in [[4], (2, 4), (2, 8, 4, 6)]:
for dim in range(len(shape)):
helper(shape, dim)
# Test softplus
def test_softplus(self):
def helper(shape, beta, threshold, dtype):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
cpu_grad = torch.randn(softplus_result.shape)
grad = cpu_grad.to('mps')
softplus_result.backward(gradient=grad)
softplus_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(softplus_result, softplus_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for shape, beta, threshold, dtype in product(
[(), (2, 3), (10, 10), (2, 3, 4, 5)],
[0.5, 1, 2, 3, 4],
[0.5, 20, 30, 40, 50],
[torch.float16, torch.float32]
):
helper(shape, beta, threshold, dtype)
# Test silu
def test_silu(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
silu_result = torch.nn.SiLU()(x)
silu_result_cpu = torch.nn.SiLU()(cpu_x)
cpu_grad = torch.randn(silu_result_cpu.shape)
grad = cpu_grad.to('mps')
silu_result.backward(gradient=grad)
silu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(silu_result, silu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for shape in [[], (2, 3), (2, 8, 4, 5)]:
helper(shape)
def test_cast_mps_to_cpu(self):
def helper(src_dtype, dst_dtype):
input = torch.rand((1, 3, 128, 128), dtype=src_dtype)
input_cast_mps = input.to('mps')
input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype)
# needs to match the initial Tensor
self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype))
helper(torch.half, torch.float)
helper(torch.float, torch.half)
def test_cast_mps_to_mps(self):
def helper(src_dtype, dst_dtype):
input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype)
input_mps = input_cpu.to('mps')
output_mps = input_mps.to(dtype=dst_dtype)
output_cpu = input_cpu.to(dtype=dst_dtype)
self.assertEqual(output_mps.cpu(), output_cpu)
helper(torch.half, torch.float)
helper(torch.float, torch.half)
helper(torch.half, torch.long)
helper(torch.float, torch.int)
def test_avg_pool2d_count_include_pad(self):
cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True)
ref_y = pool(cpu_x)
y = pool(x)
self.assertEqual(y, ref_y)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
ref_y.backward(gradient=cpu_grad)
y.backward(gradient=grad)
self.assertEqual(x.grad, cpu_x.grad)
# Test adaptive avg pool2d - when the input size is a multiple of output size
# Not testing for channels last right now
def test_adaptive_avg_pool2d_simple(self):
def helper(input_shape, out_shape, channels_last):
cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
cpu_grad = torch.randn(avg_result_cpu.shape)
grad = cpu_grad.to('mps')
avg_result.backward(gradient=grad)
avg_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(avg_result, avg_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 2, 4, 4), (2, 2), False)
helper((2, 2, 9, 9), (3, 3), False)
helper((2, 2, 9, 9), (9, 9), False)
helper((2, 2, 16, 16), (2, 2), False)
helper((2, 2, 16, 16), (2, 16), False)
helper((2, 16, 16), (4, 4), False)
# Output shape larger than input shape
helper((2, 2, 4, 4), (8, 8), False)
helper((2, 2, 2, 2), (4, 4), False)
helper((2, 2, 3, 3), (9, 9), False)
helper((2, 2, 2, 2), (16, 16), False)
helper((2, 2, 2, 16), (16, 16), False)
helper((2, 4, 4), (16, 16), False)
try:
helper((2, 2, 3, 3), (7, 7), False)
except Exception as e:
pass
# Test max avg pool2d - when the input size is a multiple of output size
# Not testing for channels last right now
def test_adaptive_max_pool2d_simple(self):
def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
cpu_x = None
if (dtype in [torch.float16, torch.float32]):
cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
else:
cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
max_result, max_indices = None, None
max_result_cpu, max_indices_cpu = None, None
if (return_indices):
max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
else:
max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
cpu_grad = torch.randn(max_result_cpu.shape)
grad = cpu_grad.to('mps')
max_result.backward(gradient=grad)
max_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(max_result, max_result_cpu)
if (return_indices):
self.assertEqual(max_indices, max_indices_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for dtype in [torch.float32]:
for return_indices in [False, True]:
helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
helper((2, 16, 16), (4, 4), return_indices, dtype)
def test_gelu_simple(self):
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()
cpu_x.requires_grad_()
x.requires_grad_()
gelu_result = torch.nn.GELU()(x)
# GELU is not supported on CPU, so cast it to float
gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float))
cpu_grad = torch.ones_like(gelu_result_cpu)
grad = cpu_grad.to('mps')
gelu_result.backward(gradient=grad)
gelu_result_cpu.backward(gradient=cpu_grad)
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
# Test that gelu would raise an assert for integral types
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
def test_mish_simple(self):
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()
cpu_x.requires_grad_()
x.requires_grad_()
mish_result = torch.nn.Mish()(x)
mish_result_cpu = torch.nn.Mish()(cpu_x)
cpu_grad = torch.ones_like(mish_result_cpu)
grad = cpu_grad.to('mps')
mish_result.backward(gradient=grad)
mish_result_cpu.backward(gradient=cpu_grad)
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
def test_gelu(self):
def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
numpy_dtype = {
torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
}[dtype]
devices = ['cpu']
devices += ['mps']
def _gelu_ref(X):
return X * stats.norm.cdf(X) # noqa: F821
for d in devices:
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
res = X
ref = (X.to(numpy_dtype).cpu().detach().numpy())
self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
for n in [1, 5, 10]:
for m in [1, 5, 10]:
_test_gelu(n, m, torch.float32, True)
_test_gelu(n, m, torch.float32, False)
# Test multi threaded
num_threads = torch.get_num_threads()
torch.set_num_threads(4)
try:
_test_gelu(32, 32, torch.float32, False)
finally:
torch.set_num_threads(num_threads)
def test_gelu_tanh(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
x = cpu_x.detach().clone().to('mps')
gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
helper((2, 8, 4, 5))
# Test hardtanh
def test_hardtanh(self):
def helper(shape, min_val, max_val, inplace=False):
cpu_x = None
x = None
if (not inplace):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
else:
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
self.assertEqual(hardtanh_result, hardtanh_result_cpu)
if (not inplace):
cpu_grad = torch.randn(hardtanh_result_cpu.shape)
grad = cpu_grad.to('mps')
hardtanh_result.backward(gradient=grad)
hardtanh_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
helper(shape, min_val, max_val)
helper(shape, min_val, max_val, inplace=True)
def test_hardswish(self):
def helper(shape, inplace=False, requires_grad=True):
m = nn.Hardswish(inplace=inplace)
input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad)
input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
if inplace and requires_grad: # check that both raise runtime error
self.assertRaises(RuntimeError, lambda: m(input_cpu))
self.assertRaises(RuntimeError, lambda: m(input_mps))
return
output_cpu = m(input_cpu)
output_mps = m(input_mps)
cpu_grad = torch.ones_like(output_cpu)
mps_grad = cpu_grad.to('mps')
self.assertEqual(output_cpu, output_mps)
if requires_grad:
output_cpu.backward(gradient=cpu_grad)
output_mps.backward(gradient=mps_grad)
self.assertEqual(input_cpu.grad, input_mps.grad)
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
helper(shape, inplace=False, requires_grad=False)
helper(shape, inplace=True, requires_grad=False)
helper(shape, inplace=False, requires_grad=True)
helper(shape, inplace=True, requires_grad=True)
def test_transpose_2D(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
mps_x1 = torch.tensor(values1, device='mps')
cpu_transpose = torch.transpose(cpu_x, 0, 1)
mps_transpose = torch.transpose(mps_x, 0, 1)
self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
def test_transpose_3D(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
self.assertEqual(cpu_transpose1, mps_transpose1)
cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
self.assertEqual(cpu_transpose2, mps_transpose2)
cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
self.assertEqual(cpu_transpose3, mps_transpose3)
def test_transpose_4D(self):
values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
[[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
self.assertEqual(cpu_transpose1, mps_transpose1)
cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
self.assertEqual(cpu_transpose2, mps_transpose2)
cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
self.assertEqual(cpu_transpose3, mps_transpose3)
cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
self.assertEqual(cpu_transpose4, mps_transpose4)
cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
self.assertEqual(cpu_transpose5, mps_transpose5)
cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
self.assertEqual(cpu_transpose6, mps_transpose6)
# Test sign
def test_sign(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
sign_result = torch.sign(x)
sign_result_cpu = torch.sign(cpu_x)
cpu_grad = torch.ones_like(sign_result_cpu)
grad = cpu_grad.to('mps')
sign_result.backward(gradient=grad)
sign_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(sign_result, sign_result_cpu)
helper((2, 8, 4, 5))
def test_signbit(self):
def helper(shape, dtype):
cpu_x = torch.randn(shape, device='cpu').to(dtype)
x = cpu_x.clone().to('mps')
signbit_result = torch.signbit(x)
signbit_result_cpu = torch.signbit(cpu_x)
self.assertEqual(signbit_result, signbit_result_cpu)
helper((2, 8, 4, 5), torch.int)
helper((2, 8, 4, 5), torch.float)
helper((2, 8, 4, 5), torch.int64)
# Test neg
def test_neg(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
neg_result = torch.neg(x)
neg_result_cpu = torch.neg(cpu_x)
cpu_grad = torch.ones_like(neg_result_cpu)
grad = cpu_grad.to('mps')
neg_result.backward(gradient=grad)
neg_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(neg_result, neg_result_cpu)
helper((2, 8, 4, 5))
def test_neg_strided_input(self):
# See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337
x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
y = x.permute(1, 0, 2)[..., 1]
z = y + y.neg()
self.assertEqual(z.abs().max().item(), 0.0)
# Test index add
def test_index_add(self):
def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):
cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False)
source = cpu_source.detach().clone().to('mps')
idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha)
idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha)
self.assertEqual(idx_result, idx_result_cpu)
helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5)
helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0)
helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5)
helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0)
helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4)
helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0)
# test result dim=1
helper((2,), 0, [1], (1,), 6.0)
helper(2, 0, 1, 1, 6)
# test float16
helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16)
def test_index_64bit(self):
""" Test that index operations work for 4Gb+ tensors """
if product_version < 14.0:
raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039")
# Cleanup memory
gc.collect()
torch.mps.empty_cache()
# Check that index operations work for 4+GB tensors
x = torch.rand(16000, 67120, device="mps")
self.assertGreater(x.element_size() * x.numel(), 2**32)
idx = torch.arange(0, 2, device="mps")
x_sampled = x[:, idx]
self.assertEqual(x[:, 0], x_sampled[:, 0])
# Reclaim memory after running the tests
del x
gc.collect()
torch.mps.empty_cache()
def test_mm_large(self):
""" Test that MM works for matrices with index larger than 32K """
x = torch.rand(10, 1, device="mps")
y = torch.rand(1, 32769, device="mps")
# This used to crash with:
# error: subRange.start (24576) is not less than length of dimension[0] (16384)
# See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
def compare_mm(m, n, k, dtype=torch.float):
x = torch.rand(m, n, device="mps", dtype=dtype)
y = torch.rand(n, k, device="mps", dtype=dtype)
z = torch.mm(x, y).cpu()
z_cpu = torch.mm(x.cpu(), y.cpu())
self.assertEqual(z, z_cpu)
# Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
compare_mm(1024, 1, 32769)
# one more time, but with dimensions inverted
# see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
compare_mm(32769, 1, 1025)
if product_version >= 14.0:
# Test bfloat16 mm
compare_mm(1024, 1, 32769, torch.bfloat16)
@unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test")
@unittest.skipIf(product_version < 14.0, "Can't allocate 4Gb tensor on MacOS 13")
def test_copy_large(self):
""" Test that copy of 4Gb+ tensors works """
x = torch.ones((2**30 + 11,), dtype=torch.float32)
y = x.to(device="mps")
self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps")))
del y
del x
# Test flip
def test_flip(self):
def helper(shape, dims):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
flip_result = torch.flip(x, dims=dims)
flip_result_cpu = torch.flip(cpu_x, dims=dims)
self.assertEqual(flip_result, flip_result_cpu)
helper((2, 8, 4, 5), [0])
helper((8, 8, 4, 5), [0, 1])
helper((2, 8, 4, 5), (0, 1, 2, 3))
helper((2, 3, 3), (-1,))
# empty dims
helper((2, 8, 4, 5), [])
# input.numel() == 1
helper((1,), (0,))
# input.numel() == 0
helper((0,), (0,))
# none of dims that needs to be flipped
helper((1, 3), [0])
# Test index select
def test_index_select(self):
def helper(shape, dim, index, idx_dtype=torch.int32):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
idx_result = torch.index_select(x, dim=dim, index=idx)
idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
self.assertEqual(idx_result, idx_result_cpu)
helper((2, 8, 4, 5), 0, [1])
helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
helper((2, 8, 4, 5), 2, [3, 0, 1])
helper((2, 8, 4, 5), 3, [2, 3, 0])
helper((2, 3, 3), -1, [1, 2])
helper((), 0, [0])
helper((5), 0, [])
def test_index_select_scalar(self):
def helper(value, dim, index, idx_dtype=torch.int32):
cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
idx_result = torch.index_select(x, dim=dim, index=idx)
idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
self.assertEqual(idx_result, idx_result_cpu)
helper(22, 0, [0])
with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
helper(22, 0, [])
def test_embedding_dense_backward(self):
def helper(n, d, m, idx):
embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
emedding_weight = embeddingMPS.weight.detach().cpu()
W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
idx_MPS = torch.tensor(idx, device='mps')
a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable
a_MPS.retain_grad()
b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place
b_MPS.retain_grad()
out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
loss_MPS = out_MPS.sigmoid().prod()
loss_MPS.backward()
embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight)
W_CPU = W_MPS.to('cpu')
idx_CPU = torch.tensor(idx)
a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable
a_CPU.retain_grad()
b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place
b_CPU.retain_grad()
out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
loss_CPU = out_CPU.sigmoid().prod()
loss_CPU.backward()
self.assertEqual(b_CPU.grad, b_MPS.grad)
self.assertEqual(a_CPU.grad, a_MPS.grad)
helper(3, 5, 7, [0, 1, 2])
helper(3, 6, 7, [0, 1, 2]) # verify if changes in shape would cause cached graph lookup problems
helper(3, 5, 7, 2) # test scalar index
# Test pytorch gather
def test_gather(self):
def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = np.random.randint(0, shape[dim], idx_shape)
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
gather_result = torch.gather(x, dim=dim, index=idx)
gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
gather_result.backward(gradient=grad)
gather_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(gather_result, gather_result_cpu)
self.assertEqual(cpu_x.grad, x.grad)
helper((6, 3, 3), 0, (3, 3, 3))
helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
# Test pytorch gather
def test_gather_scalar(self):
idx_dtype = torch.int64
cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
idx_np = [0]
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
gather_result = torch.gather(x, dim=0, index=idx)
gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
gather_result.backward(gradient=grad)
gather_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(gather_result, gather_result_cpu)
self.assertEqual(cpu_x.grad, x.grad)
# Test pytorch scatter_add and scatter
def test_scatter_add(self):
def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
src = cpu_src.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = None
if (do_add):
idx_np = np.random.randint(0, shape[dim], idx_shape)
else:
idx_np = np.array([[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[4, 5, 6]])
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
scatter_result = None
scatter_result_cpu = None
if (do_add):
scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
else:
scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
cpu_grad = None
grad = None
if (idx_shape == src_shape):
cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
scatter_result.backward(gradient=grad)
scatter_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(scatter_result, scatter_result_cpu)
if (idx_shape == src_shape):
self.assertEqual(cpu_x.grad, x.grad)
self.assertEqual(cpu_src.grad, src.grad)
helper((2, 3), 0, (5, 3), (5, 3))
helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
# Test scatter src
helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
# Test pytorch scatter_add and scatter for scalar input
def test_scatter_add_scalar(self):
def helper(idx_dtype=torch.int64, do_add=True):
cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
src = cpu_src.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = [0]
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
scatter_result = None
scatter_result_cpu = None
if (do_add):
scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
else:
scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
cpu_grad = None
grad = None
cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
scatter_result.backward(gradient=grad)
scatter_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(scatter_result, scatter_result_cpu)
self.assertEqual(cpu_x.grad, x.grad)
self.assertEqual(cpu_src.grad, src.grad)
helper()
helper(do_add=False)
# Test pytorch scatter_reduce
def test_scatter_reduce(self):
def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
src = cpu_src.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = np.random.randint(0, shape[dim], idx_shape)
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
self.assertEqual(scatter_result, scatter_result_cpu)
# for reduce in ["sum", "prod", "amax", "amin"]:
for reduce_type in ["add", "multiply"]:
helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
def test_is_nonzero(self):
self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
# Test triu
def test_triu(self):
def helper(shape, diag=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
triu_result = torch.triu(x, diag)
triu_result_cpu = torch.triu(cpu_x, diag)
cpu_grad = torch.randn(triu_result_cpu.shape)
grad = cpu_grad.to('mps')
triu_result.backward(gradient=grad)
triu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(triu_result, triu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 8, 4, 5))
helper((2, 8, 4, 5), diag=1)
helper((2, 8, 4, 5), diag=2)
helper((2, 8, 4, 5), diag=3)
helper((2, 8, 4, 5), diag=-1)
helper((2, 8, 4, 5), diag=-2)
helper((2, 8, 4, 5), diag=-3)
# Test inverse
def test_inverse(self):
def helper(n):
cpu_input = torch.randn(n, n, device='cpu')
mps_input = cpu_input.to('mps')
cpu_result = torch.linalg.inv(cpu_input)
mps_result = torch.linalg.inv(mps_input)
self.assertEqual(cpu_result, mps_result)
helper(2)
helper(6)
helper(3)
helper(8)
# Test tril
def test_tril(self):
def helper(shape, diag=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
tril_result = torch.tril(x, diag)
tril_result_cpu = torch.tril(cpu_x, diag)
cpu_grad = torch.randn(tril_result_cpu.shape)
grad = cpu_grad.to('mps')
tril_result.backward(gradient=grad)
tril_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(tril_result, tril_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 8, 4, 5))
helper((2, 8, 4, 5), diag=1)
helper((2, 8, 4, 5), diag=2)
helper((2, 8, 4, 5), diag=3)
helper((2, 8, 4, 5), diag=-1)
helper((2, 8, 4, 5), diag=-2)
helper((2, 8, 4, 5), diag=-3)
# test eye
def test_eye(self):
def helper(n, m, dtype):
cpu_result = None
result = None
if (n == m):
cpu_result = torch.eye(n, dtype=dtype, device='cpu')
result = torch.eye(n, dtype=dtype, device='mps')
else:
cpu_result = torch.eye(n, m, device='cpu')
result = torch.eye(n, m, device='mps')
self.assertEqual(result, cpu_result)
for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]:
helper(2, 2, dtype)
helper(2, 3, dtype)
helper(0, 2, dtype)
helper(0, 0, dtype)
helper(3, 8, dtype)
helper(8, 3, dtype)
# Test diag
def test_diag(self):
def helper(shape, diag=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
diag_result = torch.diag(x, diag)
diag_result_cpu = torch.diag(cpu_x, diag)
# cpu_grad = torch.randn(diag_result_cpu.shape)
# grad = cpu_grad.to('mps')
# diag_result.backward(gradient=grad)
# diag_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(diag_result, diag_result_cpu)
# self.assertEqual(x.grad, cpu_x.grad)
for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
helper(shape, diag=diag)
# Test linspace
def test_linspace(self):
def helper(start, end, steps, dtype=torch.float32):
cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
self.assertEqual(cpu_result, result)
for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
helper(2, 5, 10, dtype)
helper(2, 2, 10, dtype)
helper(5, 2, 10, dtype)
helper(2, 2, 0, dtype)
# Test argange
def test_arange(self):
self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
def test_arange_empty(self):
out_mps = torch.tensor([], device="mps")
out_cpu = torch.tensor([], device="cpu")
y_mps = torch.arange(0, 0, 1, out=out_mps)
y_cpu = torch.arange(0, 0, 1, out=out_cpu)
self.assertEqual(y_mps, y_cpu)
# Test rgange
def test_range(self):
self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
# Test softmax
def test_softmax(self):
def helper(shape, dim, channels_last=False):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
softmax_result = torch.nn.functional.softmax(x, dim=dim)
softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
# Currently NOT testing backward for channels last backward
cpu_grad = None
grad = None
if (not channels_last):
cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
softmax_result.backward(gradient=grad)
softmax_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(softmax_result, softmax_result_cpu)
if (not channels_last):
self.assertEqual(x.grad, cpu_x.grad)
def helper2(dim):
cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
softmax_result = torch.nn.functional.softmax(x, dim=dim)
softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
softmax_result.backward(gradient=grad)
softmax_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(softmax_result, softmax_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper2(0)
for channels_last in [False]:
for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
if (len(shape) != 4 and channels_last):
continue
for dim in [0, 1, 2, 3, -1, -2, -3]:
helper(shape, dim, channels_last)
def test_nan_to_num(self):
inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0)
outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0)
self.assertEqual(outputMPS, outputCPU)
# Test where
def test_where(self):
def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
cond = cpu_cond.detach().clone().to('mps')
cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
y = cpu_y.detach().clone().to('mps').requires_grad_()
cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
out = torch.where(cond, x, y)
cpu_grad = torch.randn(cpu_out.shape)
grad = cpu_grad.to('mps')
cpu_out.backward(gradient=cpu_grad)
out.backward(gradient=grad)
self.assertEqual(out, cpu_out)
self.assertEqual(x.grad, cpu_x.grad)
self.assertEqual(y.grad, cpu_y.grad)
for shape in ([(0, 3), [], (2, 3), (9,)]):
helper(shape, shape, shape)
helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
helper([], (1, 1, 4), (2, 3, 1))
helper([], (2, 3, 4), [])
helper((5, 2, 3), (2, 3), (2, 3))
helper((2, 3), (5, 2, 3), (2, 3))
helper((2, 3), (2, 3), (5, 2, 3))
helper((2, 3), (5, 2, 3), (6, 5, 2, 3))
# Test that output is correctly resizes
# TODO: Remove me when out OpInfo testing is enabled on MPS
output = torch.tensor(0.0, device="mps")
cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps")
inp = torch.rand(3, 3, device="mps")
other = torch.rand(3, 3, device="mps")
out = torch.where(cond, inp, other, out=output)
self.assertEqual(id(out), id(output))
self.assertEqual(out.shape, (3, 3))
# Test normal
def test_normal(self):
def helper(shape, mean=0.0, std=1.0):
mps_out = torch.normal(mean, std, shape, device='mps')
mean_array = np.ones(shape)
mean_array *= mean
cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
std_array = np.ones(shape)
std_array *= std
cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
std_tensor = cpu_std_tensor.detach().clone().to('mps')
# test out
mps_out = torch.zeros(shape, device='mps')
torch.normal(mean_tensor, std, out=mps_out)
mps_out = torch.zeros(shape, device='mps')
torch.normal(mean, std_tensor, out=mps_out)
mps_out = torch.zeros(shape, device='mps')
torch.normal(mean_tensor, std_tensor, out=mps_out)
# test without out
mps_out = torch.normal(mean_tensor, std)
self.assertEqual(mps_out.size(), mean_tensor.size())
mps_out = torch.normal(mean, std_tensor)
self.assertEqual(mps_out.size(), std_tensor.size())
inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
mps_out = torch.normal(mean_tensor, std_tensor)
self.assertEqual(mps_out.size(), inferred_shape)
helper((2, 3, 4, 5, 6))
helper((100, 100), 2.5, 1.2)
def test_bernoulli(self):
shape = (10, 10)
all_ones = torch.ones(shape, device='mps')
all_zeros = torch.zeros(shape, device='mps')
prob_tensor = all_ones * 0.5
# probability of drawing "1" is 0.5
mps_out = torch.bernoulli(prob_tensor)
# We can't check reliably the mean and std.
# Just make sure we don't return constant values
self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
# probability of drawing "1" is 0
mps_out = torch.bernoulli(all_zeros)
self.assertEqual(mps_out, all_zeros)
# probability of drawing "1" is 1
mps_out = torch.bernoulli(all_ones)
self.assertEqual(mps_out, all_ones)
# Check it works for different dtypes
for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]:
mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5)
# Check that output is not all zeros or ones
if product_version > 13.0:
uniq = mps_out.unique()
self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
else:
self.assertEqual(mps_out.min().item(), 0.)
self.assertEqual(mps_out.max().item(), 1.)
def test_mps_generator(self):
# explicit manual seeding by creating an MPS Generator
g_mps = torch.Generator(device='mps')
g_mps.manual_seed(999)
mps_x = torch.randn(5, device='mps', generator=g_mps)
g_mps.manual_seed(999)
# generate random numbers with offset `0`
mps_y = torch.randn(5, device='mps', generator=g_mps)
# seed values were the same, so the random tensor contents should match
self.assertEqual(mps_x, mps_y)
# save generator's state (offset = 1) to restore it later
g_state = g_mps.get_state()
# generate random numbers with offset `1`
mps_x = torch.randn(5, device='mps', generator=g_mps)
# in this case, the random results must differ from the last generated random results
self.assertNotEqual(mps_x, mps_y)
# mps_x was produced by g_state, we use it as our reference mps_y.
mps_y = mps_x
# restore the previously saved state, and the results should match again
g_mps.set_state(g_state)
mps_x = torch.randn(5, device='mps', generator=g_mps)
self.assertEqual(mps_x, mps_y)
@serialTest()
def test_default_mps_generator(self):
# manual seeding on the "default" MPS generator using
# the global torch.manual_seed()
torch.manual_seed(230)
mps_x = torch.randn(5, device='mps')
# manual seeding using torch.mps.manual_seed()
# which should set the "default" MPS generator
# like the global torch.manual_seed()
torch.mps.manual_seed(230)
# generate random numbers with offset `0`
mps_y = torch.randn(5, device='mps')
# seed values were the same, so the random tensor contents should match
self.assertEqual(mps_x, mps_y)
# save the default generator's state (offset = 1) to restore it later
g_state = torch.mps.get_rng_state()
# generate random numbers with offset `1`
mps_x = torch.randn(5, device='mps')
# in this case, the random results must differ from the last generated random results
self.assertNotEqual(mps_x, mps_y)
# since we called randn twice after seeding, the offset should be 2
self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2)
# mps_x was produced by g_state, we use it as our reference mps_y.
mps_y = mps_x
# restore the previously saved state to the "default" MPS generator, and the results should match again
torch.mps.set_rng_state(g_state)
mps_x = torch.randn(5, device='mps')
self.assertEqual(mps_x, mps_y)
def test_device_synchronize(self):
# just running some ops each followed by a synchronize to wait for
# MPS stream to finish running each of them
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
torch.mps.synchronize()
x = net1(x)
torch.mps.synchronize()
x.backward(torch.randn_like(x))
torch.mps.synchronize()
@serialTest()
def test_mps_allocator_module(self):
# first garbage collect and empty the cached blocks
gc.collect()
torch.mps.empty_cache()
# measure memory allocations from MPSAllocator
current_alloc_before = torch.mps.current_allocated_memory()
# after garbage collection and emptying the cache the
# current_allocated_memory must be zero
self.assertEqual(current_alloc_before, 0)
# measure total memory allocations from Metal driver
driver_alloc_before = torch.mps.driver_allocated_memory()
# allocate a new 8 MB tensor to force allocation of a new Metal Heap
x = torch.ones(1024 * 1024 * 8, device="mps")
# get memory allocations after allocating tensor x
current_alloc_after = torch.mps.current_allocated_memory()
driver_alloc_after = torch.mps.driver_allocated_memory()
# current and driver memory allocations must have
# grown at this point
self.assertGreater(current_alloc_after, current_alloc_before)
self.assertGreater(driver_alloc_after, driver_alloc_before)
def test_mps_allocator_stats(self):
max_memory = torch.mps.recommended_max_memory()
print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB")
self.assertGreater(max_memory, 0)
# to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
# press record, then run this python test, and press stop. Next expand
# the os_signposts->PyTorchMPS and check if events or intervals are logged
# like this example:
# "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
def test_mps_profiler_module(self):
with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
# just running some ops to capture the OS Signposts traces for profiling
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
x = net1(x)
torch.mps.profiler.start(mode="interval", wait_until_completed=True)
# just running some ops to capture the OS Signposts traces for profiling
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
x = net1(x)
torch.mps.profiler.stop()
def test_mps_event_module(self):
startEvent = torch.mps.Event(enable_timing=True)
startEvent.record()
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
x = net1(x)
endEvent = torch.mps.Event(enable_timing=True)
endEvent.record()
elapsedTime = startEvent.elapsed_time(endEvent)
self.assertGreater(elapsedTime, 0.0)
def test_jit_save_load(self):
m = torch.nn.Module()
m.x = torch.rand(3, 3, device='mps')
buffer = io.BytesIO()
torch.jit.save(torch.jit.script(m), buffer)
buffer.seek(0)
n = torch.jit.load(buffer)
self.assertEqual(n.x, m.x)
# Test random_, random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):
mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
# We can't check reliably the mean and std.
# Just make sure we don't return constant values
self.assertNotEqual(mps_out.float().mean().item(), 0.)
self.assertNotEqual(mps_out.float().std().item(), 0.)
helper([100, 100], 0, 10)
helper([100, 100], 23, 89)
helper([100, 100], 23, 89, dtype=torch.float32)
helper([100, 100], 23, 89, dtype=torch.int64)
helper([100, 100], 0, 2, dtype=torch.bool)
# Test random_
for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]:
x = torch.empty(10, 10, dtype=dtype, device='mps')
x.random_()
self.assertNotEqual(x.max().item(), 0)
# Test exponential
def test_exponential(self):
def helper(shape, lamda, dtype=torch.float32):
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
mps_out.exponential_(lamda)
print(mps_out.to('cpu').float().mean(), 1 / lamda)
print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2))
for dtype in [torch.float32, torch.float16]:
helper([100, 100], 2, dtype)
helper([100, 100], 1, dtype)
helper([100, 100], 3, dtype)
helper([100, 100], 0.5, dtype)
def test_exponential_1(self):
rate = torch.randn(5, 5).abs().requires_grad_()
rate_1d = torch.randn(1).abs().requires_grad_()
self.assertEqual(Exponential(rate).sample().size(), (5, 5))
self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
# Test add
def test_add_sub(self):
def helper(shape, alpha, op_name, inplace):
if op_name == "add":
op = torch.Tensor.add_ if inplace else torch.add
elif op_name == "sub":
op = torch.Tensor.sub_ if inplace else torch.sub
for dtype in [torch.float16, torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
mps_y = cpu_y.detach().clone().to('mps')
cpu_out = op(cpu_x, cpu_y, alpha=alpha)
mps_out = op(mps_x, mps_y, alpha=alpha)
# fp16 isn't accurate when alpha is passed
# TODO: remove or fix 'tol' when we fix problems with fp16
tol = 2e-3 if dtype is torch.float16 else None
self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol)
if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted.
# create a scalar tensor
cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
mps_s = cpu_s.detach().clone().to('mps')
# primary tensor is scalar
self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y))
# create a scalar tensor
cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
mps_s = cpu_s.detach().clone().to('mps')
# secondary tensor is scalar
self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol)
for op_name, inplace in product(["add", "sub"], [True, False]):
helper((), 0.0, op_name, inplace)
helper((2, 8, 4, 5), 0.0, op_name, inplace)
helper((2, 8, 4, 5), 0.1, op_name, inplace)
helper((2, 8, 4, 5), 1.0, op_name, inplace)
helper((2, 8, 3, 5), 0.1, op_name, inplace)
helper((2, 8, 3, 5), 0.2, op_name, inplace)
# Test add
def test_add_scalars(self):
def helper(alpha):
for dtype in [torch.float16, torch.float32]:
cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
out = torch.add(x, y, alpha=alpha)
# fp16 isn't accurate when alpha is passed
tol = 1e-3 if dtype is torch.float16 else None
self.assertEqual(out, cpu_out, rtol=tol, atol=tol)
helper(1.0)
helper(0.0)
helper(0.1)
helper(0.2)
# Test int32 tensor + int64 scalar add
# see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
x = torch.ones(4, dtype=torch.int32, device='mps')
self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
def test_types_binary_op(self):
# Float * Bool
cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
self.assertEqual(cpu_x, mps_x)
# Float * Int64
cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
self.assertEqual(cpu_y, mps_y)
def test_unary_ops(self):
def helper(shape, op):
for dtypef in [torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
self.assertEqual(op(cpu_x), op(mps_x))
for dtypei in [torch.int32, torch.int16]:
cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False)
mps_x = cpu_x.to('mps')
self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4)
# test slice
for dtypef in [torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
cpu_slice = cpu_x[:, ::2, :, :]
mps_slice = mps_x[:, ::2, :, :]
self.assertEqual(op(cpu_slice), op(mps_slice))
# test view
for dtypef in [torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
# create view of tensor by reducing the 3rd and 4th dimension
combined_dim = shape[-1] * shape[-2]
reshaped_dims = list(shape[:-2]) + [combined_dim]
cpu_view = cpu_x.view(*reshaped_dims)
mps_view = mps_x.view(*reshaped_dims)
self.assertEqual(op(cpu_view), op(mps_view))
helper((2, 8, 4, 5), torch.exp)
helper((2, 8, 3, 5), torch.exp2)
helper((2, 8, 3, 5), torch.expm1)
helper((2, 8, 3, 5), torch.log)
helper((2, 8, 3, 5), torch.cos)
helper((2, 8, 3, 5), torch.erfinv)
def test_non_dense_in_storage_unary_ops(self):
def helper(op):
for dtypef in [torch.float32]:
cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]))
for dtypei in [torch.int32, torch.int16, torch.int8]:
cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False)
mps_x = cpu_x.to('mps')
self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4)
helper(torch.exp)
helper(torch.exp2)
helper(torch.expm1)
helper(torch.log)
helper(torch.cos)
def test_unary_ops_storage_offset_strided(self):
def helper(shape, op, inplace, dtype=torch.float32):
# test in-place with storage_offset
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
y = op(mps_x[1])
cpu_y = op(cpu_x[1])
self.assertEqual(y, cpu_y)
# See https://github.com/pytorch/pytorch/issues/100764
if not inplace:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t()
mps_y = cpu_y.detach().clone().to('mps')
op(cpu_x, out=cpu_y)
op(mps_x, out=mps_y)
self.assertEqual(mps_y, cpu_y)
helper((5, 5), torch.exp, False)
helper((5, 5), torch.cos, False)
helper((5, 5), torch.neg, False)
helper((5, 5), torch.tanh, False)
helper((5, 5), torch.tanh_, True)
def test_atan2(self):
def helper(shape):
input_cpu = torch.randn(shape)
input_mps = input_cpu.detach().clone().to("mps")
other_cpu = torch.randn(shape)
other_mps = other_cpu.detach().clone().to("mps")
atan2_cpu = torch.atan2(input_cpu, other_cpu)
atan2_mps = torch.atan2(input_mps, other_mps)
self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
helper(4)
helper(10000)
helper((10000, 40))
def test_multinomial(self):
# Test with num_dist = 1
def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
if (not replacement):
print(mps_out.to('cpu'))
else:
# Compare "real" with theoretical values
print(mps_out.to('cpu').float().mean(), compare_mean)
print(mps_out.to('cpu').float().std() ** 2, compare_var)
# TODO: Add tests for data types
helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)
def test_cumsum_dim_check(self):
x = torch.rand((3, 3), device="mps")
self.assertEqual(x.cumsum(1), x.cumsum(-1))
self.assertEqual(x.cumsum(0), x.cumsum(-2))
self.assertRaises(IndexError, lambda: x.cumsum(2))
self.assertRaises(IndexError, lambda: x.cumsum(-3))
def test_cumprod_dim_check(self):
x = torch.rand((3, 3), device="mps")
self.assertEqual(x.cumprod(1), x.cumprod(-1))
self.assertEqual(x.cumprod(0), x.cumprod(-2))
self.assertRaises(IndexError, lambda: x.cumprod(2))
self.assertRaises(IndexError, lambda: x.cumprod(-3))
class TestLogical(TestCaseMPS):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
def test_logical_not(self):
def helper(x):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
result = torch.logical_not(x)
result_cpu = torch.logical_not(cpu_x)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]))
helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True))
helper(self._wrap_tensor([True, True, False, False]))
helper(self._wrap_tensor(1))
helper(self._wrap_tensor(0))
helper(self._wrap_tensor(True))
helper(self._wrap_tensor(False))
def test_logical_and(self):
def helper(x, other):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
cpu_other = other
other = cpu_other.detach().clone().to('mps')
result = torch.logical_and(x, other)
result_cpu = torch.logical_and(cpu_x, cpu_other)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
helper(
self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
)
helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
def test_logical_or(self):
def helper(x, other):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
cpu_other = other
other = cpu_other.detach().clone().to('mps')
result = torch.logical_or(x, other)
result_cpu = torch.logical_or(cpu_x, cpu_other)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
helper(
self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
)
helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
def test_logical_xor(self):
def helper(x, other):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
cpu_other = other
other = cpu_other.detach().clone().to('mps')
result = torch.logical_xor(x, other)
result_cpu = torch.logical_xor(cpu_x, cpu_other)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
helper(
self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
)
helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
def test_min_max(self):
def helper(dtype):
for _ in range(10):
if dtype == torch.float32 or dtype == torch.float16:
x = torch.randn((30, 15), device='mps', dtype=dtype)
else:
x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
x_cpu = x.to("cpu")
y = x.max()
y_cpu = x_cpu.max()
self.assertEqual(y, y_cpu)
z = x.min()
z_cpu = x_cpu.min()
self.assertEqual(z, z_cpu)
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
def test_min_max_nan_propagation(self):
def helper(dtype):
cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
mps_x = cpu_x.detach().clone().to('mps')
cpu_max = torch.max(cpu_x)
mps_max = torch.max(mps_x).to('cpu')
cpu_amax = torch.amax(cpu_x)
mps_amax = torch.amax(mps_x).to('cpu')
cpu_min = torch.min(cpu_x)
mps_min = torch.min(mps_x).to('cpu')
cpu_amin = torch.amin(cpu_x)
mps_amin = torch.amin(mps_x).to('cpu')
self.assertEqual(cpu_max, mps_max)
self.assertEqual(cpu_amax, mps_amax)
self.assertEqual(cpu_min, mps_min)
self.assertEqual(cpu_amin, mps_amin)
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
def test_isin(self):
def helper(dtype):
shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),
([5], [10]), ([0], [5]), ([5], [0])]
for shape_tuple in shapes:
for inverted in [True, False]:
if dtype.is_floating_point:
# Half is not supported for CPU isin. Compute reference in FP32
A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32)
B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32)
else:
A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype)
B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype)
A_mps = A.clone().detach().to('mps')
B_mps = B.clone().detach().to('mps')
cpu_ref = torch.isin(A, B, invert=inverted)
if dtype in [torch.float16, torch.bfloat16]:
cpu_ref.type(dtype)
mps_out = torch.isin(A_mps, B_mps, invert=inverted)
self.assertEqual(mps_out, cpu_ref)
dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8]
if product_version < 14.0:
# Int types expected to fail on MacOS < 14.0
dtypes = [torch.float32, torch.float16, torch.bfloat16]
[helper(dtype) for dtype in dtypes]
def test_isin_asserts(self):
A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16)
with self.assertRaisesRegex(RuntimeError, 'Expected elements.dtype()*'):
out = torch.isin(A, B)
C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32)
with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'):
out = torch.isin(C, D)
class TestSmoothL1Loss(TestCaseMPS):
def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
# CPU
input_cpu = torch.randn(4, 7, requires_grad=requires_grad)
target_cpu = torch.randn(4, 7)
# MPS
input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
target_mps = target_cpu.detach().clone().to('mps')
smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
if requires_grad:
smooth_l1_loss_cpu.backward()
smooth_l1_loss_mps.backward()
self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
return smooth_l1_loss_cpu, smooth_l1_loss_mps
def test_smooth_l1_loss_reduction_none(self):
self._smooth_l1_loss_helper(reduction="none")
def test_smooth_l1_loss_reduction_mean(self):
self._smooth_l1_loss_helper(reduction="mean")
def test_smooth_l1_loss_reduction_sum(self):
self._smooth_l1_loss_helper(reduction="sum")
def test_smooth_l1_loss_reduction_mean_backward(self):
self._smooth_l1_loss_helper(reduction="mean", requires_grad=True)
def test_smooth_l1_loss_reduction_mean_sum_backward(self):
self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
class TestNLLLoss(TestCaseMPS):
def test_nll_loss_mismatched_batch(self, device='mps'):
x = torch.randn((10, 3), requires_grad=True, device=device)
# t should have size (10,)
t = torch.zeros((3,), dtype=torch.int64, device=device)
with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
F.nll_loss(x, t)
def test_nll_loss_out_of_bounds_ignore_index(self):
def test_nll_loss_out_of_bounds_ignore_index_helper(device):
output = []
x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device)
for reduction in ['mean', 'none']:
# out of bound ignore_index
output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction))
# default ignore_index
output.append(F.nll_loss(x, t2, reduction=reduction))
return output
output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu')
output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps')
for cpu, mps in zip(output_cpu, output_mps):
self.assertEqual(cpu, mps)
def test_nll_loss_invalid_target_dim(self):
def _test_nll_loss_invalid_target_dim(device):
output = []
x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
t = torch.zeros((6, 2), dtype=torch.int64, device=device)
with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
F.nll_loss(x, t)
_test_nll_loss_invalid_target_dim(device='cpu')
_test_nll_loss_invalid_target_dim(device='mps')
def test_nll_loss_invalid_weights(self):
def _test_nll_loss_invalid_weights(device):
x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
invalid_weights = [
torch.zeros(4, device=device),
torch.zeros((1, 3), device=device),
]
msg = "weight tensor should be defined either for all 3 classes or no classes"
for weight in invalid_weights:
with self.assertRaisesRegex(RuntimeError, msg):
F.nll_loss(x, t, weight=weight)
_test_nll_loss_invalid_weights(device='cpu')
_test_nll_loss_invalid_weights(device='mps')
def _nll_loss_helper(self, input_size, reduction, expected):
# CPU
input = torch.rand(input_size, requires_grad=True, device='cpu')
num_channels = input_size[1]
target_size = (input_size[0], ) + tuple(input_size[2:])
target = torch.randint(num_channels, target_size, device='cpu')
weights = torch.randn(num_channels)
# MPS
input_mps = input.detach().clone().to('mps').requires_grad_()
target_mps = target.detach().clone().to('mps')
weights_mps = weights.to("mps")
output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
self.assertEqual(output_cpu, output_mps.to('cpu'))
output_cpu.sum().backward()
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
def _nll_loss_1d_helper(self, input_size, reduction):
# CPU
input = torch.rand(input_size, requires_grad=True, device='cpu')
num_channels = input_size[0]
target = torch.randint(num_channels, [], device='cpu')
# MPS
input_mps = input.detach().clone().to('mps').requires_grad_()
target_mps = target.detach().clone().to('mps')
output_cpu = F.nll_loss(input, target, reduction=reduction)
output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
self.assertEqual(output_cpu, output_mps.to('cpu'))
output_cpu.sum().backward()
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
def test_nll_loss_1d(self, device='cpu'):
self._nll_loss_1d_helper([10], "none")
self._nll_loss_1d_helper([10], "mean")
self._nll_loss_1d_helper([10], "sum")
def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
nan = torch.tensor(float('nan'), device=device)
self._nll_loss_helper([1, 3], "mean", nan)
self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
zero = torch.tensor(0, device=device)
self._nll_loss_helper([1, 3], "sum", zero)
self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
def test_nll_loss_byte_target_matches_long(self, device='cpu'):
N, C = 10, 4
input = torch.randn(N, C, device=device, requires_grad=True)
target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
def compute_result_and_gradient(reduction, target_dtype):
result, grad = {}, {}
for dev in ['cpu', 'mps']:
input_dev = input.to(dev)
input_ = input_dev.detach()
input_.requires_grad_()
target_dev = target.to(dev)
prob = F.log_softmax(input_, dim=-1)
loss = nn.NLLLoss(reduction=reduction)
result[dev] = loss(prob, target_dev.to(target_dtype))
result[dev].sum().backward()
grad[dev] = input_.grad
return result, grad
for reduction in ["none", "mean", "sum"]:
result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
class TestTopK(TestCase):
def _test_topk(self, shape, largest):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
if isinstance(shape, tuple):
for curr_dim, dim_size in enumerate(shape):
for k in range(1, dim_size + 1):
topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)
else:
for k in range(1, shape):
topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)
def test_topk(self):
largest_vals = [True, False]
shapes = [
# Zero Element Tensors
0,
(1, 0),
(0, 1),
(1, 0, 1),
# Multiple Element Tensors
1,
2,
(5, 1),
(1, 5),
(5, 9, 7, 4),
]
for shape in shapes:
for largest_val in largest_vals:
with self.subTest(shape=shape, largest_val=largest_val):
self._test_topk(shape, largest_val)
class TestNNMPS(NNTestCase):
def _create_basic_net(self):
class Layer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer_dummy_param = Parameter(torch.empty(3, 5))
self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7))
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.l1 = Layer()
self.dummy_param = Parameter(torch.empty(3, 5))
self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1))
l = Layer()
n = Net()
s = nn.Sequential(n, n)
return l, n, s
def test_requires_grad_(self):
m = self._create_basic_net()[-1]
assert len(list(m.buffers())) > 0, 'invalid test'
assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
assert len(list(m.parameters())) > 0, 'invalid test'
assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
for requires_grad in (False, True):
self.assertIs(m.requires_grad_(requires_grad), m)
for p in m.parameters():
self.assertEqual(p.requires_grad, requires_grad)
for b in m.buffers():
self.assertFalse(b.requires_grad)
def test_module_backcompat(self):
from torch.serialization import SourceChangeWarning
path = download_file('https://download.pytorch.org/test_data/linear.pt')
with warnings.catch_warnings():
warnings.simplefilter('ignore', SourceChangeWarning)
m = torch.load(path)
input = torch.randn(2, 3, dtype=torch.float)
self.assertEqual(m(input).size(), (2, 5))
def test_conv_backcompat(self):
from torch.serialization import SourceChangeWarning
# This file was generated by running on PyTorch 1.0.1 on Python 2:
#
# import torch
# from torch import nn
# m = nn.Conv2d(1, 1, 1)
# torch.save(m, 'legacy_conv2d.pt')
#
# NB: This Pickle also contains some Unicode data!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
with warnings.catch_warnings():
warnings.simplefilter('ignore', SourceChangeWarning)
m = torch.load(path, encoding='utf-8')
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
self.assertEqual(m(input).size(), (1, 1, 1, 1))
def test_conv_expand(self):
device = 'mps'
input_ = torch.rand(2, 3, 16, 16, device=device)
kernel = torch.rand(1, 1, 3, 11, device=device)
tmp_kernel = kernel.expand(-1, 3, -1, -1)
output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
# The test should not crash
def test_permute(self):
M_cpu = torch.randn(5, 5)
M_mps = M_cpu.to('mps')
output_cpu = M_cpu.permute(1, 0)
output_mps = M_mps.permute(1, 0)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
# Printing of non_contiguous should not crash
def test_print_non_contiguous(self):
print(torch.ones(100, 100, device='mps').nonzero())
print(torch.ones(100, 100, device='mps').nonzero().contiguous())
def test_zero_grad(self):
i = torch.randn(2, 5, requires_grad=True)
module = nn.Linear(5, 5)
for p in module.parameters():
p.requires_grad = False
module.zero_grad()
module.weight.requires_grad = True
module.zero_grad()
self.assertIsNone(module.weight.grad) # uninitialized grad
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
module.zero_grad()
self.assertIsNone(module.weight.grad)
module.bias.requires_grad = True
module.zero_grad()
self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertIsNotNone(module.bias.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
# Force set to zeros.
module.zero_grad(set_to_none=False)
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
module.zero_grad()
self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
input = torch.randn(1, 2, 10, 10).to(dtype)
x = input
y = input.clone()
output = module(x)
self.assertTrue(output.requires_grad)
output.backward(torch.ones(1, 5, 10, 10))
with torch.no_grad():
output2 = module(y)
self.assertFalse(output2.requires_grad)
self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
def test_invalid_conv1d(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
input = torch.randn(1, 3, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError,
r'Calculated padded input size per channel: \(4\). ' +
r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
module(input)
# Negative stride check
module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
input = torch.randn(1, 3, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
module(input)
def test_conv2d_discontiguous_weight(self):
# Test for https://github.com/pytorch/pytorch/issues/55781
x = torch.ones(64, 16, 16, 16)
weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
self.assertFalse(weight.is_contiguous())
y = torch.nn.functional.conv2d(x, weight, None)
if torch.backends.mkldnn.is_available():
# Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
with torch.backends.mkldnn.flags(enabled=False):
y_ = torch.nn.functional.conv2d(x, weight, None)
self.assertEqual(y, y_)
self.assertEqual(y.sum(), 4186112.)
def test_invalid_conv2d(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
input = torch.empty(1, 1, 4, 4).to(dtype)
self.assertRaises(RuntimeError, lambda: module(input))
module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
input = torch.randn(1, 3, 1, 1)
with self.assertRaisesRegex(RuntimeError,
r'Calculated padded input size per channel: \(1 x 1\). ' +
r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
module(input)
# Negative stride check
module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
input = torch.randn(1, 3, 4, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
module(input)
# Zero stride check
module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
input = torch.randn(1, 3, 4, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
module(input)
# Input and weights on different devices
self.assertRaisesRegex(RuntimeError,
'must be on the same device',
lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
self.assertRaisesRegex(RuntimeError,
'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same',
lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
def test_conv2d_valid_padding(self, device='mps'):
# Test F.conv2d padding='valid' is the same as no padding
x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
expect = F.conv2d(x, y)
actual = F.conv2d(x, y, padding='valid')
self.assertEqual(expect.to('cpu'), actual.to('cpu'))
def test_conv2d_backward_collision(self):
# Test for https://github.com/pytorch/pytorch/issues/112998
x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True)
m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps")
m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps")
y1, y2 = m1(x), m2(x)
self.assertEqual(y1.shape, y2.shape)
y1.sum().backward()
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
y2.sum().backward()
@unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
def test_conv3d_backward_collision(self):
# Conv3D is only available from MacOS 13.2 onwards
x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
y1, y2 = m1(x), m2(x)
self.assertEqual(y1.shape, y2.shape)
y1.sum().backward()
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
y2.sum().backward()
def test_gemm_permute_transpose(self):
batch_size = 32
n = 20
hidden = 768
num_attention_heads = 12
attention_head_size = hidden // num_attention_heads
def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def attention2(key, *, workaround=False, device):
key = transpose_for_scores(key)
res = key.transpose(-1, -2)
return res
A = torch.randn(batch_size, n, hidden)
A_mps = A.detach().clone().to("mps")
r1 = attention2(A, device="cpu")
r2 = attention2(A_mps, device="mps")
r2_cpu = r2.to("cpu")
self.assertEqual(r1, r2_cpu)
def test_group_norm_backward(self, device='mps'):
# See https://github.com/pytorch/pytorch/issues/88331 for more detail
shape = [1, 4, 16, 16]
x = torch.full(shape, 7.0, device=device)
target = torch.ones((1, 3, 128, 128), device=device)
conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device)
with torch.enable_grad():
x = x.detach().requires_grad_()
out = 5.5 * x
out = conv_in(out)
out = out + norm(out)
out = out + norm(out)
out = out + norm(out)
out = F.interpolate(out, scale_factor=8.0, mode="nearest")
out = norm(out)
out = conv_out(out)
loss = (out - target).norm(dim=-1).sum()
grad = -torch.autograd.grad(loss, x)[0]
self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd')
# def test_conv2d_same_padding(self, device='mps'):
# x = torch.rand(1, 1, 10, 11, device=device)
# y = torch.rand(1, 1, 4, 5, device=device)
# expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
# actual = F.conv2d(x, y, padding='same')
# self.assertEqual(expect.to('cpu'), actual.to('cpu'))
# # With dilation
# y = torch.rand(1, 1, 3, 4, device=device)
# expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
# actual = F.conv2d(x, y, padding='same', dilation=2)
# self.assertEqual(expect, actual)
# # Dilation with asymmetric padding
# y = torch.rand(1, 1, 4, 4, device=device)
# expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
# actual = F.conv2d(x, y, padding='same', dilation=3)
# self.assertEqual(expect, actual)
class TestPad(TestCaseMPS):
def test_constant_pad(self):
m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
input_cpu = torch.randn(1, 16, 16, 16)
input_mps = input_cpu.detach().clone().to("mps")
r_cpu = m(input_cpu)
r_mps = m(input_mps)
self.assertEqual(r_cpu, r_mps.to("cpu"))
# Arbitrary input dimensions
pad = (1, 1, 0, 0, 0, 0)
value = 3.5
input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
input_mps = input_cpu.detach().clone().to("mps")
r_cpu = F.pad(input_cpu, pad=pad, value=value)
r_mps = F.pad(input_mps, pad=pad, value=value)
self.assertEqual(r_cpu, r_mps.to("cpu"))
def test_circular_pad(self):
# https://github.com/pytorch/pytorch/issues/80856
k_cpu = torch.ones(3, 3, 9, 9)
k_mps = k_cpu.detach().clone().to("mps")
x_cpu = torch.rand(1, 3, 32, 32)
x_mps = x_cpu.detach().clone().to("mps")
x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular')
x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular')
y_cpu = F.conv2d(x_pad_cpu, k_cpu)
y_mps = F.conv2d(x_pad_mps, k_mps)
self.assertEqual(y_cpu, y_mps.cpu())
def test_constant_pad_4d_warning(self):
inputCPU = torch.rand((1, 2, 2, 2, 1, 1))
inputMPS = inputCPU.detach().clone().to('mps')
outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0])
outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0])
self.assertEqual(outputCPU, outputMPS)
def test_pad(self):
def helper(shape, padding, op, value=0):
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
padCriteria = op(padding, value)
else:
padCriteria = op(padding)
outputCPU = padCriteria(inputCPU)
outputMPS = padCriteria(inputMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass (chose 0.6 just to have the grad_output != 1)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
self.assertEqual(inputCPU.grad, inputMPS.grad)
# 1D Padding
helper((2, 4, 3), 2, nn.ReflectionPad1d)
# verify if a change in shape of input would cause problems with graph caching
helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
# Replication 1D
helper((2, 1, 6), 3, nn.ReplicationPad1d)
# Constant Pad 1D
helper((2, 3, 4), 2, nn.ConstantPad1d)
# Constant Pad 1D with single dimension input
helper((16), (1, 2), nn.ConstantPad1d)
# 2D Padding
helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
# verify if a change in shape of input would cause problems with graph caching
helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
# this should make the padding (2, 2, 2, 2)
helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
# Constant Pad 2D
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
# input size < pad size
helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
# pad dims < input dims
helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
# pad dims == input dims
helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
# input.numel() == 0 but output.numel() > 0
helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
# pad dims < input dims - 2
helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
# 3D Padding
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
# case where input_d == pad_front/back for ReplicationPad3d
helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
# Constant Pad 3D
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# input size < pad size
helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# check the workaround for the right padding bug in Monterey
helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
def test_constant_pad_nd_preserves_memory_format(self):
nchw_tensor = torch.rand((1, 2, 5, 3))
nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
class TestLinalgMPS(TestCaseMPS):
def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
dtype = t.dtype
numpy_dtype = dtype
alpha = 1.2 if alpha is None else alpha
beta = 0.8 if beta is None else beta
res1 = f(t, m, v, alpha=alpha, beta=beta)
res2 = torch.full_like(res1, math.nan)
if transpose_out:
res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
f(t, m, v, alpha=alpha, beta=beta, out=res2)
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
if beta != 0:
res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
res3 = torch.from_numpy(res3).to(dtype)
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)
def test_addmm(self, device="mps", dtype=torch.float32):
M = torch.randn(10, 25, device=device).to(dtype)
m1 = torch.randn(10, 50, device=device).to(dtype)
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2)
# Test beta=0, M=nan
M = torch.full((10, 25), math.nan, device=device).to(dtype)
m1 = torch.randn(10, 50, device=device).to(dtype)
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
# Test transpose
for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
def maybe_transpose(cond, m):
if not cond:
return m
return m.t().clone(memory_format=torch.contiguous_format).t()
M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
def _test_addr(self, f, t, m, v, alpha=None, beta=None):
dtype = t.dtype
numpy_dtype = dtype
alpha = 1.2 if alpha is None else alpha
beta = 0.8 if beta is None else beta
res1 = f(t, m, v, alpha=alpha, beta=beta)
res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
if beta != 0:
res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
res2 = torch.from_numpy(res2).to(dtype)
self.assertEqual(res1, res2)
def test_addr(self, device="mps", dtype=torch.float32):
M = torch.randn(10, 25, device=device).to(dtype)
m1 = torch.randn(10, device=device).to(dtype)
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2)
# Test beta=0, M=nan
M = torch.full((10, 25), math.nan, device=device).to(dtype)
m1 = torch.randn(10, device=device).to(dtype)
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2, beta=0)
def test_matrix_rank(self, device="mps", dtype=torch.float32):
matrix_rank = torch.linalg.matrix_rank
def run_test(shape0, shape1, batch):
a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
rank_a = matrix_rank(a)
self.assertEqual(rank_a, matrix_rank(a.mH))
aaH = torch.matmul(a, a.mH)
rank_aaH = matrix_rank(aaH)
rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
self.assertEqual(rank_aaH, rank_aaH_hermitian)
aHa = torch.matmul(a.mH, a)
self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
# check against NumPy
self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
# hermitian flag for NumPy was added in 1.14.0
if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
self.assertEqual(rank_aaH_hermitian,
np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
self.assertEqual(matrix_rank(aaH, 0.01, True),
np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
# check out= variant
out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
ans = matrix_rank(a, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, rank_a)
shapes = (3, 13)
batches = ((), (0, ), (4, ), (3, 5, ))
for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
# escape only when NotImplementedError of downstream function is raised
# TODO: remove this once the required function is implemented
try:
run_test(shape0, shape1, batch)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
raise e
def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
def run_test_main(A, hermitian):
# Testing against definition for pseudo-inverses
A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
np_A = A.cpu().numpy()
np_A_pinv = A_pinv.cpu().numpy()
if A.numel() > 0:
self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
else:
self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
# Check out= variant
out = torch.empty_like(A_pinv)
ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, A_pinv)
def run_test_numpy(A, hermitian):
# Check against NumPy output
# Test float rcond, and specific value for each matrix
rconds = [float(torch.rand(1)), ]
# Test different types of rcond tensor
for rcond_type in MPS_DTYPES:
rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
# Test broadcasting of rcond
if A.ndim > 2:
rconds.append(torch.rand(A.shape[-3], device=device))
for rcond in rconds:
actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
self.assertEqual(actual, expected, atol=precision, rtol=precision)
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
(3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices
(2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices
(0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices
A = torch.randn(*sizes, dtype=dtype, device=device)
hermitian = False
run_test_main(A, hermitian)
run_test_numpy(A, hermitian)
# Check hermitian = True
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
(0, 0), (3, 0, 0), ]: # zero numel square matrices
A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
hermitian = True
# escape only when NotImplementedError of downstream function is raised
# TODO: remove this once the required function is implemented
try:
run_test_main(A, hermitian)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
raise e
try:
run_test_numpy(A, hermitian)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
raise e
@parametrize("m", [1, 32, 64])
@parametrize("n", [48, 64])
@parametrize("q_group", [32, 64, 128, 256])
@parametrize("num_groups", [1, 2])
def test__int4_mm(self, m, n, q_group, num_groups):
k = q_group * num_groups
inner_k_tiles = 2
torch.manual_seed(1)
a_f32 = torch.rand((m, k), device="mps")
b_f32 = torch.rand((k, n), device="mps")
def convert_weight_to_int4pack(b):
b_int32, b_scales_and_zeros = _group_quantize_tensor(
b.to("cpu"), n_bit=4, q_group_size=q_group
)
b_int32 = b_int32.to("mps")
b_scales_and_zeros = b_scales_and_zeros.to("mps")
b_int4pack = torch._convert_weight_to_int4pack(
b_int32, inner_k_tiles
)
return b_int4pack, b_scales_and_zeros
def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
return torch._weight_int4pack_mm(
a, b_int4pack, q_group, b_scales_and_zeros
)
b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
a = a_f32.to(dtype=dtype)
b = b_f32.to(dtype=dtype)
b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
ref = torch.mm(a, b)
res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
mean_err = ((res - ref).abs() / ref).mean()
self.assertLess(mean_err, 0.05)
@parametrize("m", [1, 32, 64])
@parametrize("k", [32, 64])
@parametrize("n", [32, 64])
def test__int8_mm(self, m, k, n):
torch.manual_seed(1)
a_f32 = torch.rand((m, k), device="mps")
b_f32 = torch.rand((n, k), device="mps")
def convert_weight_to_int8pack(b):
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
return b_int8pack, b_scales
def weight_int8pack_mm(a, b_int8pack, b_scales):
return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32)
for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
a = a_f32.to(dtype=dtype)
b = b_f32.to(dtype=dtype)
b_scales = b_scales_f32.to(dtype=dtype)
res = weight_int8pack_mm(a, b_int8pack, b_scales)
ref = torch.mm(a, b.transpose(0, 1))
mean_err = ((res - ref).abs() / ref).mean()
self.assertLess(mean_err, 0.05)
class TestSDPA(TestCaseMPS):
def _compare_tensors(self, y, ref):
denom = torch.maximum(ref.abs(), torch.tensor([1e-6], device=ref.device, dtype=ref.dtype))
err = ((y - ref).abs() / denom).mean().item()
self.assertLess(err, 0.01)
def _test_sdpa_no_mask(
self,
is_causal: bool,
dtype: torch.dtype,
L: int = 1,
S: int = 72,
NH: int = 32,
HS: int = 128,
requires_grad: bool = False
):
torch.manual_seed(1729)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad)
k_cpu = k.cpu()
v_cpu = v.cpu()
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal)
self._compare_tensors(y.cpu(), y_ref)
if requires_grad and torch.is_grad_enabled():
y.sum().backward()
y_ref.sum().backward()
self._compare_tensors(q.grad.cpu(), q_cpu.grad)
def test_sdpa_no_mask_no_causal_fp32(self):
self._test_sdpa_no_mask(False, torch.float32)
def test_sdpa_no_mask_no_causal_fp16(self):
self._test_sdpa_no_mask(False, torch.float16)
def test_sdpa_no_mask_causal_fp32(self):
self._test_sdpa_no_mask(True, torch.float32)
def test_sdpa_no_mask_causal_fp16(self):
self._test_sdpa_no_mask(True, torch.float16)
def test_sdpa_no_mask_causal_fp16_L7(self):
self._test_sdpa_no_mask(True, torch.float16, 7)
def test_sdpa_no_mask_causal_fp16_L7_S17(self):
self._test_sdpa_no_mask(True, torch.float16, 7, 17)
def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self):
self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
def test_sdpa_no_mask_no_causal_fp32_grad(self):
self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
with torch.no_grad():
self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
torch.manual_seed(1729)
causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
i = 42
q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
input_pos = torch.tensor([i], dtype=torch.int32, device='mps')
mask = causal_mask[None, None, input_pos]
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
self._compare_tensors(y.cpu(), y_ref)
def test_sdpa_mask_fp32(self):
self._test_sdpa_mask(torch.float32)
def test_sdpa_mask_fp16(self):
self._test_sdpa_mask(torch.float16)
def test_sdpa_mask_fp16_L6(self):
self._test_sdpa_mask(torch.float16, 6)
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
class TestGatherScatter(TestCaseMPS):
def test_slicing_with_step(self):
# Slicing with step
# https://github.com/pytorch/pytorch/issues/78886
x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
x_mps[::2] = 1.0
x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
x_cpu[::2] = 1.0
self.assertEqual(x_cpu, x_mps)
def test_cast_gather_scatter(self):
for _ in range(0, 50):
input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8)
with torch.no_grad():
s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0)
s = s.long()
s_cpu = s_cpu.long()
self.assertEqual(s.cpu(), s_cpu)
s = s.float()
s_cpu = s_cpu.float()
self.assertEqual(s.cpu(), s_cpu)
s /= 255
s_cpu /= 255
self.assertEqual(s.cpu(), s_cpu)
def test_slicing_replace_column(self):
# https://github.com/pytorch/pytorch/issues/78074
def _helper(tensor_data):
x_cpu = torch.tensor(tensor_data)
x_mps = x_cpu.to('mps')
x_cpu[:, 0] = 7
x_mps[:, 0] = 7
self.assertEqual(x_cpu, x_mps)
_helper([[1, 2, 3], [4, 5, 6]])
_helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
_helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
def test_inplace_scatter(self):
# https://github.com/pytorch/pytorch/issues/79672
a_mps = torch.ones((2, 2),).to(torch.device("mps"))
b_mps = torch.ones((2, 2),).to(torch.device("mps"))
a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
a_mps[:, 0] += b_mps[:, 0]
a_cpu[:, 0] += b_cpu[:, 0]
self.assertEqual(a_cpu, a_mps)
a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
self.assertEqual(a_cpu, a_mps)
# These tests were taken from test/test_view_ops.py
# They are subset of those tests as currently only this subset is working.
# This whole `class` will be removed when we add generic device testing. There
# are no additional tests added apart from what is part of test_view_ops.py
class TestViewOpsMPS(TestCaseMPS):
exact_dtype = True
def test_permute_slicing(self):
# test the fix for crash reported in
# https://github.com/pytorch/pytorch/issues/94190
cpu_x = (torch.randn([3, 2, 2]).float())
mps_x = cpu_x.detach().clone().to('mps')
cpu_out = cpu_x.permute((2, 0, 1)) * 2.0
mps_out = mps_x.permute((2, 0, 1)) * 2.0
# this print caused a crash prior to fix PR#94259
print(torch.zeros_like(mps_out))
# test the fix for fill_scalar_mps() mentioned in issue #94190
self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out))
self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1))
def is_view_of(self, base, other):
if (not other._is_view() or
other is base or
other._base is not base or
base.device != other.device):
return False
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
if base.device.type == 'mps':
if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
return False
return True
# Returns true if v1 and v2 are views of the same base
def is_view_of_same_base(self, v1, v2):
if (not v1._is_view() or v1 is v2):
return False
return self.is_view_of(v1._base, v2)
# Performs transpose if contiguous=True, else returns the input tensor as is
def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
if contiguous:
return x
else:
return x.transpose(dim0, dim1)
def test_diagonal_view(self, device="mps"):
t = torch.ones((5, 5), device=device)
v = torch.diagonal(t)
self.assertTrue(self.is_view_of(t, v))
v[0] = 0
self.assertEqual(t[0, 0], v[0])
t = torch.ones((3, 3, 3), device="mps")
v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 0, 1], v[0, 0])
def test_select_view(self, device="mps") -> None:
t = torch.ones((5, 5), device=device)
v = t.select(0, 2)
self.assertTrue(self.is_view_of(t, v))
v[0] = 0
self.assertEqual(t[2, 0], v[0])
def test_unbind_view(self, device="mps") -> None:
t = torch.zeros((5, 5), device=device)
tup = torch.unbind(t)
for idx, v in enumerate(tup):
self.assertTrue(self.is_view_of(t, v))
v[0] = idx + 1
self.assertEqual(t[idx, 0], v[0])
def test_expand_view(self, device="mps") -> None:
t = torch.ones((5, 1), device=device)
v = t.expand(5, 5)
self.assertTrue(self.is_view_of(t, v))
v[2, 2] = 0
self.assertEqual(t[2, 0], v[2, 2])
def test_expand_as_view(self, device="mps"):
t = torch.ones((5, 1), device=device)
e = torch.empty((5, 5), device=device)
v = t.expand_as(e)
self.assertTrue(self.is_view_of(t, v))
v[2, 2] = 0
self.assertEqual(t[2, 0], v[2, 2])
def test_narrow_view(self, device="mps"):
t = torch.ones((5, 5), device=device)
v = torch.narrow(t, 1, 2, 2)
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 2], v[0, 0])
def test_permute_view(self, device="mps") -> None:
t = torch.ones((5, 5), device=device)
v = t.permute(1, 0)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_transpose_view(self, device="mps"):
for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
t = torch.ones((5, 5), device=device)
v = fn(t, 0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_transpose_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.swapdims_(0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.swapaxes_(0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.transpose_(0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_t_view(self, device="mps"):
t = torch.ones((5, 5), device=device)
v = t.t()
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_inplace_view_add(self):
# https://github.com/pytorch/pytorch/issues/96153
t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3)
t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3)
t_mps = t_mps + 1
t_cpu = t_cpu + 1
self.assertEqual(t_mps, t_cpu)
def test_t_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.t_()
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_T_view(self, device="mps"):
for op in ("T", "H", "mT", "mH"):
t = torch.ones((5, 5), device=device)
v = getattr(t, op)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_unfold_view(self, device="mps"):
t = torch.ones(10, device=device)
v = t.unfold(0, 3, 2)
self.assertTrue(self.is_view_of(t, v))
v[1, 0] = 0
self.assertEqual(t[2], v[1, 0])
def test_squeeze_view(self, device="mps"):
t = torch.ones(5, 1, 5, device=device)
v = torch.squeeze(t)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertIs(t, v._base)
def test_squeeze_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.squeeze_()
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertIs(t, v._base)
def test_unsqueeze_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = torch.unsqueeze(t, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 0, 1] = 0
self.assertEqual(t[0, 1], v[0, 0, 1])
def test_unsqueeze_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.unsqueeze_(1)
self.assertTrue(self.is_view_of(t, v))
v[0, 0, 1] = 0
self.assertEqual(t[0, 1], v[0, 0, 1])
def test_as_strided_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = torch.as_strided(t, (25,), (1,))
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_as_strided_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.as_strided_((25,), (1,))
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_view_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view(25)
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_view_as_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
e = torch.empty((25,))
v = t.view_as(e)
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_contiguous_self(self, device="mps"):
t = torch.ones(5, 5, device=device)
s = t.contiguous()
self.assertIs(s, t)
def test_contiguous_nonview(self, device="mps"):
t = torch.ones(5, 5, device=device)
nv = t.t().contiguous()
self.assertFalse(self.is_view_of(t, nv))
nv[0, 0] = 0
self.assertNotEqual(t[0, 0], nv[0, 0])
def test_reshape_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = torch.reshape(t, (25,))
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_reshape_as_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
e = torch.empty((25,), device=device)
v = t.reshape_as(e)
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_reshape_nonview(self, device="mps"):
t = torch.ones(5, 5, device=device)
nv = torch.reshape(t.t(), (25,))
self.assertFalse(self.is_view_of(t, nv))
nv[6] = 0
self.assertNotEqual(t[1, 1], nv[6])
def test_flatten_view(self, device="mps"):
def test_writes_propagate(t, v):
idx_t = (0,) * t.ndim
idx_v = (0,) * v.ndim
v[idx_v] = 0
self.assertEqual(t[idx_t], v[idx_v])
t = torch.ones(1, 2, 3, 4, device=device)
v = t.flatten()
self.assertTrue(self.is_view_of(t, v))
test_writes_propagate(t, v)
# zero-dimensional tensor
t = torch.tensor(1, device=device)
v = t.flatten()
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of(t, v))
t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
v = t.flatten(0, 1)
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of_same_base(t, v))
# stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
t = torch.ones(720, device=device) \
.as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
# [--1--|---2---|-3-] [--1--|----2---|-3-]
v1 = t.flatten(0, 1)
v2 = v1.flatten(1, 3)
v3 = v2.flatten(2, 2)
test_writes_propagate(t, v1)
self.assertTrue(self.is_view_of_same_base(t, v1))
test_writes_propagate(t, v2)
self.assertTrue(self.is_view_of_same_base(t, v2))
test_writes_propagate(t, v3)
self.assertTrue(self.is_view_of_same_base(t, v3))
def test_flatten_nonview(self, device="mps"):
def assert_is_nonview(t, nv):
idx_t = (0,) * t.ndim
idx_nv = (0,) * nv.ndim
self.assertFalse(nv._is_view())
nv[idx_nv] = 0
self.assertNotEqual(t[idx_t], nv[idx_nv])
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
nv = t.flatten(1, 3)
assert_is_nonview(t, nv)
t = torch.ones(2, 2, device=device).T
nv = t.flatten()
assert_is_nonview(t, nv)
# flatten returns the original object if start_dim=end_dim
t = t = torch.ones(2, 2, device=device)
nv = t.flatten(1, 1)
self.assertIs(t, nv)
def test_basic_indexing_slice_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t[:2, :3]
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 0], v[0, 0])
def test_basic_indexing_ellipses_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t[..., :2]
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 0], v[0, 0])
def test_basic_indexing_newaxis_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t[None, :2, 3]
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 3], v[0, 0])
def test_chunk_view(self, device="mps"):
t = torch.zeros(3, 3, device=device)
l = torch.chunk(t, 3)
for idx, v in enumerate(l):
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = idx + 1
self.assertEqual(t[idx, 0], v[0, 0])
def test_split_view(self, device="mps"):
t = torch.zeros(3, 3, device=device)
l = torch.split(t, [1, 1, 1])
for idx, v in enumerate(l):
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = idx + 1
self.assertEqual(t[idx, 0], v[0, 0])
def test_movedim_view(self, device="mps"):
def run_test(device, op):
t = torch.zeros(3, 3, device=device)
out = op(t)
self.assertTrue(self.is_view_of(t, out))
# Randomly change values in output
# and verify that original is changed
# as well.
for _ in range(3):
idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
out[idx_1, idx_2] = random.random()
self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
for fn in [torch.movedim, torch.moveaxis]:
op = partial(fn, source=(0, 1), destination=(1, 0))
run_test(device, op)
op = partial(fn, source=0, destination=1)
run_test(device, op)
# Testing that the generated view_copy kernel and its derivative are implemented correctly
def test_view_copy(self, device="mps"):
a = torch.randn(4, device=device, requires_grad=True)
a_ref = a.clone().detach().requires_grad_()
a_view = a_ref.view(2, 2)
a_view_copy = torch.view_copy(a, (2, 2))
# view_copy ops don't preserve view relationship
self.assertTrue(self.is_view_of(a_ref, a_view))
self.assertFalse(self.is_view_of(a, a_view_copy))
a_view_copy.sum().backward()
a_view.sum().backward()
# forward and backward give the same shape + result
self.assertEqual(a_view_copy, a_view)
self.assertEqual(a.grad, a_ref.grad)
def test_view_copy_out(self, device="mps"):
a = torch.randn(2, 2, device=device)
out = torch.empty(2, device=device)
torch.diagonal_copy(a, out=out)
expected = torch.diagonal_copy(a)
self.assertEqual(expected, out)
a = torch.randn(4, device=device)
out1 = torch.empty(2, device=device)
out2 = torch.empty(2, device=device)
torch.split_copy(a, 2, out=(out1, out2))
expected1, expected2 = torch.split_copy(a, 2)
self.assertEqual(expected1, out1)
self.assertEqual(expected2, out2)
def test_detached_view_copy(self, device="mps"):
# https://github.com/pytorch/pytorch/issues/86052
x = torch.arange(2)
# .detach() makes y not a view, but contig tensor
# with non-zero offset
y = x[1].detach()
z = y.to(device)
self.assertEqual(y, z.cpu())
def test_empty_reshape(self, device="mps"):
x = torch.randn(0, 6, device=device)
self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
# should be viewable -- i.e. data_ptr is the same.
self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
# match NumPy semantics -- don't infer the size of dimension with a degree of freedom
self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
def test_expand(self, device="mps"):
tensor = torch.rand(1, 8, 1, device=device)
tensor2 = torch.rand(5, device=device)
template = torch.rand(4, 8, 5, device=device)
target = template.size()
self.assertEqual(tensor.expand_as(template).size(), target)
self.assertEqual(tensor.expand(4, 8, 5).size(), target)
self.assertEqual(tensor.expand(target).size(), target)
self.assertEqual(tensor2.expand_as(template).size(), target)
self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
self.assertEqual(tensor2.expand(target).size(), target)
# test double expand
self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
# test non-contiguous
noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
self.assertFalse(noncontig.is_contiguous())
self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
# make sure it's compatible with unsqueeze
expanded = tensor2.expand(1, 1, 5)
unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
self.assertEqual(expanded, unsqueezed)
self.assertEqual(expanded.stride(), unsqueezed.stride())
# test -1 as target size
self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
# test expanding empty to empty
self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
def test_view_empty(self, device="mps"):
x = torch.randn(0, 6, device=device)
self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
def test_reshape(self, device="mps"):
x = torch.randn(3, 3, device=device)
self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
y = torch.randn(4, 4, 4, device=device)[:, 0, :]
# .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
if device != "meta":
self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
s = torch.randn((), device=device)
self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
self.assertEqual(s.reshape(-1).shape, (1,))
self.assertRaises(RuntimeError, lambda: s.reshape(2))
empty = torch.tensor([], device=device)
self.assertEqual(empty, empty.reshape(-1))
self.assertEqual(empty, empty.reshape([0]))
# TODO: fix these once we have multi-dimensional empty tensors
self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
self.assertRaises(RuntimeError, lambda: empty.reshape(1))
x = torch.randn(3, 3, device=device)
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
def test_narrow(self, device="mps"):
x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
def test_narrow_tensor(self, device="mps"):
x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
with self.assertRaises(Exception):
x.narrow(0, torch.tensor(0.), 1)
with self.assertRaises(Exception):
x.narrow(0, torch.tensor([0]), 1)
with self.assertRaises(Exception):
x.narrow(0, torch.tensor([0, 1]), 1)
def test_t(self, device="mps"):
# Test 0D tensors
x = torch.randn(())
self.assertEqual(x, x.t())
x = x.to_sparse()
self.assertEqual(x, x.t())
# Test 1D tensors
x = torch.arange(4)
self.assertEqual(x, x.t())
x = x.to_sparse()
self.assertEqual(x, x.t())
# Test 2D tensors
x = torch.rand((2, 2))
self.assertEqual(x.t(), x.transpose(0, 1))
x = x.to_sparse()
self.assertEqual(x.t(), x.transpose(0, 1))
# Test 3D tensor
x = torch.rand((2, 2, 2))
with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
x.t()
x = x.to_sparse()
with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
x.t()
def test_split(self, device="mps"):
tensor = torch.rand(7, 4)
split_size = 3
dim = 0
target_sizes = ([3, 4], [3, 4], [1, 4])
splits = tensor.split(split_size, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
start = start + target_size[dim]
# Variable sections split
tensor = torch.randn(20, 10)
dim = 0
split_sizes = [5, 5, 10]
target_sizes = ([[5, 10], [5, 10], [10, 10]])
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
start = start + target_size[dim]
split_sizes = [2, 2, 6]
target_sizes = ([20, 2], [20, 2], [20, 6])
dim = 1
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
start = start + target_size[dim]
def test_chunk(self, device="mps"):
tensor = torch.rand(4, 7)
num_chunks = 3
dim = 1
target_sizes = ([4, 3], [4, 3], [4, 1])
splits = tensor.chunk(num_chunks, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
atol=0, rtol=0)
start = start + target_size[dim]
# Invalid chunk sizes
error_regex = 'chunk expects.*greater than 0'
with self.assertRaisesRegex(RuntimeError, error_regex):
tensor.chunk(0)
with self.assertRaisesRegex(RuntimeError, error_regex):
tensor.chunk(-2)
def test_unsqueeze(self, device="mps") -> None:
x = torch.randn(2, 3, 4)
y = x.unsqueeze(1)
self.assertEqual(y, x.view(2, 1, 3, 4))
y = x.clone().unsqueeze_(2)
self.assertEqual(y, x.view(2, 3, 1, 4))
x = x[:, 1]
self.assertFalse(x.is_contiguous())
y = x.unsqueeze(1)
self.assertEqual(y, x.contiguous().view(2, 1, 4))
y = x.clone().unsqueeze_(2)
self.assertEqual(y, x.contiguous().view(2, 4, 1))
# unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
def test_big_transpose(self, device="mps"):
t = torch.rand(456, 789, device=device)
t1 = t.t().contiguous()
t2 = torch.from_numpy(t.cpu().numpy().transpose())
self.assertEqual(t1, t2)
def test_T(self, device="mps"):
a = torch.randn(2, 3, 4, device=device)
t1 = a.T
t2 = a.permute(2, 1, 0)
self.assertEqual(t2, t1)
b = torch.randn(10, device=device)
self.assertEqual(b, b.T)
def test_transposes(self, device="mps", dtype=torch.float32):
for op in ("T", "H", "mT", "mH", "adjoint"):
shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
for shape in shapes:
a = make_tensor(shape, device=device, dtype=dtype)
t1 = getattr(a, op)
if op == "adjoint":
t1 = t1()
t2 = a
if a.ndim != 0:
t2 = t2.transpose(-2, -1)
if op[-1] == "H" or op == "adjoint":
t2 = t2.conj()
self.assertEqual(t2, t1)
def test_transposes_errors(self, device="mps", dtype=torch.float32):
for op in ("H", "mT", "mH", "adjoint"):
shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
for shape in shapes:
a = make_tensor(shape, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
t1 = getattr(a, op)
if op == "adjoint":
t1 = t1()
def test_python_types(self, device="mps"):
a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
self.assertEqual(a1.dtype, a2.dtype)
b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
b2 = torch.arange(10, 20, dtype=int, device=device)
self.assertEqual(b1.dtype, b2.dtype)
c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
c2 = torch.tensor([True, False], dtype=bool, device=device)
self.assertEqual(c1.dtype, c2.dtype)
# TODO: is resize best put in test_view_ops?
def test_resize_as_preserves_strides(self, device="mps"):
x = torch.empty(2, 3).t()
old_strides = x.stride()
x.resize_as_(x)
self.assertEqual(x.stride(), old_strides)
def test_memory_format_resize_as(self, device="mps"):
def test_helper(shape, memory_format, device="mps"):
xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
flat = torch.randn(xc.numel(), device=device)
flat.resize_as_(xc, memory_format=torch.preserve_format)
self.assertTrue(flat.is_contiguous(memory_format=memory_format))
test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
def test_memory_format_resize_(self, device="mps"):
def test_helper(shape, numel, memory_format, device="mps"):
flat = torch.randn(numel, device=device)
flat.resize_(shape, memory_format=memory_format)
self.assertTrue(flat.is_contiguous(memory_format=memory_format))
test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
# TODO: OpInfo this
def _test_atleast(self, device, torch_fn):
# 0-dim
s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
gradcheck(lambda x: torch_fn(x), s)
gradgradcheck(lambda x: torch_fn(x), s)
# 1-dim
a = torch.rand(4, dtype=torch.double, requires_grad=True)
gradcheck(lambda x: torch_fn(x), a)
gradgradcheck(lambda x: torch_fn(x), a)
# 2,3,4-dim
b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
input_tuple = (s, a, b, c, d)
gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
def test_atleast_gradient(self, device="mps"):
self._test_atleast(device, torch.atleast_1d)
self._test_atleast(device, torch.atleast_2d)
self._test_atleast(device, torch.atleast_3d)
def test_view(self, device="mps"):
tensor = torch.rand(15, device=device)
template = torch.rand(3, 5, device=device)
empty = torch.empty(0, device=device)
target = template.size()
self.assertEqual(tensor.view_as(template).size(), target)
self.assertEqual(tensor.view(3, 5).size(), target)
self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
self.assertEqual(tensor.view(-1, 5).size(), target)
self.assertEqual(tensor.view(3, -1).size(), target)
tensor_view = tensor.view(5, 3)
tensor_view.fill_(random.uniform(0, 1))
self.assertEqual(empty.view_as(empty), empty)
self.assertEqual(empty.view(0), empty)
self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
# test size inference with empty tensors
self.assertEqual(empty.view(-1).size(), torch.Size([0]))
self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
empty.view(-1, 0)
with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
empty.view(3, 0, -1, 0)
self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
def test_contiguous(self, device="mps"):
x = torch.randn(1, 16, 5, 5, device=device)
self.assertTrue(x.is_contiguous())
stride = list(x.stride())
stride[0] = 20
# change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
x.set_(x.storage(), 0, x.size(), stride)
self.assertTrue(x.is_contiguous())
def test_resize_mps_dtypes(self, device="mps"):
shape = (2, 2)
for dt in MPS_DTYPES:
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
x.resize_(shape)
self.assertEqual(shape, x.shape)
def test_resize_as_mps_dtypes(self, device="mps"):
for dt in MPS_DTYPES:
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
x.resize_as_(y)
self.assertEqual(y.shape, x.shape)
def test_resize_overflow(self, device="mps"):
x = torch.empty((), dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
x.resize_([2, 4, 2**29, 2**29])
with self.assertRaisesRegex(RuntimeError, 'overflow'):
x.resize_([8, 8, 2**29, 2**29])
def test_view_all_dtypes_and_devices(self, device="mps"):
for dt in (torch.float, torch.bool):
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
self.assertEqual(x.view(6).shape, [6])
class TestConvolutionMPS(TestCaseMPS):
def test_conv1d_all_strides_paddings(self):
# https://github.com/pytorch/pytorch/issues/82921
def helper(stride, padding):
y_cpu = torch.randn(1, 57, 40)
conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False)
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
y_gpu = y_cpu.to(device='mps')
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu())
for stride in range(1, 4):
for padding in range(1, 4):
helper(stride, padding)
def test_conv1d_channels_last(self):
# https://github.com/pytorch/pytorch/issues/81557
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.arange((128 * 176), dtype=torch.float32)
a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
out_cpu = model_cpu(a_cpu)
a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)
self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
def test_conv_transpose_1d_all_strides(self):
# https://github.com/pytorch/pytorch/issues/82711
def helper(stride):
y_cpu = torch.ones(1, 1, 2)
deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1)
deconv_cpu.weight.data = torch.ones(1, 1, 2)
deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
x_cpu = deconv_cpu(y_cpu)
y_gpu = y_cpu.to(device='mps')
x_gpu = deconv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu())
[helper(stride) for stride in [1, 2, 3]]
def test_conv_transpose_1d_nn_functional(self):
# https://github.com/pytorch/pytorch/issues/82563
tin = torch.rand((1, 512, 1245), dtype=torch.float32)
tparams = torch.rand((512, 256, 16), dtype=torch.float32)
tbias = torch.rand((256), dtype=torch.float32)
device = 'cpu'
tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
device = 'mps'
tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
def test_conv_backward_1d_channels_last(self):
def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
# https://github.com/pytorch/pytorch/issues/84511
conv_cpu = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
conv_mps = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
data = torch.rand(shape, dtype=torch.float32)
x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
res_cpu = conv_cpu(x_cpu)
res_mps = conv_mps(x_mps)
self.assertEqual(res_cpu, res_mps)
res_cpu = res_cpu.sum().backward()
res_mps = res_mps.sum().backward()
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
self.assertEqual(x_cpu.grad, x_mps.grad)
helper(shape=(1, 176, 1))
helper(shape=(2, 12, 1))
helper(shape=(3, 176, 1))
helper(shape=(4, 376, 1))
helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)
out_cpu = model_cpu(a_cpu)
a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)
self.assertEqual(out_cpu.shape, out_mps.shape)
self.assertEqual(out_cpu, out_mps.cpu())
def test_conv2d_all_strides_paddings(self):
# https://github.com/pytorch/pytorch/issues/83180
def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
if permute_data:
x_cpu.permute(0, 2, 3, 1)
x_mps.permute(0, 2, 3, 1)
for strideX in range(1, 4):
for strideY in range(1, 4):
conv_cpu = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
conv_mps = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
res_cpu = conv_cpu(x_cpu)
res_mps = conv_mps(x_mps)
self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
res_cpu = res_cpu.sum().backward()
res_mps = res_mps.sum().backward()
self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
self.assertEqual(x_cpu.grad, x_mps.grad)
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
for permute_data in [True, False]:
helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
def test_conv_transpose_2d_strided(self):
def helper(m_cpu, memory_format):
m_mps = copy.deepcopy(m_cpu).requires_grad_()
m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
input_mps = input_cpu.detach().clone().to("mps")
output_cpu = m_cpu(input_cpu)
output_mps = m_mps(input_mps)
self.assertEqual(output_cpu, output_mps)
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
# With square kernels and equal stride
helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
# non-square kernels and unequal stride and with padding
helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
def test_conv_transpose_2d_specified_output(self):
input_cpu = torch.randn(1, 16, 12, 12)
input_mps = input_cpu.detach().clone().to("mps")
downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
h_cpu = downsample_cpu(input_cpu)
h_mps = downsample_mps(input_mps)
self.assertEqual(h_cpu, h_mps)
size_cpu = h_cpu.size()
size_mps = h_mps.size()
self.assertEqual(size_cpu, size_mps)
output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
output_mps = upsample_mps(h_mps, output_size=input_mps.size())
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_conv2d_single_stride(self):
y_cpu = torch.randn(2, 2, 3, 6)
y_gpu = y_cpu.to(device='mps')
for stride in range(1, 4):
conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride)
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
@unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
def test_conv3d_single_stride(self):
# Conv3d is only available from MacOS 13.2 onwards
y_cpu = torch.randn(2, 2, 3, 6)
y_gpu = y_cpu.to(device='mps')
for stride in range(1, 4):
conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
def test_grid_sample(self):
def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
# grid_dim_contig_order specifies the dimension order that can
# make grid to be contiguous.
# i.e., grid.permute(grid_dim_contig_order) is contiguous.
# e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
# initialized with contiguous tensor of shape [N, 2, H, W]
# and permuted to [N, H, W, 2] afterwards.
grid_shape = [N, H, W, 2]
grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
grid_fwd_permute = [None, None, None, None]
for i, d in enumerate(grid_dim_contig_order):
grid_fwd_permute[d] = i
def get_grid(device='cpu', data=None):
if data is not None:
assert list(data.shape) == grid_shape
data = data.permute(grid_dim_contig_order).to(device)
else:
data = torch.randn(grid_init_shape, device=device)
grid = data.permute(grid_fwd_permute)
assert grid.permute(grid_dim_contig_order).is_contiguous()
return grid
input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
grid_cpu = get_grid().requires_grad_()
out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
align_corners=align_corners)
self.assertEqual(out_cpu.size(), torch.Size([N, C, H, W]))
gradients = torch.randn_like(out_cpu)
out_cpu.backward(gradients)
# Compare against unvectorized CPU fallback
# NOTE [ grid_sample CPU fallback ]
# grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
# 32-bit floats. So we also have a fallback that is used only for float tensors
# requiring 64-bit indexing. That requires too much memory to run on CI, so we
# also export the fallback and test it here to ensure feature parity with
# the vectorized version.
input_fallback = input_cpu.float().detach_().requires_grad_()
grid_fallback = grid_cpu.float().detach_().requires_grad_()
out_fallback = torch._grid_sampler_2d_cpu_fallback(
input_fallback, grid_fallback,
F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
F.GRID_SAMPLE_PADDING_MODES[padding_mode],
align_corners)
self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
out_fallback.backward(gradients.float())
if input_requires_grad:
self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad)
grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
self.assertEqual(out_cpu, out_mps)
out_mps.backward(gradients.to("mps"))
if input_requires_grad:
self.assertEqual(input_cpu.grad, input_mps.grad)
self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0)
# check that zero-dimensional input strides don't error out
base_input = torch.randn(N, C, 1, IW)
input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad)
out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
align_corners=align_corners)
input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
self.assertEqual(out_cpu, out_mps)
# test same size output
test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
# test larger output
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(IH + 1, 12)
W = random.randint(IW + 1, 12)
test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
# test smaller output
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(2, IH)
W = random.randint(2, IW)
test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
# test 1x1 inpput
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = 1
IW = 1
H = random.randint(2, 5)
W = random.randint(2, 5)
test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
# testing empty grid
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
W = random.randint(3, IW + 2)
test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
# testing empty channel
N = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(3, IH + 2)
W = random.randint(3, IW + 2)
test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
# testing empty batch
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(3, IH + 2)
W = random.randint(3, IW + 2)
test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
for mode in ('bilinear', 'nearest'):
for padding_mode in ('zeros', 'reflection'):
for align_corners in (True, False):
# test known input
input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
grid = torch.tensor(
[[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
[[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
if mode == 'bilinear':
if padding_mode == 'zeros':
if align_corners:
groundtruth = torch.tensor(
[[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
[2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
[0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'border':
if align_corners:
groundtruth = torch.tensor(
[[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
[2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
[1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'reflection':
if align_corners:
groundtruth = torch.tensor(
[[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
[2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
[1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
else:
raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
elif mode == 'nearest':
if padding_mode == 'zeros':
if align_corners:
groundtruth = torch.tensor(
[[0., 8., 5., 7., 9.],
[1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0., 8., 5., 7., 0.],
[1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'border':
if align_corners:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'reflection':
if align_corners:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
else:
raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
elif mode == 'bicubic':
if padding_mode == 'zeros':
if align_corners:
groundtruth = torch.tensor(
[[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
[2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
[0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'border':
if align_corners:
groundtruth = torch.tensor(
[[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
[2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
[0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'reflection':
if align_corners:
groundtruth = torch.tensor(
[[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
[1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
[0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
else:
raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
else:
raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
align_corners=align_corners)
self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
msg=f"groundtruth comparison failed for mode={mode}, "
f"padding_mode={padding_mode}")
class TestAdvancedIndexing(TestCaseMPS):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
@unittest.skipIf(product_version < 14.0, "Skipped on macOS < 14")
def test_nonzero_no_warning(self):
device = "mps"
t = torch.randn((2, 2), device=device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
torch.nonzero(t)
t.nonzero()
self.assertEqual(len(w), 0)
def test_nonzero(self):
def helper(dtype):
device = "mps"
shapes = [
torch.Size((12,)),
torch.Size((12, 1)),
torch.Size((1, 12)),
torch.Size((6, 2)),
torch.Size((3, 2, 2)),
torch.Size((5, 5, 5)),
]
def gen_nontrivial_input(shape, dtype, device):
if dtype != torch.bfloat16:
return torch.randint(2, shape, device=device, dtype=dtype)
else:
# windows does not work for bfloat16 randing
return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
for shape in shapes:
tensor = gen_nontrivial_input(shape, dtype, device)
dst1 = torch.nonzero(tensor, as_tuple=False)
dst2 = tensor.nonzero(as_tuple=False)
dst3 = torch.empty([], dtype=torch.long, device=device)
dst3 = dst3.resize_(0)
torch.nonzero(tensor, out=dst3)
np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
tup1 = torch.nonzero(tensor, as_tuple=True)
tup2 = tensor.nonzero(as_tuple=True)
tup1 = torch.stack(tup1).t().cpu()
tup2 = torch.stack(tup2).t().cpu()
self.assertEqual(tup1, np_result, atol=0, rtol=0)
self.assertEqual(tup2, np_result, atol=0, rtol=0)
[helper(dtype) for dtype in self.supported_dtypes]
def test_nonzero_astuple_out(self):
device = "mps"
t = torch.randn((3, 3, 3), device=device)
out = torch.empty([], dtype=torch.long, device=device)
out = out.resize_(0)
with self.assertRaises(RuntimeError):
torch.nonzero(t, as_tuple=True, out=out)
self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
# Verifies that JIT script cannot handle the as_tuple kwarg
# See Issue https://github.com/pytorch/pytorch/issues/45499.
def _foo(t):
tuple_result = torch.nonzero(t, as_tuple=True)
nontuple_result = torch.nonzero(t, as_tuple=False)
out = torch.empty_like(nontuple_result)
torch.nonzero(t, as_tuple=False, out=out)
return tuple_result, nontuple_result, out
with self.assertRaises(RuntimeError):
scripted_foo = torch.jit.script(_foo)
# Verifies that JIT tracing works fine
traced_foo = torch.jit.trace(_foo, t)
traced_tuple, traced_nontuple, traced_out = traced_foo(t)
expected_tuple = torch.nonzero(t, as_tuple=True)
expected_nontuple = torch.nonzero(t)
self.assertEqual(traced_tuple, expected_tuple)
self.assertEqual(traced_nontuple, expected_nontuple)
self.assertEqual(traced_out, expected_nontuple)
def test_nonzero_discontiguous(self):
device = "mps"
shape = (4, 4)
tensor = torch.randint(2, shape, device=device)
tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
dst1 = tensor.nonzero(as_tuple=False)
dst2 = tensor_nc.nonzero(as_tuple=False)
self.assertEqual(dst1, dst2, atol=0, rtol=0)
dst3 = torch.empty_like(dst1)
data_ptr = dst3.data_ptr()
# expect dst3 storage to be reused
torch.nonzero(tensor, out=dst3)
self.assertEqual(data_ptr, dst3.data_ptr())
self.assertEqual(dst1, dst3, atol=0, rtol=0)
# discontiguous out
dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
data_ptr = dst4.data_ptr()
strides = dst4.stride()
torch.nonzero(tensor, out=dst4)
self.assertEqual(data_ptr, dst4.data_ptr())
self.assertEqual(dst1, dst4, atol=0, rtol=0)
self.assertEqual(strides, dst4.stride())
def test_nonzero_non_diff(self):
device = "mps"
x = torch.randn(10, requires_grad=True, device=device)
nz = x.nonzero()
self.assertFalse(nz.requires_grad)
def test_nonzero_multi_threading(self):
# Test that MPS doesn't crash if nonzero called concurrently
# See https://github.com/pytorch/pytorch/issues/100285
x = torch.rand(3, 3, device="mps")
t1 = threading.Thread(target=torch.nonzero, args=(x,))
t2 = threading.Thread(target=torch.nonzero, args=(x,))
t1.start()
t2.start()
def test_sliced_view_cast(self):
# This used to crash on MacOS Sequoia
# See https://github.com/pytorch/pytorch/issues/137800
x = torch.rand(16, 16, device='mps', dtype=torch.float16)
y = x[:, 0:2].view(torch.float32) + 1
def test_masked_select(self):
x = torch.randn(3, 4)
x_mps = x.to("mps")
mask = x.ge(0.5)
mask_mps = x_mps.ge(0.5)
res = torch.masked_select(x, mask)
res_mps = torch.masked_select(x_mps, mask_mps)
self.assertEqual(res, res_mps)
# examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
def test_indexing_get(self):
def helper(dtype):
x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]]
y_mps = x_mps[[0, 1, 2], [0, 1, 0]]
self.assertEqual(y_cpu, y_mps, str(dtype))
[helper(dtype) for dtype in self.supported_dtypes]
def test_indexing_select_corners(self):
def helper(dtype):
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
rows_cpu = torch.tensor([[0, 0], [3, 3]])
rows_mps = rows_cpu.detach().clone().to("mps")
cols_cpu = torch.tensor([[0, 2], [0, 2]])
cols_mps = cols_cpu.detach().clone().to("mps")
res_cpu = x_cpu[rows_cpu, cols_cpu]
res_mps = x_mps[rows_mps, cols_mps]
self.assertEqual(res_cpu, res_mps, str(dtype))
[helper(dtype) for dtype in self.supported_dtypes]
# FIXME: uint8 fails for this testcase, needs further debugging
def test_slicing_using_advanced_index_for_column(self):
def helper(dtype):
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
z_cpu = x_cpu[1:4, 1:3]
z_mps = x_mps[1:4, 1:3]
self.assertEqual(z_cpu, z_mps, str(dtype))
# using advanced index for column
y_cpu = x_cpu[1:4, [1, 2]]
y_mps = x_mps[1:4, [1, 2]]
self.assertEqual(y_cpu, y_mps, str(dtype))
# FIXME: use supported_dtypes once uint8 is fixed
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
def test_boolean_array_indexing(self):
def helper(dtype):
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
res_cpu = x_cpu[x_cpu > 5]
res_mps = x_mps[x_mps > 5]
self.assertEqual(res_cpu, res_mps, str(dtype))
for dtype in self.supported_dtypes:
# MPS support binary op with uint8 natively starting from macOS 13.0
if product_version < 13.0 and dtype == torch.uint8:
continue
helper(dtype)
def test_advanced_indexing_3D_get(self):
def helper(x_cpu):
x_mps = x_cpu.detach().clone().to("mps")
self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2],
[1.3, 1.4, 1.5, 1.6]],
[[2.0, 2.1, 2.2, 2.3],
[2.4, 2.5, 2.6, 2.7],
[2.8, 2.9, 3.0, 3.1],
[3.2, 3.3, 3.4, 3.5]],
[[4.0, 4.1, 4.2, 4.3],
[4.4, 4.5, 4.6, 4.7],
[4.8, 4.9, 5.0, 5.1],
[5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
helper(x_cpu)
for idx in range(len(self.supported_np_dtypes)):
# torch.randn / torch.rand don't work with all dtypes
# Generate input data for all dtypes on Numpy them move to torch
input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
helper(inputCPU)
def test_advanced_indexing_3D_put(self):
def helper(x_cpu):
dtype = x_cpu.dtype
x_mps = x_cpu.detach().clone().to("mps")
out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
out_tensor_cpu_view = out_tensor_cpu[1:]
out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
out_tensor_mps_view = out_tensor_mps[1:]
x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
x_mps[[1, 2], 3, :] = out_tensor_mps_view
self.assertEqual(x_cpu, x_mps)
x_cpu[[0, 2], :, :] = out_tensor_cpu_view
x_mps[[0, 2], :, :] = out_tensor_mps_view
self.assertEqual(x_cpu, x_mps)
x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
x_mps[:, [1, 0], [1]] = out_tensor_mps_view
self.assertEqual(x_cpu, x_mps)
x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2],
[1.3, 1.4, 1.5, 1.6]],
[[2.0, 2.1, 2.2, 2.3],
[2.4, 2.5, 2.6, 2.7],
[2.8, 2.9, 3.0, 3.1],
[3.2, 3.3, 3.4, 3.5]],
[[4.0, 4.1, 4.2, 4.3],
[4.4, 4.5, 4.6, 4.7],
[4.8, 4.9, 5.0, 5.1],
[5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
helper(x_cpu)
for idx in range(len(self.supported_np_dtypes)):
# torch.randn / torch.rand don't work with all dtypes
# Generate input data for all dtypes on Numpy them move to torch
input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
helper(inputCPU)
def test_index_put_with_view_indices(self):
def helper(dtype):
target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
self.assertEqual(target_cpu, target_mps)
[helper(dtype) for dtype in [torch.int32, torch.float]]
# tests from 'test_indexing.py'
def test_advancedindex_big(self, device="mps"):
reference = torch.arange(0, 123344, dtype=torch.int, device=device)
self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
def test_set_item_to_scalar_tensor(self, device="mps"):
m = random.randint(1, 10)
n = random.randint(1, 10)
z = torch.randn([m, n], device=device)
a = 1.0
w = torch.tensor(a, requires_grad=True, device=device)
z[:, 0] = w
z.sum().backward()
self.assertEqual(w.grad, m * a)
def test_single_int(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[4].shape, (7, 3))
def test_multiple_int(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[4].shape, (7, 3))
self.assertEqual(v[4, :, 1].shape, (7,))
def test_none(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[None].shape, (1, 5, 7, 3))
self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
def test_step(self, device="mps"):
v = torch.arange(10, device=device)
self.assertEqual(v[::1], v)
self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
self.assertEqual(v[::11].tolist(), [0])
self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
def test_step_assignment(self, device="mps"):
v = torch.zeros(4, 4, device=device)
v[0, 1::2] = torch.tensor([3., 4.], device=device)
self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].sum(), 0)
def test_bool_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
self.assertEqual(v[boolIndices].shape, (3, 7, 3))
self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
with warnings.catch_warnings(record=True) as w:
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
self.assertEqual(v[boolIndices], v[uint8Indices])
self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
self.assertEqual(len(w), 2)
@unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_bool_indices_accumulate(self, device="mps"):
mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
mask = mask > 0
y = torch.ones(size=(10, 10), device=device)
y.index_put_((mask, ), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
def test_multiple_bool_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
def test_byte_mask(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
with warnings.catch_warnings(record=True) as w:
self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
self.assertEqual(len(w), 2)
v = torch.tensor([1.], device=device)
self.assertEqual(v[v == 0], torch.tensor([], device=device))
def test_byte_mask_accumulate(self, device="mps"):
mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
y = torch.ones(size=(10, 10), device=device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
y.index_put_((mask, ), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
self.assertEqual(len(w), 2)
def test_index_put_accumulate_expanded_values(self, device="mps"):
t = torch.zeros((5, 2))
t_dev = t.to(device)
indices = [
torch.tensor([0, 1, 2, 3]),
torch.tensor([1, ]),
]
indices_dev = [i.to(device) for i in indices]
values0d = torch.tensor(1.0)
values1d = torch.tensor([1.0, ])
out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values0d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values1d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
t = torch.zeros(4, 3, 2)
t_dev = t.to(device)
indices = [
torch.tensor([0, ]),
torch.arange(3)[:, None],
torch.arange(2)[None, :],
]
indices_dev = [i.to(device) for i in indices]
values1d = torch.tensor([-1.0, -2.0])
values2d = torch.tensor([[-1.0, -2.0], ])
out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values1d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values2d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
def test_index_put_accumulate_non_contiguous(self, device="mps"):
t = torch.zeros((5, 2, 2))
t_dev = t.to(device)
t1 = t_dev[:, 0, :]
t2 = t[:, 0, :]
self.assertFalse(t1.is_contiguous())
self.assertFalse(t2.is_contiguous())
indices = [torch.tensor([0, 1]), ]
indices_dev = [i.to(device) for i in indices]
value = torch.randn(2, 2)
out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
out_cpu = t2.index_put_(indices, value, accumulate=True)
self.assertFalse(t1.is_contiguous())
self.assertFalse(t2.is_contiguous())
self.assertEqual(out_mps.cpu(), out_cpu)
def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
# TODO: replace with a better solution.
# Currently, here using torchscript to put None into indices.
# on C++ it gives indices as a list of 2 optional tensors: first is null and
# the second is a valid tensor.
@torch.jit.script
def func(x, i, v):
idx = [None, i]
x.index_put_(idx, v, accumulate=True)
return x
n = 4
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
t_dev = t.to(device)
indices = torch.tensor([1, 0])
indices_dev = indices.to(device)
value0d = torch.tensor(10.0)
value1d = torch.tensor([1.0, 2.0])
out_mps = func(t_dev, indices_dev, value0d.to("mps"))
out_cpu = func(t, indices, value0d)
self.assertEqual(out_mps.cpu(), out_cpu)
out_mps = func(t_dev, indices_dev, value1d.to("mps"))
out_cpu = func(t, indices, value1d)
self.assertEqual(out_mps.cpu(), out_cpu)
def test_index_put_accumulate_duplicate_indices(self, device="mps"):
for i in range(1, 128):
# generate indices by random walk, this will create indices with
# lots of duplicates interleaved with each other
delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
indices = delta.cumsum(0).long().to("mps")
# abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
values = torch.randn(indices.size(0), device=device)
output = input.index_put((indices,), values, accumulate=True)
input_list = input.tolist()
indices_list = indices.tolist()
values_list = values.tolist()
for i, v in zip(indices_list, values_list):
input_list[i] += v
self.assertEqual(output, input_list)
def test_index_put_deterministic(self, device="mps"):
def helper(dtype, accumulate, deterministic, num_tests=128):
acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype)
non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype)
t_idx = torch.tensor(
[0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2,
0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2]
)
for _ in range(num_tests):
try:
torch.use_deterministic_algorithms(deterministic)
t = torch.zeros(3, dtype=dtype, device=device)
t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate)
if accumulate:
self.assertEqual(t, acc_expected)
else:
self.assertEqual(t, non_acc_expected)
finally:
torch.use_deterministic_algorithms(False)
for accumulate, deterministic in product((False, True), (False, True)):
dtype = torch.float if accumulate else torch.long
if not accumulate and not deterministic:
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"):
helper(dtype, accumulate, deterministic)
else:
helper(dtype, accumulate, deterministic)
def test_multiple_byte_mask(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
mask2 = torch.ByteTensor([1, 1, 1]).to(device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
self.assertEqual(len(w), 2)
def test_byte_mask2d(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
c = torch.randn(5, 7, device=device)
num_ones = (c > 0).sum()
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
def test_jit_indexing(self, device="mps"):
def fn1(x):
x[x < 50] = 1.0
return x
def fn2(x):
x[0:50] = 1.0
return x
scripted_fn1 = torch.jit.script(fn1)
scripted_fn2 = torch.jit.script(fn2)
data = torch.arange(100, device=device, dtype=torch.float)
out = scripted_fn1(data.detach().clone())
ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
self.assertEqual(out, ref)
out = scripted_fn2(data.detach().clone())
self.assertEqual(out, ref)
def test_int_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
def test_index_put_src_datatype(self):
def helper(device, dtype):
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
indices = (torch.tensor([0, 2, 1]),)
res = src.index_put_(indices, vals, accumulate=True)
self.assertEqual(res.shape, src.shape)
[helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
@unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_index_src_datatype(self):
def helper(device, dtype):
orig_dtype = dtype
if dtype is torch.bool:
dtype = torch.uint8
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
if orig_dtype is torch.bool:
src = src == 1
# test index
res = src[[0, 2, 1], :, :]
self.assertEqual(res.shape, src.shape)
# test index_put, no accum
src[[0, 2, 1], :, :] = res
self.assertEqual(res.shape, src.shape)
[helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
def test_int_indices2d(self, device="mps"):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
rows = torch.tensor([[0, 0], [3, 3]], device=device)
columns = torch.tensor([[0, 2], [0, 2]], device=device)
self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
def test_int_indices_broadcast(self, device="mps"):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
rows = torch.tensor([0, 3], device=device)
columns = torch.tensor([0, 2], device=device)
result = x[rows[:, None], columns]
self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
def test_empty_index(self, device="mps"):
x = torch.arange(0, 12, device=device).view(4, 3)
idx = torch.tensor([], dtype=torch.long, device=device)
self.assertEqual(x[idx].numel(), 0)
# empty assignment should have no effect but not throw an exception
y = x.clone()
y[idx] = -1
self.assertEqual(x, y)
mask = torch.zeros(4, 3, device=device).bool()
y[mask] = -1
self.assertEqual(x, y)
def test_empty_ndim_index(self, device="mps"):
x = torch.randn(5, device=device)
self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
x = torch.randn(2, 3, 4, 5, device=device)
self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
x = torch.empty(10, 0, device=device)
self.assertEqual(x[[1, 2]].shape, (2, 0))
self.assertEqual(x[[], []].shape, (0,))
with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
x[:, [0, 1]]
def test_empty_ndim_index_bool(self, device="mps"):
x = torch.randn(5, device=device)
self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
def test_empty_slice(self, device="mps"):
x = torch.randn(2, 3, 4, 5, device=device)
y = x[:, :, :, 1]
z = y[:, 1:1, :]
self.assertEqual((2, 0, 4), z.shape)
# this isn't technically necessary, but matches NumPy stride calculations.
self.assertEqual((60, 20, 5), z.stride())
self.assertTrue(z.is_contiguous())
def test_index_getitem_copy_bools_slices(self, device="mps"):
true = torch.tensor(1, dtype=torch.uint8, device=device)
false = torch.tensor(0, dtype=torch.uint8, device=device)
tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
for a in tensors:
self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
self.assertEqual(torch.empty(0, *a.shape), a[False])
self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
self.assertEqual(torch.empty(0, *a.shape), a[false])
self.assertEqual(a.data_ptr(), a[None].data_ptr())
self.assertEqual(a.data_ptr(), a[...].data_ptr())
def test_index_setitem_bools_slices(self, device="mps"):
true = torch.tensor(1, dtype=torch.uint8, device=device)
false = torch.tensor(0, dtype=torch.uint8, device=device)
tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
for a in tensors:
# prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
# (some of these ops already prefix a 1 to the size)
neg_ones = torch.ones_like(a) * -1
neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
a[True] = neg_ones_expanded
self.assertEqual(a, neg_ones)
a[False] = 5
self.assertEqual(a, neg_ones)
a[true] = neg_ones_expanded * 2
self.assertEqual(a, neg_ones * 2)
a[false] = 5
self.assertEqual(a, neg_ones * 2)
a[None] = neg_ones_expanded * 3
self.assertEqual(a, neg_ones * 3)
a[...] = neg_ones_expanded * 4
self.assertEqual(a, neg_ones * 4)
if a.dim() == 0:
with self.assertRaises(IndexError):
a[:] = neg_ones_expanded * 5
def test_index_scalar_with_bool_mask(self, device="mps"):
a = torch.tensor(1, device=device)
uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
boolMask = torch.tensor(True, dtype=torch.bool, device=device)
self.assertEqual(a[uintMask], a[boolMask])
self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
a = torch.tensor(True, dtype=torch.bool, device=device)
self.assertEqual(a[uintMask], a[boolMask])
self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
def test_setitem_expansion_error(self, device="mps"):
true = torch.tensor(True, device=device)
a = torch.randn(2, 3, device=device)
# check prefix with non-1s doesn't work
a_expanded = a.expand(torch.Size([5, 1]) + a.size())
# NumPy: ValueError
with self.assertRaises(RuntimeError):
a[True] = a_expanded
with self.assertRaises(RuntimeError):
a[true] = a_expanded
def test_getitem_scalars(self, device="mps"):
zero = torch.tensor(0, dtype=torch.int64, device=device)
one = torch.tensor(1, dtype=torch.int64, device=device)
# non-scalar indexed with scalars
a = torch.randn(2, 3, device=device)
self.assertEqual(a[0], a[zero])
self.assertEqual(a[0][1], a[zero][one])
self.assertEqual(a[0, 1], a[zero, one])
self.assertEqual(a[0, one], a[zero, 1])
# indexing by a scalar should slice (not copy)
self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
# scalar indexed with scalar
r = torch.randn((), device=device)
with self.assertRaises(IndexError):
r[:]
with self.assertRaises(IndexError):
r[zero]
self.assertEqual(r, r[...])
def test_setitem_scalars(self, device="mps"):
zero = torch.tensor(0, dtype=torch.int64)
# non-scalar indexed with scalars
a = torch.randn(2, 3, device=device)
a_set_with_number = a.clone()
a_set_with_scalar = a.clone()
b = torch.randn(3, device=device)
a_set_with_number[0] = b
a_set_with_scalar[zero] = b
self.assertEqual(a_set_with_number, a_set_with_scalar)
a[1, zero] = 7.7
self.assertEqual(7.7, a[1, 0])
# scalar indexed with scalars
r = torch.randn((), device=device)
with self.assertRaises(IndexError):
r[:] = 8.8
with self.assertRaises(IndexError):
r[zero] = 8.8
r[...] = 9.9
self.assertEqual(9.9, r)
def test_basic_advanced_combined(self, device="mps"):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
# Check that it is a copy
unmodified = x.clone()
x[1:2, [1, 2]].zero_()
self.assertEqual(x, unmodified)
# But assignment should modify the original
unmodified = x.clone()
x[1:2, [1, 2]] = 0
self.assertNotEqual(x, unmodified)
def test_int_assignment(self, device="mps"):
x = torch.arange(0, 4, device=device).view(2, 2)
x[1] = 5
self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
x = torch.arange(0, 4, device=device).view(2, 2)
x[1] = torch.arange(5, 7, device=device)
self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
def test_byte_tensor_assignment(self, device="mps"):
x = torch.arange(0., 16, device=device).view(4, 4)
b = torch.ByteTensor([True, False, True, False]).to(device)
value = torch.tensor([3., 4., 5., 6.], device=device)
with warnings.catch_warnings(record=True) as w:
x[b] = value
self.assertEqual(len(w), 1)
self.assertEqual(x[0], value)
self.assertEqual(x[1], torch.arange(4., 8, device=device))
self.assertEqual(x[2], value)
self.assertEqual(x[3], torch.arange(12., 16, device=device))
def test_variable_slicing(self, device="mps"):
x = torch.arange(0, 16, device=device).view(4, 4)
indices = torch.IntTensor([0, 1]).to(device)
i, j = indices
self.assertEqual(x[i:j], x[0:1])
def test_ellipsis_tensor(self, device="mps"):
x = torch.arange(0, 9, device=device).view(3, 3)
idx = torch.tensor([0, 2], device=device)
self.assertEqual(x[..., idx].tolist(), [[0, 2],
[3, 5],
[6, 8]])
self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
[6, 7, 8]])
def test_invalid_index(self, device="mps"):
x = torch.arange(0, 16, device=device).view(4, 4)
self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
def test_out_of_bound_index(self, device="mps"):
x = torch.arange(0, 100, device=device).view(2, 5, 10)
self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
lambda: x[0, 1, 15])
self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
lambda: x[:, :, 12])
def test_zero_dim_index(self, device="mps"):
x = torch.tensor(10, device=device)
self.assertEqual(x, x.item())
def runner():
print(x[0])
return x[0]
self.assertRaisesRegex(IndexError, 'invalid index', runner)
def test_cpu_indices(self, device="mps"):
idx = torch.tensor([0, 1])
b = torch.zeros(2, device=device)
x = torch.ones(10, device=device)
x[idx] = b # index_put_
ref = torch.ones(10, device=device)
ref[:2] = 0
self.assertEqual(x, ref, atol=0, rtol=0)
out = x[idx] # index
self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
def test_nextafter(self, device="mps"):
for dtype in [torch.float16, torch.float32]:
x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype)
y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype)
na = torch.nextafter(x, y)
na_cpu = torch.nextafter(x.cpu(), y.cpu())
na_ge_x_mps = na.cpu() > x.cpu()
# greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051
na_ge_x_cpu = na_cpu > x.cpu()
self.assertEqual(na_ge_x_mps, na_ge_x_cpu)
class TestRNNMPS(TestCaseMPS):
def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
rnn = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
bidirectional=bidirectional,
batch_first=batch_first,
device="cpu"
)
bidirectional_mul = 2 if bidirectional else 1
if batch_first:
input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
else:
input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
rnn = rnn.to(device)
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
output, (hx_out, cx_out) = rnn(inp, (hx, cx))
assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
f = 0
if output_grad_presented:
f = f + 3 * output.sum()
if states_grad_presented:
f = f + (hx_out * cx_out).sum()
param_names, params = zip(*rnn.named_parameters())
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
return output, param_grads, input_grad, hx_grad, cx_grad
if backward:
grad_cases = [
dict(output_grad_presented=True, states_grad_presented=True),
dict(output_grad_presented=False, states_grad_presented=True),
dict(output_grad_presented=True, states_grad_presented=False),
]
for grad_case in grad_cases:
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
get_backward_results(rnn, device, input, hx, cx, **grad_case)
self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad,
f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
LSTM_TEST_CASES = [
{}, # default
dict(batch_first=True),
dict(bias=False),
dict(bidirectional=True),
dict(batch_first=True, bias=False),
dict(bidirectional=True, bias=False),
dict(bidirectional=True, batch_first=True),
dict(bidirectional=True, batch_first=True, bias=False)
]
def test_lstm_forward(self, device="mps", dtype=torch.float32):
for num_layers in [1, 2, 5]:
for test_options in self.LSTM_TEST_CASES:
self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
def test_lstm_backward(self, device="mps", dtype=torch.float32):
for num_layers in [1, 2, 5]:
for test_options in self.LSTM_TEST_CASES:
self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
def test_RNN_cell_no_broadcasting(self):
def test(cell_module, input, hx, input_size, hidden_size):
cell = cell_module(input_size, hidden_size, device='mps')
self.assertRaises(RuntimeError, lambda: cell(input, hx))
def test_all(hidden_size, bad_hx, good_hx, input_size, input):
test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
hidden_size = 20
input_size = 10
input = torch.randn(3, input_size, device='mps')
bad_hx = torch.randn(1, hidden_size, device='mps')
good_hx = torch.randn(3, hidden_size, device='mps')
# Test hidden/input batch size broadcasting
test_all(hidden_size, bad_hx, good_hx, input_size, input)
# Test hx's hidden_size vs module's hidden_size broadcasting
bad_hx = torch.randn(3, 1)
test_all(hidden_size, bad_hx, good_hx, input_size, input)
# Test input's input_size vs module's input_size broadcasting
bad_input = torch.randn(3, 1)
test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
def test_LSTM_cell(self):
# this is just a smoke test; these modules are implemented through
# autograd so no Jacobian test is needed
for bias in (True, False):
input = torch.randn(3, 10, device='mps')
hx = torch.randn(3, 20, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
for _ in range(6):
hx, cx = lstm(input, (hx, cx))
(hx + cx).sum().backward()
def test_LSTM_cell_forward_input_size(self):
input = torch.randn(3, 11, device='mps')
hx = torch.randn(3, 20, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, device='mps')
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
def test_LSTM_cell_forward_hidden_size(self):
input = torch.randn(3, 10, device='mps')
hx = torch.randn(3, 21, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, device='mps')
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
class TestFallbackWarning(TestCase):
# TODO: Remove once test_testing.py is running on MPS devices
def test_no_warning_on_import(self):
out = subprocess.check_output(
[sys.executable, "-W", "always", "-c", "import torch"],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
self.assertEqual(out, "")
def _get_not_implemented_op(self):
# This can be changed once we actually implement 'lcm'
# Should return fn, args, kwargs, string_version
return (torch.lcm,
[torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {},
"torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))")
def test_error_on_not_implemented(self):
fn, args, kwargs, _ = self._get_not_implemented_op()
with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
fn(*args, **kwargs)
def test_warn_on_not_implemented_with_fallback(self):
_, _, _, op = self._get_not_implemented_op()
script = f"""
import os
# MUST happen before pytorch's import
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import warnings
with warnings.catch_warnings(record=True) as w:
import torch
if len(w) > 0:
print(w)
exit(1)
# This should run just fine and raise warning about perf
with warnings.catch_warnings(record=True) as w:
{op}
if len(w) != 1:
print(w)
exit(2)
"""
try:
subprocess.check_output(
[sys.executable, '-W', 'always', '-c', script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),)
except subprocess.CalledProcessError as e:
if e.returncode == 1:
self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
e.output.decode("utf-8"))
elif e.returncode == 2:
self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
else:
self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
e.output.decode("utf-8"))
class TestNoRegression(TestCase):
def test_assert_close(self):
a = torch.ones(1, device="mps")
b = torch.zeros(1, device="mps")
inf = a / b
nan = b / b
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
torch.testing.assert_close(a, inf)
# TODO: The NaN test is failing when all the tests in test_mps are run
# together but passes when run separately. There seems to be memory
# corruption which needs to be fixed for this test to be enabled.
# with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
# torch.testing.assert_close(a, nan)
def test_double_error(self):
with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
a = torch.ones(2, dtype=torch.float64, device="mps")
a = torch.ones(2, device="mps")
with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
a = a.double()
def test_legacy_constructor(self):
a = torch.ones(2, device="mps")
b = a.new(1)
def test_serialization_map_location(self):
# Ensures that cpu Tensor can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2)
torch.save(x, f)
f.seek(0)
x2 = torch.load(f, map_location="mps")
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")
# Ensures that mps Tensors can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps")
torch.save(x, f)
f.seek(0)
x2 = torch.load(f)
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")
# Ensures that mps Tensors can be loaded on cpu
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps")
torch.save(x, f)
f.seek(0)
x2 = torch.load(f, map_location="cpu")
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "cpu")
# Ensures that `mps:0` Tensors can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps:0")
torch.save(x, f)
f.seek(0)
x2 = torch.load(f, map_location="mps:0")
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")
MPS_DTYPES = get_all_dtypes()
for t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]:
del MPS_DTYPES[MPS_DTYPES.index(t)]
MPS_GRAD_DTYPES = [torch.float32, torch.float16]
class TestConsistency(TestCaseMPS):
# TODO: This is only used while some ops are being added.
# This list should contain all ops and dtypes eventually
# This can be generated automatically in the `new_mps_allowlist.txt` file
# by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
# You most likely do NOT want to modify this manually
FP16_LOW_PRECISION_LIST = {
'add', 'sub', 'div', 'addcdiv',
'__rdiv__', '__rmul__',
'nn.functional.huber_loss',
'true_divide', 'kron',
'gradient', 'var', 'std', 'std_mean', 'ldexp',
'linalg.vector_norm', 'lerp',
'addr', 'var_mean',
'var_mean_unbiased',
'acosh', 'asinh', 'asin',
'masked.std',
'nn.functional.normalize',
'nn.functional.triplet_margin_loss',
'nn.functional.triplet_margin_with_distance_loss',
'nn.functional.batch_norm',
'nn.functional.instance_norm',
'round', 'xlogy', 'addcmul',
'nn.functional.cross_entropy',
'nn.functional.binary_cross_entropy',
'nn.functional.nll_loss',
'nn.functional.max_pool2d',
'nn.functional.gelu',
'nn.functional.glu',
'_native_batch_norm_legit',
'_batch_norm_with_update',
'native_batch_norm',
'softmax',
'_softmax_backward_data',
'log_softmax',
'masked.softmax',
'masked.log_softmax',
'masked.softmin',
'nn.functional.kl_div',
'nn.functional.softmin',
'cross', 'linalg.cross',
'prod', 'masked.prod',
'nextafter',
'native_layer_norm',
'nn.functional.layer_norm',
'nn.functional.interpolate',
'nn.functional.upsample_bilinear',
'nn.functional.upsample_nearest',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
'outer',
'sum_to_size', 'sum',
'mul',
'nansum', 'nanmean',
'norm',
}
FP32_LOW_PRECISION_LIST = {
# conv2d and conv_transpose2d results have a very small
# difference compared to CPU/CUDA, so we use lower precision on FP32
'nn.functional.conv2d',
'nn.functional.conv_transpose2d',
'matmul', '__rmatmul__',
'linalg.multi_dot',
'addbmm',
}
def _compute_tolerances(self, op, dtype):
if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]:
return (1e-4, 3e-5)
if op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16:
return (1e-2, 1e-2)
if op.name in ['nn.functional.conv_transpose1d',
'nn.functional.conv_transpose2d',
'nn.functional.conv_transpose3d',
'__rmatmul__', 'addbmm', 'addmv',
'baddbmm', 'cov', 'matmul', 'mv'] and dtype == torch.float16:
return (5e-2, 5e-2)
if op.name == "masked.mean":
return (7e-4, 2e-3)
if op.name == "native_layer_norm":
return (1e-4, 1.3e-5)
if op.name in ["pow", "__rpow__"] and product_version < 13.3:
# The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
# fixed in macOS 13.3+
return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6)
if op.name == "nn.functional.interpolate":
return (1e-3, 1e-4)
if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']:
# TODO: Investigate why this is needed
# See https://github.com/pytorch/pytorch/issues/120237
return (3e-5, 3e-5)
return (None, None)
# Used for accept mode only
NEW_ALLOW_LIST = defaultdict(list)
NEW_ALLOW_LIST_GRAD = defaultdict(list)
@ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES + [torch.complex64])
def test_output_match(self, device, dtype, op):
self.assertEqual(device, "cpu")
def get_samples():
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
)
cpu_samples = get_samples()
for cpu_sample in cpu_samples:
#
# Forward check
#
mps_sample = cpu_sample.transform(
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
mps_args = [mps_sample.input] + list(mps_sample.args)
mps_kwargs = mps_sample.kwargs
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
mps_args[1] = cpu_args[1]
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
atol, rtol = self._compute_tolerances(op, dtype)
if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8:
atol = 1.0
rtol = 0.0
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
def test_output_grad_match(self, device, dtype, op):
self.assertEqual(device, "cpu")
def get_samples():
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
)
cpu_samples = get_samples()
for cpu_sample in cpu_samples:
#
# Forward check
#
forward_failed = False
mps_sample = cpu_sample.transform(
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
mps_args = [mps_sample.input] + list(mps_sample.args)
mps_kwargs = mps_sample.kwargs
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
mps_args[1] = cpu_args[1]
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
if op.name == "unique" and cpu_kwargs["sorted"] is False:
continue
atol, rtol = self._compute_tolerances(op, dtype)
if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16:
atol = 7e-4
rtol = 1.5e-3
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
#
# Backward check
#
if forward_failed:
# We would've failed immediately anyway, but this error is clearer
# We error instead of continuing so that all_backward_pass would not be True
raise RuntimeError("Forward pass already failed")
cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
def req_grad(t):
return isinstance(t, torch.Tensor) and t.requires_grad
diff_cpu_out = tuple(t for t in cpu_out if req_grad(t))
diff_mps_out = tuple(t for t in mps_out if req_grad(t))
diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t))
diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t))
self.assertEqual(len(diff_cpu_out), len(diff_mps_out))
self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg))
if len(diff_cpu_out) == 0:
continue
# rand_like does not work with certain dtypes, so cast to double and cast back
cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out)
mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
# Compare computed gradients with cpu given random grad_output vector
# Sometimes when the derivative is 0, we just don't bother creating the graph
# allow_unused is needed in those cases.
cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
class TestErrorInputs(TestCase):
_ignore_not_implemented_error = True
@ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
def test_error_inputs(self, device, op):
self.assertEqual(device, "mps:0")
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
mps_samples = op.error_inputs(device, set_seed=False)
for mps_sample in mps_samples:
mps_sample_input = mps_sample.sample_input
error_type = mps_sample.error_type
error_regex = mps_sample.error_regex
mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
mps_kwargs = mps_sample_input.kwargs
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
mps_args[1] = mps_args[1].cpu()
with self.assertRaisesRegex(error_type, error_regex):
op(*mps_args, **mps_kwargs)
class TestComplex(TestCase):
def test_tensor_scalar_binops(self):
# Regression test for https://github.com/pytorch/pytorch/issues/119088
def to_cpu(x):
return x.cpu() if isinstance(x, torch.Tensor) else x
# Allocate tensors on mps
with torch.device("mps"):
inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]]
self.assertTrue(all(x.device.type == "mps" for x in inputs))
# Add scalars
inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)])
# Iterate over all permutations of types(int, float, complex, half) and ops (excluding div)
for x, y in itertools.product(inputs, inputs):
for op_name in ["__add__", "__sub__", "__mul__"]:
x_cpu, y_cpu = map(to_cpu, (x, y))
res = getattr(x, op_name)(y)
res_cpu = getattr(x_cpu, op_name)(y_cpu)
self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}")
# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
@skipIfSlowGradcheckEnv
class TestCommon(TestCase):
exact_dtype = True
# Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
@classmethod
def tearDownClass(cls):
super().tearDownClass()
if IS_CI:
err_msg = (
"The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
)
# Assure no opinfo entry has dynamic_dtypes
filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
for op in filtered_ops:
fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
err_msg += "\n" + fmt_str
assert len(filtered_ops) == 0, err_msg
# This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
# MPS still requires some fairly heavy special casing in the test framework.
# When MPS becomes more consistent, this can probably be merged with that test using
# `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
@suppress_warnings
# MPS only supports float32
@ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
def test_numpy_ref_mps(self, device, dtype, op):
# Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
# does not support float64 Tensors.
# A few ops are currently broken on their reference inputs, but not their sample inputs. These should
# get patched up and this workaround removed.
broken_on_ref_inputs = op.name in ('where',)
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
inputs = (
op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs
else op.sample_inputs(device, dtype, set_seed=False)
)
for sample_input in inputs:
self.compare_with_reference(op, op.ref, sample_input)
@dtypes(*get_all_dtypes())
def test_tensor_creation(self, device, dtype):
def ones(device):
return torch.ones((2, 2), dtype=dtype, device=device)
if dtype not in MPS_DTYPES + ([torch.bfloat16, torch.complex64] if product_version > 14.0 else [torch.complex64]):
with self.assertRaises(TypeError):
ones(device)
else:
mps_tensor = ones(device)
cpu_tensor = ones("cpu")
self.assertEqual(mps_tensor.cpu(), cpu_tensor)
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
# to achieve this.
instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
instantiate_parametrized_tests(TestMPS)
if __name__ == "__main__":
run_tests()