blob: 311cf8245c4f3afbd407e2be94fae3fec6682e12 [file] [log] [blame]
Kulin Sethe011a8e2022-05-13 18:28:53 +00001# Owner(s): ["module: mps"]
2
Denis Vieriude7ec2d2023-05-25 23:32:29 +00003import io
Denis Vieriu71ec2612023-02-15 06:09:56 +00004import platform
Kulin Sethe011a8e2022-05-13 18:28:53 +00005import sys
6import math
7import random
8import unittest
9import warnings
Kulin Seth3d833212022-05-20 03:18:09 +000010import subprocess
Alban Desmaison0a651a22022-06-14 17:54:30 +000011import tempfile
Kulin Seth3d833212022-05-20 03:18:09 +000012import os
Kulin Seth31d4b6f2022-08-17 00:26:41 +000013import copy
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000014import gc
Nikita Shulga916183a2023-09-13 19:28:47 +000015import threading
Kulin Sethe011a8e2022-05-13 18:28:53 +000016import torch
17import torch.nn as nn
18import torch.nn.functional as F
Kulin Seth978304f2022-05-14 13:33:16 +000019import itertools
Kulin Seth76cff182022-07-04 06:41:39 +000020from collections import defaultdict
Xuehai Panb005ec62023-02-14 09:14:10 +000021from torch import inf
Jason Lubc880282023-08-08 15:27:34 +000022from torch.nn import Parameter
Alex620dbc42022-10-21 19:03:00 +000023from torch.testing._internal import opinfo
Kulin Seth76cff182022-07-04 06:41:39 +000024from torch.testing._internal.common_utils import \
Nikita Shulga30610252024-05-03 15:20:39 +000025 (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
26 NoTest, skipIfSlowGradcheckEnv, suppress_warnings)
Kulin Sethb744e1c2022-07-01 15:10:56 +000027from torch.testing import make_tensor
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +000028from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
Kulin Sethe011a8e2022-05-13 18:28:53 +000029import torch.backends.mps
Kulin Seth83239352022-06-10 13:16:21 +000030from torch.distributions import Uniform, Exponential
Kulin Sethb744e1c2022-07-01 15:10:56 +000031from functools import partial
PyTorch MergeBotb1943e02022-06-30 16:37:11 +000032
Alex620dbc42022-10-21 19:03:00 +000033from torch.testing._internal.common_methods_invocations import (
34 op_db,
Nikita Shulgafd8367a2023-02-27 15:01:01 +000035 DecorateInfo,
Alex620dbc42022-10-21 19:03:00 +000036 UnaryUfuncInfo,
37 ReductionOpInfo,
38 SpectralFuncInfo,
39 BinaryUfuncInfo,
40)
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000041from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
Kulin Sethe011a8e2022-05-13 18:28:53 +000042from torch.testing._internal.common_nn import NNTestCase
Nikita Shulga4ff91132024-05-24 16:08:04 +000043from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
Kulin Sethe011a8e2022-05-13 18:28:53 +000044import numpy as np
45import torch
soulitzerbfdfeec2022-08-31 17:53:32 -040046import torch.utils._pytree as pytree
Kulin Sethfc596642023-01-04 22:15:13 +000047from itertools import product
Aaron Gokaslan6de28e92023-12-20 19:35:04 +000048import operator
Kulin Sethe011a8e2022-05-13 18:28:53 +000049
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000050test_consistency_op_db = copy.deepcopy(op_db)
51test_error_inputs_op_db = copy.deepcopy(op_db)
Alex620dbc42022-10-21 19:03:00 +000052
53# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
54_ref_test_ops = tuple(
55 filter(
56 lambda op: not isinstance(
57 op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
58 )
59 and op.ref is not None,
60 op_db,
61 )
62)
63
jhavukainend28868c2024-05-20 20:23:53 +000064def xfailIf(condition):
65 def wrapper(func):
66 if condition:
67 return unittest.expectedFailure(func)
68 else:
69 return func
70 return wrapper
71
Nikita Shulga4e29e802024-05-09 13:43:12 +000072def xfailIfMacOS14_4Plus(func):
73 return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821
74
Kulin Seth2bb022e2023-03-08 08:41:21 +000075def mps_ops_grad_modifier(ops):
76 XFAILLIST_GRAD = {
igm5031b9b3a22023-09-12 16:43:37 +000077
78 # precision issues
79 'digamma': [torch.float32],
80 'special.polygammaspecial_polygamma_n_0': [torch.float16],
81 'polygammapolygamma_n_0': [torch.float16],
Nikita Shulga56771282024-04-18 15:21:01 +000082 'nn.functional.binary_cross_entropy': [torch.float16],
igm5031b9b3a22023-09-12 16:43:37 +000083
Kulin Seth2bb022e2023-03-08 08:41:21 +000084 # Unimplemented ops
85 '__getitem__': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +000086 '_segment_reduce': [torch.float16, torch.float32],
Boyuan Feng35d3adb2024-03-08 21:48:08 +000087 '_chunk_cat': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +000088 'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
89 'unfold': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +000090 'sparse.mmreduce': [torch.float32], # csr not supported
91 'unique_consecutive': [torch.float16, torch.float32],
92 'special_modified_bessel_i0': [torch.float16, torch.float32],
93 'scalar_tensor': [torch.float16, torch.float32],
94 'cdist': [torch.float32],
95 'masked.scatter': [torch.float16, torch.float32],
Li-Huai (Allan) Linbe8a4eb2023-04-12 18:13:28 +000096 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
CaoE4b324a82023-10-23 17:43:47 +000097 'aminmax': [torch.float32, torch.float16],
Nikita Shulga6e85a682023-08-25 03:16:18 +000098 'polar': [torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +000099
100 # Correctness issues
101 'atanh': [torch.float32],
102
103 # Random output
104 'exponential': [torch.float16, torch.float32],
105
106 # CPU errors
igm503a389181f2023-10-03 19:20:17 +0000107 # derivative for aten::nextafter is not implemented on CPU
108 'nextafter': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000109 # derivative for aten::floor_divide is not implemented on CPU
110 'floor_divide': [torch.float16, torch.float32],
111 # derivative for aten::narrow_copy is not implemented on CPU
112 'narrow_copy': [torch.float16, torch.float32],
Li-Huai (Allan) Linbb355892023-05-17 01:25:43 +0000113 # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
114 'histogramdd': [torch.float16, torch.float32],
115 # derivative for aten::histogram is not implemented
116 'histogram': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000117 # 'bool' object is not iterable
118 'allclose': [torch.float16, torch.float32],
119 'equal': [torch.float16, torch.float32],
Khushi51fe53e2023-05-10 11:32:45 +0000120 # 'float' object is not iterable
121 'item': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000122 # "mse_backward_cpu_out" not implemented for 'Half'
123 'nn.functional.mse_loss': [torch.float16],
124 # "smooth_l1_backward_cpu_out" not implemented for 'Half'
125 'nn.functional.smooth_l1_loss': [torch.float16],
126 # cpu error: grad requires non-empty inputs
127 'randn': [torch.float16, torch.float32],
128 'signal.windows.bartlett': [torch.float32],
129 'signal.windows.blackman': [torch.float32],
130 'signal.windows.cosine': [torch.float32],
131 'signal.windows.exponential': [torch.float32],
132 'signal.windows.gaussian': [torch.float32],
133 'signal.windows.general_cosine': [torch.float32],
134 'signal.windows.general_hamming': [torch.float32],
135 'signal.windows.hamming': [torch.float32],
136 'signal.windows.hann': [torch.float32],
137 'signal.windows.kaiser': [torch.float32],
138 'signal.windows.nuttall': [torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000139 'eye': [torch.float16, torch.float32],
140
141 # trunc_tensor not working properly for float16
142 'divtrunc_rounding': [torch.float16],
143 'fmod': [torch.float16],
Sun, Jiayid56e1b22023-05-11 15:30:59 +0800144
145 # round not working properly for float16
146 'round': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000147 }
148
149 MACOS_12_3_XFAILLIST_GRAD = {
150 # Unsupported Border padding mode, forward pass success as fallback to cpu
151 'grid_sampler_2d': [torch.float32],
152 # Unimplemented
153 'logaddexp2': [torch.float32],
154
Kulin Seth2bb022e2023-03-08 08:41:21 +0000155 }
156
157 MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
158 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
Cao E1c89ea72023-10-26 08:38:54 +0000159 'masked.softmin': [torch.float32, torch.float16],
160 'masked.softmax': [torch.float32, torch.float16],
161 'masked.log_softmax': [torch.float32, torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000162
163 # Unsupported Border padding mode, forward pass success as fallback to cpu
164 'grid_sampler_2d': [torch.float32],
165
166 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
167 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
168 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
169 # Running `msort` with stable `sort` passes.
170 'msort': [torch.float16],
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000171 }
172
Peter Bell46e80ce2023-10-24 15:19:01 +0100173 SKIPLIST_GRAD = {
Kulin Seth2bb022e2023-03-08 08:41:21 +0000174 'nn.functional.pairwise_distance': [torch.float16],
CaoE7c905212023-09-24 00:25:09 -0700175 # failed assertion `destination datatype must be fp32'
176 'nn.functional.conv1d': [torch.float16],
177 'nn.functional.conv2d': [torch.float16],
178 'nn.functional.conv3d': [torch.float16],
179 'nn.functional.conv_transpose1d': [torch.float16],
180 'nn.functional.conv_transpose2d': [torch.float16],
181 'nn.functional.conv_transpose3d': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000182 }
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000183
Kulin Seth2bb022e2023-03-08 08:41:21 +0000184 MACOS_13_3_XFAILLIST_GRAD = {
185 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
186 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
187 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
188 # Running `msort` with stable `sort` passes.
189 'msort': [torch.float16],
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000190 }
191
watarungurunnnd444a3b2024-02-05 15:36:55 +0000192 ON_MPS_XFAILLIST = {
193 # Failures due to lack of implementation of downstream functions on MPS backend
194 # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
195 'linalg.matrix_rank': None,
Huy Do89921412024-06-05 14:44:00 +0000196
197 # Exception: Caused by sample input at index 3 on MPS
198 'nn.functional.conv3d': [torch.float32],
watarungurunnnd444a3b2024-02-05 15:36:55 +0000199 }
200
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000201 def addDecorator(op, d) -> None:
202 op.decorators = list(op.decorators) if op.decorators is not None else []
203 op.decorators.append(d)
204
205 for op in ops:
206 key = op.name + op.variant_test_name
Kulin Seth2bb022e2023-03-08 08:41:21 +0000207 if key in XFAILLIST_GRAD:
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000208 addDecorator(op, DecorateInfo(
209 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000210 dtypes=XFAILLIST_GRAD[key]))
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000211
Peter Bell46e80ce2023-10-24 15:19:01 +0100212 if key in SKIPLIST_GRAD:
Kulin Seth2bb022e2023-03-08 08:41:21 +0000213 addDecorator(op, DecorateInfo(
214 unittest.skip,
Peter Bell46e80ce2023-10-24 15:19:01 +0100215 dtypes=SKIPLIST_GRAD[key]))
Kulin Seth2bb022e2023-03-08 08:41:21 +0000216
watarungurunnnd444a3b2024-02-05 15:36:55 +0000217 if key in ON_MPS_XFAILLIST:
218 addDecorator(op, DecorateInfo(
219 unittest.expectedFailure,
220 dtypes=ON_MPS_XFAILLIST[key]))
221
Kulin Seth2bb022e2023-03-08 08:41:21 +0000222 if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000223 addDecorator(op, DecorateInfo(
224 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000225 dtypes=MACOS_12_3_XFAILLIST_GRAD[key]))
226
227 if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000228 addDecorator(op, DecorateInfo(
229 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000230 dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key]))
231
232 if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3):
233 addDecorator(op, DecorateInfo(
234 unittest.expectedFailure,
235 dtypes=MACOS_13_3_XFAILLIST_GRAD[key]))
236 yield op
237
238def mps_ops_modifier(ops):
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700239 # Supported complex OPS
Li-Huai (Allan) Lin293d3b82023-09-11 11:56:27 -0700240 SUPPORTED_COMPLEX_OPS = {
Nikita Shulgac7bb8422023-08-31 20:41:51 -0700241 '__radd__',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600242 '__rmul__',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800243 '__getitem__',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000244 'abs',
Nikita Shulgac7bb8422023-08-31 20:41:51 -0700245 'add',
Tom Ritchford23860452024-06-11 12:54:06 +0000246 'alias_copy',
Denis Vieriua40d6df2024-05-03 03:50:55 +0000247 'argwhere',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700248 'atleast_1d',
249 'atleast_2d',
250 'atleast_3d',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800251 'as_strided',
Tom Ritchfordedb45dc2024-06-12 15:12:58 +0000252 'as_strided_copy',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800253 'as_strided_scatter',
254 'broadcast_tensors',
255 'broadcast_to',
Nikita Shulga8d8fb972024-02-12 10:11:25 -0800256 'chalf',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800257 'cfloat',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800258 'chunk',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700259 'clone',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800260 'conj',
261 'conj_physical',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700262 'contiguous',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800263 'diag',
264 'diag_embed',
265 'diagflat',
266 'diagonal',
267 'diagonal_copy',
268 'diagonal_scatter',
269 'dsplit',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700270 'empty',
271 'empty_permuted',
272 'empty_strided',
273 'eye',
Nikita Shulga06787422024-06-11 15:37:03 -0700274 'exp',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800275 'expand',
276 'expand_as',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700277 'flatten',
Li-Huai (Allan) Lin4b804da2023-10-23 20:48:11 -0700278 'fill',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700279 'full',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800280 'H',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800281 'hsplit',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700282 'imag',
Nikita Shulga4c70ab22024-03-25 16:57:35 +0000283 'index_select',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700284 'isfinite',
285 'isinf',
286 'isreal',
287 'item',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600288 'kron',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800289 'linalg.diagonal',
290 'linalg.svd',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700291 'linspace',
292 'logspace',
Li-Huai (Allan) Lin293d3b82023-09-11 11:56:27 -0700293 'linspacetensor_overload',
294 'logspacetensor_overload',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800295 'mH',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800296 'mT',
297 'masked_scatter',
298 'masked_select',
299 'meshgridlist_of_tensors',
300 'meshgridvariadic_tensors',
301 'movedim',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600302 'mul',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800303 'narrow',
304 'narrow_copy',
Nikita Shulga1d610112024-02-08 18:10:59 +0000305 'nn.functional.conv1d',
Nikita Shulga045309a2024-05-28 17:56:13 +0000306 'nn.functional.conv2d',
Nikita Shulga1d610112024-02-08 18:10:59 +0000307 'nn.functional.conv_transpose1d',
Nikita Shulga045309a2024-05-28 17:56:13 +0000308 'nn.functional.conv_transpose2d',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700309 'nn.functional.feature_alpha_dropoutwithout_train',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000310 'nn.functional.padcircular',
Nikita Shulga06787422024-06-11 15:37:03 -0700311 'nn.functional.tanhshrink',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700312 'nn.functional.unfold',
Denis Vieriua40d6df2024-05-03 03:50:55 +0000313 'nonzero',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700314 'ones',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600315 'outer',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800316 'permute',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700317 'positive',
318 'randn',
319 'ravel',
320 'real',
Nikita Shulga4c70ab22024-03-25 16:57:35 +0000321 'repeat_interleave',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700322 'reshape_as',
323 'reshape',
324 'resolve_conj',
325 'resolve_neg',
326 'scalar_tensor',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800327 'select',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700328 'sgn',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800329 'slice',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700330 'split',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800331 'split_with_sizes',
Yifu Wanga1280f02024-01-31 15:10:47 -0800332 'split_with_sizes_copy',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800333 'splitlist_args',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700334 'squeeze',
335 'squeezemultiple',
Nikita Shulgac7bb8422023-08-31 20:41:51 -0700336 'sub',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800337 'svd',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700338 't',
Nikita Shulga06787422024-06-11 15:37:03 -0700339 'tanh',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800340 'tensor_split',
341 'transpose',
342 'T',
343 'unbind',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700344 'unflatten',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800345 'unfold',
346 'unfold_copy',
347 'unsafe_chunk',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700348 'unsafe_split',
349 'unsqueeze',
350 'view_as',
351 'view_as_real',
352 'view',
353 'vsplit',
354 'zero_',
355 'zeros',
Li-Huai (Allan) Lin293d3b82023-09-11 11:56:27 -0700356 }
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000357
358 AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
359 '__rdiv__',
Nikita Shulga045309a2024-05-28 17:56:13 +0000360 '__rmatmul__',
Boyuan Feng35d3adb2024-03-08 21:48:08 +0000361 '_chunk_cat',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000362 'acos',
363 'acosh',
Nikita Shulgaff0f79d2024-01-06 01:10:11 +0000364 'all',
Nikita Shulga1d610112024-02-08 18:10:59 +0000365 'allclose',
Nikita Shulgaff0f79d2024-01-06 01:10:11 +0000366 'any',
Nikita Shulga1d610112024-02-08 18:10:59 +0000367 'addcdiv',
368 'addcmul',
Nikita Shulga045309a2024-05-28 17:56:13 +0000369 'addmmdecomposed',
370 'addmv',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000371 'asin',
372 'atan',
373 'atanh',
Nikita Shulga4ee8aac2024-02-11 16:25:29 +0000374 'bfloat16',
Nikita Shulga045309a2024-05-28 17:56:13 +0000375 'bmm',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000376 'bool',
377 'cartesian_prod',
378 'cat',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000379 'char',
380 'column_stack',
381 'combinations',
Nikita Shulga045309a2024-05-28 17:56:13 +0000382 'corrcoef',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000383 'constant_pad_nd',
384 'cos',
385 'cosh',
386 'count_nonzero',
387 'diff',
Nikita Shulga1d610112024-02-08 18:10:59 +0000388 'div',
389 'divno_rounding_mode',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000390 'dot',
391 'dstack',
Nikita Shulga045309a2024-05-28 17:56:13 +0000392 'einsum',
Nikita Shulga1d610112024-02-08 18:10:59 +0000393 'eq',
394 'equal',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000395 'exp2',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000396 'expm1',
Nikita Shulga53bfae22024-02-20 08:53:12 -0800397 'fft.fft',
398 'fft.fft2',
399 'fft.fftn',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000400 'fft.fftshift',
Nikita Shulga53bfae22024-02-20 08:53:12 -0800401 'fft.ifft',
402 'fft.ifft2',
403 'fft.ifftn',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000404 'fft.ifftshift',
jhavukainen6a539e82024-05-22 21:48:49 +0000405 'fft.irfftn',
406 'fft.irfft2',
407 'fft.irfft',
408 'fft.hfftn',
409 'fft.hfft2',
410 'fft.hfft',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000411 'flip',
412 'fliplr',
413 'flipud',
414 'float',
Nikita Shulga1d610112024-02-08 18:10:59 +0000415 'gradient',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000416 'half',
417 'hstack',
Nikita Shulga045309a2024-05-28 17:56:13 +0000418 'inner',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000419 'int',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000420 'isclose',
Nikita Shulga1d610112024-02-08 18:10:59 +0000421 'isnan',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000422 'ldexp',
Nikita Shulga045309a2024-05-28 17:56:13 +0000423 'linalg.multi_dot',
424 'linalg.pinv',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000425 'log10',
426 'log1p',
427 'log2',
428 'log',
Nikita Shulga1d610112024-02-08 18:10:59 +0000429 'logical_and',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000430 'logical_not',
Nikita Shulga1d610112024-02-08 18:10:59 +0000431 'logical_or',
432 'logical_xor',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000433 'long',
Nikita Shulga1d610112024-02-08 18:10:59 +0000434 'masked_fill',
435 'masked.mean',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000436 'masked.prod',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800437 'masked.std',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000438 'masked.sum',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800439 'masked.var',
Nikita Shulga045309a2024-05-28 17:56:13 +0000440 'matmul',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000441 'mean',
Nikita Shulga045309a2024-05-28 17:56:13 +0000442 'mm',
443 'mv',
Nikita Shulga1d610112024-02-08 18:10:59 +0000444 'ne',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000445 'neg',
446 'nn.functional.padconstant',
447 'nn.functional.padreflect',
448 'nn.functional.padreplicate',
449 'nn.functional.pixel_shuffle',
450 'nn.functional.pixel_unshuffle',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000451 'nn.functional.rms_norm',
452 'nn.functional.softsign',
Nikita Shulga045309a2024-05-28 17:56:13 +0000453 'pinverse',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000454 'prod',
455 'reciprocal',
456 'roll',
457 'rot90',
458 'rsqrt',
459 'short',
460 'sigmoid',
461 'sin',
462 'sinh',
463 'sqrt',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800464 'square',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000465 'stack',
Nikita Shulga53bfae22024-02-20 08:53:12 -0800466 'stft',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000467 'sum',
468 'sum_to_size',
469 'tan',
Nikita Shulga045309a2024-05-28 17:56:13 +0000470 'tensordot',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000471 'trace',
Nikita Shulga1d610112024-02-08 18:10:59 +0000472 'trapz',
473 'trapezoid',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000474 'tril',
475 'triu',
Nikita Shulga1d610112024-02-08 18:10:59 +0000476 'true_divide',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000477 'vstack',
478 'where',
479 }
Kulin Seth2bb022e2023-03-08 08:41:21 +0000480 # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
481 MACOS_12_3_XFAILLIST = {
482 # Top 60
483 # expected failures
484 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
485 # fixed in macOS 13.3. Currently error is not raised.
486 'pow': [torch.int16, torch.int64, torch.uint8, torch.int8],
487 # expected failures
488 '__rpow__': [torch.uint8, torch.int8],
489
490 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
491 'cdist': [torch.float32],
492 'tan': [torch.uint8, torch.float32],
493
494 # Data type support starts from macOS 13
495 'nn.functional.avg_pool1d': [torch.int64],
496 'nn.functional.avg_pool2d': [torch.int64],
497 'nn.functional.local_response_norm': [torch.int64],
498 '__radd__': [torch.uint8],
499 '__rdiv__': [torch.uint8],
500 '__rmul__': [torch.uint8],
501 'abs': [torch.uint8],
502 'acos': [torch.uint8],
503 'acosh': [torch.uint8],
504 'add': [torch.uint8],
505 'asin': [torch.uint8],
506 'asinh': [torch.uint8],
507 'atan': [torch.uint8],
508 'atanh': [torch.uint8],
509 'ceil': [torch.uint8],
510 'corrcoef': [torch.uint8],
511 'cos': [torch.uint8],
512 'cosh': [torch.uint8],
513 'cov': [torch.uint8],
514 'cumulative_trapezoid': [torch.uint8],
515 'deg2rad': [torch.uint8],
516 'diff': [torch.uint8],
517 'eq': [torch.uint8],
518 'equal': [torch.uint8],
519 'erf': [torch.uint8],
520 'exp2': [torch.uint8],
521 'exp': [torch.uint8],
522 'expm1': [torch.uint8],
523 'floor': [torch.uint8],
524 'fmax': [torch.uint8],
525 'fmin': [torch.uint8],
526 'fmod': [torch.uint8],
527 'ge': [torch.uint8],
528 'gt': [torch.uint8],
529 'isclose': [torch.uint8],
530 'isnan': [torch.uint8],
531 'kron': [torch.uint8],
532 'le': [torch.uint8],
533 'log10': [torch.uint8],
534 'log1p': [torch.uint8],
535 'log2': [torch.uint8],
536 'log': [torch.uint8],
537 'logical_and': [torch.uint8],
538 'logical_or': [torch.uint8],
539 'logical_xor': [torch.uint8],
540 'logit': [torch.uint8],
541 'lt': [torch.uint8],
542 'masked.mean': [torch.uint8],
543 'masked.std': [torch.uint8],
544 'masked.var': [torch.uint8],
545 'maximum': [torch.uint8],
546 'minimum': [torch.uint8],
547 'mul': [torch.uint8],
548 'ne': [torch.uint8],
549 'neg': [torch.uint8],
550 'nn.functional.cosine_embedding_loss': [torch.uint8],
551 'nn.functional.margin_ranking_loss': [torch.uint8],
552 'nn.functional.poisson_nll_loss': [torch.uint8],
553 'nn.functional.softsign': [torch.uint8],
554 'nn.functional.tanhshrink': [torch.uint8],
555 'nn.functional.triplet_margin_loss': [torch.uint8],
556 'nn.functional.triplet_margin_with_distance_loss': [torch.uint8],
Denis Vieriu89baa1a2023-04-26 01:34:24 +0000557 'nn.functional.pairwise_distance': [torch.uint8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000558 'outer': [torch.uint8],
559 'rad2deg': [torch.uint8],
560 'reciprocal': [torch.uint8],
561 'remainder': [torch.uint8],
562 'round': [torch.uint8],
563 'rsqrt': [torch.uint8],
564 'sigmoid': [torch.uint8],
565 'sign': [torch.uint8],
566 'signbit': [torch.uint8],
567 'sin': [torch.uint8],
568 'sinh': [torch.uint8],
569 'special.ndtr': [torch.uint8],
570 'sqrt': [torch.uint8],
571 'sub': [torch.uint8],
572 'tanh': [torch.uint8],
573 'trapezoid': [torch.uint8],
574 'trapz': [torch.uint8],
575 'true_divide': [torch.uint8],
576 'trunc': [torch.uint8],
577 'xlogy': [torch.uint8],
578 'minbinary': [torch.uint8],
579 'maxbinary': [torch.uint8],
580 'divtrunc_rounding': [torch.uint8],
581 'divfloor_rounding': [torch.uint8],
582 'divno_rounding_mode': [torch.uint8],
583 'floor_divide': [torch.uint8],
584 'ldexp': [torch.uint8],
585 # square internally calls into power, and will type cast to int64, which supports starting from macOS 13
586 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
587
588 # cpu not giving nan for x/0.0
589 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
vfdev-5d2a2a672023-10-06 10:01:15 +0000590
591 # inconsistency errors between cpu and mps, max seen atol is 2
592 'nn.functional.interpolatebilinear': [torch.uint8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000593 }
594
595 MACOS_BEFORE_13_3_XFAILLIST = {
igm5031b9b3a22023-09-12 16:43:37 +0000596 # Failure due to precision issues (still present on 13.3+) as well as non-standard behavior of
597 # cpu ops for the negative integers.
598 # Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)):
599 # - CPU output: tensor([102.668, 1.129e+15])
600 # - MPS output: tensor([102.6681, inf])
601 # In the latter case, inf is probably correct (this is what scipy does).
602 'polygamma': [torch.float32, torch.uint8],
Pearu Peterson2c91e132024-02-11 15:03:36 +0200603 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
604 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
605 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
606 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
607 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
608 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
609 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
igm5031b9b3a22023-09-12 16:43:37 +0000610
Kulin Seth2bb022e2023-03-08 08:41:21 +0000611 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
612 'tan': [torch.float32],
613 'cdist': [torch.float32],
614
615 # CPU Error: cpu not giving nan for x/0.0
CaoE455241b2023-11-06 06:01:29 +0000616 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000617
618 # test blow pass on macOS 12 as it falls back to cpu
619 # Argsort case using duplicate indices (undefined behaviour):
620 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
621 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
622 # Elements from index 30 and 5133 are both equal.
623 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
624 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
625 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
626 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
627 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
628 # Unsupported dtypes
629 'cumsum': [torch.int64],
Peter Stefek97e50552023-08-01 21:51:16 +0000630 'cumprod': [torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000631 'cumulative_trapezoid': [torch.int64],
632 'masked.cumsum': [torch.int64],
Peter Stefek97e50552023-08-01 21:51:16 +0000633 'masked.cumprod': [torch.int64],
634 'linalg.vander': [torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000635 }
636
Nikita Shulga87084642023-05-11 10:35:05 +0000637 MACOS_AFTER_13_1_XFAILLIST = {
638 # before macOS 13.2 it falls back to cpu and pass the forward pass
Kulin Seth2bb022e2023-03-08 08:41:21 +0000639 'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode
vfdev-5d2a2a672023-10-06 10:01:15 +0000640 # inconsistency errors between cpu and mps, max seen atol is 2
641 'nn.functional.interpolatebilinear': [torch.uint8],
Nikita Shulga87084642023-05-11 10:35:05 +0000642 }
Kulin Seth2bb022e2023-03-08 08:41:21 +0000643
Nikita Shulga87084642023-05-11 10:35:05 +0000644 MACOS_13_3_XFAILLIST = {
Kulin Seth2bb022e2023-03-08 08:41:21 +0000645 # Failure due to precision issue for fp16
646 # on both cpu and mps there are test cases that might produce inf result
647 # 'nn.functional.pairwise_distance': [torch.float16],
648
649 # test blow pass on macOS 12 as it falls back to cpu
650 # Argsort case using duplicate indices (undefined behaviour):
651 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
652 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
653 # Elements from index 30 and 5133 are both equal.
654 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
655 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
656 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
657 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
658 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
igm5031b9b3a22023-09-12 16:43:37 +0000659
660 # Failure due to precision issues as well as non-standard behavior of cpu ops for the
661 # negative integers. Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)):
662 # - CPU output: tensor([102.668, 1.129e+15])
663 # - MPS output: tensor([102.6681, inf])
664 # In the latter case, inf is probably correct (this is what scipy does).
665 'polygamma': [torch.float32, torch.uint8],
Pearu Peterson2c91e132024-02-11 15:03:36 +0200666 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
667 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
668 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
669 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
670 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
671 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
672 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000673 }
674
Huy Do89921412024-06-05 14:44:00 +0000675 MACOS_BEFORE_14_4_XFAILLIST = {
676 # These ops work fine in 14.4 but fail in 14.2 or 13.x
677 'fft.hfft2': [torch.complex64],
678 }
679
Kulin Seth2bb022e2023-03-08 08:41:21 +0000680 # Those ops are not expected to work
681 UNIMPLEMENTED_XFAILLIST = {
682 # Failures due to lack of op implementation on MPS backend
683 'login': None,
684 'log_sigmoid': None,
685 'log_sigmoid_forward': None,
686 'linalg.eig': None,
687 'linalg.eigvals': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000688 'put': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000689 'nn.functional.conv_transpose3d': None,
690 'rounddecimals_neg_3': None,
691 'rounddecimals_3': None,
692 'rounddecimals_0': None,
693 '__rsub__': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000694 'angle': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000695 'cauchy_': None,
696 'cauchy': None,
697 'cholesky': None,
698 'cholesky_inverse': None,
699 'cholesky_solve': None,
700 'cummax': None,
701 'cummin': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000702 'erfc': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000703 'frexp': None,
704 'gcd': None,
705 'geqrf': None,
706 'nn.functional.grid_sample': None, # Unsupported Border padding mode
707 'heaviside': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000708 'i0': None,
709 'igamma': None,
710 'igammac': None,
711 'index_copy': None,
Pearu Petersond2b0c0a2024-04-17 14:30:26 +0300712 'index_reduceprod': None,
713 'index_reducemean': None,
714 'index_reduceamax': None,
715 'index_reduceamin': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000716 'isneginf': None,
717 'isposinf': None,
718 'kthvalue': None,
719 'lcm': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000720 'linalg.cholesky': None,
721 'linalg.cholesky_ex': None,
722 'linalg.cond': None,
723 'linalg.detsingular': None,
724 'linalg.det': None,
725 'linalg.eigh': None,
726 'linalg.eigvalsh': None,
727 'linalg.householder_product': None,
728 'linalg.ldl_factor': None,
729 'linalg.ldl_factor_ex': None,
730 'linalg.ldl_solve': None,
731 'linalg.lstsq': None,
732 'linalg.lstsqgrad_oriented': None,
733 'linalg.lu': None,
734 'linalg.lu_factor': None,
735 'linalg.lu_factor_ex': None,
736 'linalg.lu_solve': None,
737 'linalg.matrix_norm': [torch.float32],
738 'linalg.norm': [torch.float32],
739 'linalg.normsubgradients_at_zero': [torch.float32],
740 'linalg.qr': None,
741 'linalg.slogdet': None,
742 'linalg.solve': None,
743 'linalg.solve_ex': None,
744 'linalg.svdvals': None,
745 'linalg.tensorsolve': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000746 'linalg.vecdot': None,
747 'logcumsumexp': None,
748 'logdet': None,
749 'lu': None,
750 'lu_solve': None,
751 'lu_unpack': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000752 'masked.median': None,
753 'matrix_exp': None,
754 'mode': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000755 'nanquantile': None,
756 'nanmedian': None,
757 'native_dropout_backward': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000758 'normnuc': None,
759 'nn.functional.fractional_max_pool2d': None,
760 'nn.functional.fractional_max_pool3d': None,
761 'nn.functional.adaptive_avg_pool3d': None,
762 'nn.functional.adaptive_max_pool3d': None,
763 'nn.functional.interpolatearea': None,
764 'nn.functional.interpolatebicubic': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000765 'nn.functional.interpolatetrilinear': None,
Jerry Zhang611febf2023-07-05 13:27:37 -0700766 # TODO: max_pool2d for integral types fails the numerical test
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000767 'nn.functional.max_pool2d': (integral_types() if product_version < 14.0 else
768 [torch.int64, torch.int32, torch.int16, torch.int8]),
Kulin Seth2bb022e2023-03-08 08:41:21 +0000769 'nn.functional.max_unpool1dgrad': None,
770 'nn.functional.max_unpool2dgrad': None,
771 'nn.functional.max_unpool3dgrad': None,
772 'nn.functional.avg_pool3d': None,
773 'nn.functional.ctc_loss': None,
774 'nn.functional.embedding_bag': None,
775 'nn.functional.hardshrink': None,
776 'nn.functional.max_pool3d': None,
777 'nn.functional.max_unpool1d': None,
778 'nn.functional.max_unpool2d': None,
779 'nn.functional.max_unpool3d': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000780 'nn.functional.multi_margin_loss': None,
781 'nn.functional.multilabel_margin_loss': None,
782 'nn.functional.pdist': None,
783 'nn.functional.rrelu': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000784 'nn.functional.norm': None,
785 'ormqr': None,
786 'pca_lowrank': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000787 'qr': None,
788 'quantile': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000789 'rsub': None,
790 'scatter_reduceamax': None,
791 'scatter_reduceamin': None,
792 'scatter_reducemin': None,
793 'scatter_reducemean': None,
794 'scatter_reduceprod': None,
795 'scatter_reducesum': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000796 'segment_reduce': None,
797 '_segment.reduce': None,
798 'segment.reduce': None,
799 'segment_reduce_offsets': None,
800 '_segment_reduce_offsets': None,
801 '_segment_reduce_lengths': None,
802 '_segment_reducelengths': None,
803 '_segment_reduceoffsets': None,
804 'sinc': None,
805 'sparse.mm': None,
806 'sparse.mmreduce': None,
807 'special.airy_ai': None,
808 'special.bessel_j0': None,
809 'special.bessel_j1': None,
810 'special.bessel_y0': None,
811 'special.bessel_y1': None,
812 'special.chebyshev_polynomial_t': None,
813 'special.chebyshev_polynomial_u': None,
814 'special.entr': None,
815 'special.erfcx': None,
816 'special.hermite_polynomial_h': None,
817 'special.hermite_polynomial_he': None,
818 'special.i0e': None,
819 'special.i1': None,
820 'special.i1e': None,
821 'special.laguerre_polynomial_l': None,
822 'special.log_ndtr': None,
823 'special.modified_bessel_i0': None,
824 'special.modified_bessel_i1': None,
825 'special.modified_bessel_k0': None,
826 'special.modified_bessel_k1': None,
827 'special.ndtri': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000828 'special.scaled_modified_bessel_k0': None,
829 'special.scaled_modified_bessel_k1': None,
830 'special.spherical_bessel_j0': None,
831 'special.xlog1py': None,
832 'special.zeta': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000833 'svd_lowrank': None,
834 'symeig': None,
835 'take': None,
836 'to': None,
837 'to_sparse': None,
838 'unique': None,
839 'vdot': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000840 'segment_reduce_': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000841 '_upsample_bilinear2d_aa': None,
842 'geometric' : None,
843 'geometric_': None,
844 'log_normal_': None,
845 'log_normal': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000846 'cdouble': None,
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800847 'double': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000848 'nn.functional.softminwith_dtype': None,
849 'log_softmaxwith_dtype': None,
850 'softmaxwith_dtype': None,
851 'float_power': None,
852 'full_like': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000853 'linalg.matrix_rankhermitian': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000854 'linalg.pinvhermitian': None,
Guang Yangc377a852023-04-11 05:13:36 +0000855 'nonzero_static': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000856
857 # MPS: input sizes must be divisible by output sizes
858 'nn.functional.adaptive_avg_pool1d': None,
859 'nn.functional.adaptive_avg_pool2d': None,
860
861 # Unsupported dtypes
862 # bmm is not supported for integral types
863 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
864 # Cannot convert a MPS Tensor to float64 dtype. The tensors
865 # input data is created with double in common_methods_invocations.py
866 'nn.functional.batch_norm': [torch.float32],
867 'ones_like': None,
868 'zeros_like': None,
869
870 # Convolution for integral types is not supported on MPS
871 'nn.functional.conv1d': [torch.int64],
872 'nn.functional.conv2d': [torch.int64],
Khushi Agrawalcff84872023-11-27 14:45:44 +0000873 'nn.functional.conv3d': [torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000874 'nn.functional.conv_transpose1d': [torch.int64],
875 'nn.functional.conv_transpose2d': [torch.int64],
876
877 # Unsupported dtypes
878 'dot': [torch.int64],
CaoEa310cc82023-10-31 09:12:47 +0000879 'histc': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000880 'index_add': [torch.int64],
881 'log1p': [torch.int64],
882 'sigmoid': [torch.int64],
883 'atan2': [torch.int64],
884
885 # GEMM on MPS is not supported for integral types
886 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
887 '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
888 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
889 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
890 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
891 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
892 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
893 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
894 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
895 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
896 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
897 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
898 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
899 'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
900 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
901 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
Kurt Mohler5292a922023-10-12 00:55:51 +0000902 'unravel_index': [torch.int32, torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000903
904 # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as
905 # the MPS framework doesn't support float64
906 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
907 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
908 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
909 # returned output on CPU is float64
910 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
911
912 # trunc_tensor not working properly for float16
913 'divtrunc_rounding': [torch.float16],
914 'fmod': [torch.float16],
Sun, Jiayid56e1b22023-05-11 15:30:59 +0800915
916 # round not working properly for float16
917 'round': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000918 }
919
Nikita Shulga53bfae22024-02-20 08:53:12 -0800920 if product_version < 14.0:
921 # FFT and BFloat16 support was added in MacOS 14
922 UNIMPLEMENTED_XFAILLIST.update({
923 'bfloat16': None,
924 'fft.fft': None,
925 'fft.fft2': None,
926 'fft.fftn': None,
927 'fft.hfft': None,
jhavukainen6a539e82024-05-22 21:48:49 +0000928 'fft.hfft2': None,
929 'fft.hfftn': None,
Nikita Shulga53bfae22024-02-20 08:53:12 -0800930 'fft.ifft': None,
931 'fft.ifft2': None,
932 'fft.ifftn': None,
933 'fft.ihfft': None,
934 'fft.ihfft2': None,
935 'fft.ihfftn': None,
936 'fft.irfft': None,
937 'fft.irfft2': None,
938 'fft.irfftn': None,
939 'fft.rfft': None,
940 'fft.rfft2': None,
941 'fft.rfftn': None,
942 'stft': None,
jhavukainend28868c2024-05-20 20:23:53 +0000943 # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers,
Joona Havukainenc451d102024-05-01 23:14:05 +0000944 # not reproducible in later OS. Added assert to op if used in < 14.0
jhavukainend28868c2024-05-20 20:23:53 +0000945 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8],
Nikita Shulga53bfae22024-02-20 08:53:12 -0800946 })
947
Kulin Seth2bb022e2023-03-08 08:41:21 +0000948 UNDEFINED_XFAILLIST = {
949 # Top 60 operators
950 # topk fails with duplicate indices
951 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
952
953 # Failures due to random output that they generate using
954 # Philox engine causing mismatch with CPU results
CaoEd1afb7d2023-10-19 19:05:09 -0700955 'multinomial': [torch.float16, torch.float32], # random results
Kulin Seth2bb022e2023-03-08 08:41:21 +0000956 'uniform': [torch.float16, torch.float32],
957 'rand_like': [torch.float16, torch.float32],
958 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
959 'randn_like': [torch.float16, torch.float32],
CaoE8713a1a2023-10-11 23:54:31 -0700960 'bernoulli': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000961 'exponential': [torch.float16, torch.float32],
CaoE8713a1a2023-10-11 23:54:31 -0700962 'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000963 'normal': [torch.float16, torch.float32, torch.float16, torch.float32],
964 'normalin_place': [torch.float16, torch.float32],
965 'normalnumber_mean': [torch.float16, torch.float32],
CaoE8713a1a2023-10-11 23:54:31 -0700966 'nn.functional.alpha_dropout': [torch.float16, torch.float32],
967 'nn.functional.dropout': [torch.float16, torch.float32],
968 'nn.functional.dropout2d': [torch.float16, torch.float32],
969 'nn.functional.dropout3d': [torch.float16, torch.float32],
Cao E1c89ea72023-10-26 08:38:54 +0000970 # See https://github.com/pytorch/pytorch/issues/111479
971 'nn.functional.multi_head_attention_forward': [torch.float32, torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000972
Kulin Seth2bb022e2023-03-08 08:41:21 +0000973 # duplicate indices are used in the testcase - undefined behaviour
974 'index_put': None,
975 # zero to negative integer powers are undefined
976 '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64],
977 'resize_': [torch.float16, torch.float32],
978 'resize_as_': [torch.float16, torch.float32],
979
980 # CPU Errors:
981 'addr': [torch.bool, torch.int16, torch.int32,
982 torch.int64, torch.uint8, torch.int8], # "addmv_impl_cpu" not implemented for 'Half'
983 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
984 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
985 'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
986 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
987
988 # random results
989 # mps vs cpu:
990 # Mismatched elements: 40 / 96 (41.7%)
991 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
992 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
993 # cuda(2.0.0.dev20230301+cu117) vs cpu:
994 # Mismatched elements: 56 / 96 (58.3%)
995 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
996 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
Cao E1c89ea72023-10-26 08:38:54 +0000997 'nn.functional.scaled_dot_product_attention': [torch.float32, torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000998
999 # Failures due to casting negative float to uint8 is undefined
1000 'byte': [torch.float16, torch.float32],
CaoEa310cc82023-10-31 09:12:47 +00001001 # float output for float16 input on MPS
1002 'logit': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +00001003 }
1004
watarungurunnnd444a3b2024-02-05 15:36:55 +00001005 ON_MPS_XFAILLIST = {
1006 # Failures due to lack of implementation of downstream functions on MPS backend
1007 # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
1008 'linalg.matrix_rank': None,
1009 }
1010
Li-Huai (Allan) Lin13da6582023-05-01 14:55:02 +08001011 EMPTY_OPS_SKIPLIST = {
1012 # Fill tensors with uninitialized data, causing mismatch with CPU.
1013 # They occasionally match, thus skipping them.
1014 # See https://github.com/pytorch/pytorch/issues/100175
1015 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
1016 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16,
1017 torch.int32, torch.int64, torch.uint8, torch.int8],
Khushi1aaf0392023-05-19 03:06:29 +00001018 'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
Li-Huai (Allan) Lin13da6582023-05-01 14:55:02 +08001019 # CPU: empty is returning all 0's and there is a mismatch with MPS
1020 # allocation (MacOS 13). According to
1021 # https://pytorch.org/docs/2.0/generated/torch.empty.html
1022 'empty': [torch.bool, torch.float16, torch.float32, torch.int16,
1023 torch.int32, torch.int64, torch.uint8, torch.int8],
1024 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
1025 'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16,
1026 torch.int32, torch.int64, torch.uint8, torch.int8],
1027 }
1028
Peter Bell46e80ce2023-10-24 15:19:01 +01001029 SKIPLIST = {
mingfeimaa8acd6c2023-12-12 12:59:47 +00001030 # Unsupported
1031 # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible
1032 'nn.functional.avg_pool2d': [torch.float16],
Huy Do89921412024-06-05 14:44:00 +00001033
1034 # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
1035 'nn.functional.conv3d': None,
Peter Bell46e80ce2023-10-24 15:19:01 +01001036 }
1037
Kulin Seth2bb022e2023-03-08 08:41:21 +00001038 def addDecorator(op, d) -> None:
1039 op.decorators = list(op.decorators) if op.decorators is not None else []
1040 op.decorators.append(d)
1041
1042 for op in ops:
1043 key = op.name + op.variant_test_name
Li-Huai (Allan) Lin13da6582023-05-01 14:55:02 +08001044 if key in EMPTY_OPS_SKIPLIST:
1045 addDecorator(op, DecorateInfo(
1046 unittest.skip("Skipping empty ops."),
1047 dtypes=EMPTY_OPS_SKIPLIST[key]))
Peter Bell46e80ce2023-10-24 15:19:01 +01001048 if key in SKIPLIST:
1049 addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
watarungurunnnd444a3b2024-02-05 15:36:55 +00001050 for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
Kulin Seth2bb022e2023-03-08 08:41:21 +00001051 if key in xfaillist:
1052 addDecorator(op, DecorateInfo(
1053 unittest.expectedFailure,
1054 dtypes=xfaillist[key]))
1055
Huy Do89921412024-06-05 14:44:00 +00001056 if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4):
1057 addDecorator(op, DecorateInfo(
1058 unittest.expectedFailure,
1059 dtypes=MACOS_BEFORE_14_4_XFAILLIST[key]))
1060
Kulin Seth2bb022e2023-03-08 08:41:21 +00001061 if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
1062 addDecorator(op, DecorateInfo(
1063 unittest.expectedFailure,
1064 dtypes=MACOS_BEFORE_13_3_XFAILLIST[key]))
1065
Nikita Shulga87084642023-05-11 10:35:05 +00001066 if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2):
1067 addDecorator(op, DecorateInfo(
1068 unittest.expectedFailure,
1069 dtypes=MACOS_AFTER_13_1_XFAILLIST[key]))
1070
Kulin Seth2bb022e2023-03-08 08:41:21 +00001071 if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3):
1072 addDecorator(op, DecorateInfo(
1073 unittest.expectedFailure,
1074 dtypes=MACOS_13_3_XFAILLIST[key]))
1075
1076 if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()):
1077 addDecorator(op, DecorateInfo(
1078 unittest.expectedFailure,
1079 dtypes=MACOS_12_3_XFAILLIST[key]))
Nikita Shulgab0393eb2024-01-05 00:25:47 +00001080
Nikita Shulga53a4ca42023-08-31 20:41:39 -07001081 # If ops is not supported for complex types, expect it to fail
Nikita Shulgab0393eb2024-01-05 00:25:47 +00001082 if key not in SUPPORTED_COMPLEX_OPS and (key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS or product_version < 14.0):
Nikita Shulga53a4ca42023-08-31 20:41:39 -07001083 addDecorator(op, DecorateInfo(unittest.expectedFailure, dtypes=[torch.complex32, torch.complex64]))
Peter Bell46e80ce2023-10-24 15:19:01 +01001084
Nikita Shulgafd8367a2023-02-27 15:01:01 +00001085 yield op
1086
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001087def mps_ops_error_inputs_modifier(ops):
1088 # Error input samples do not take a dtype argument.
1089 XFAILLIST = {
1090 # Exceptions are not raised
1091 '__rmod__',
1092 '__rsub__',
albanD08cbfb22023-07-12 18:11:24 +00001093 '__rpow__',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001094 'bernoulli',
1095 'clamp_max',
1096 'clamp_min',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001097 'masked_scatter',
1098
1099 # unsupported float64 dtype
1100 'cat',
1101 'complex',
1102 'multinomial',
1103 'nn.functional.conv1d',
1104 'nn.functional.conv2d',
Khushi Agrawalcff84872023-11-27 14:45:44 +00001105 'nn.functional.conv3d',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001106 'gather',
1107 'scatter',
1108 'scatter_add',
1109
1110 # unsupported complex dtypes
1111 'masked_fill',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001112
1113 # MPS does not support tensor dimensions > 16
1114 'amax',
1115 'amin',
Li-Huai (Allan) Lina50fb502023-05-01 14:54:57 +08001116 'aminmax',
1117
1118 # memory overlapping checks
1119 'index_select',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001120
1121 # unimplemented
1122 'logcumsumexp',
1123 }
1124
1125 def addDecorator(op, d) -> None:
1126 op.decorators = list(op.decorators) if op.decorators is not None else []
1127 op.decorators.append(d)
1128
1129 for op in ops:
1130 if op.error_inputs_func is None:
1131 continue
1132 key = op.name + op.variant_test_name
1133 if key in XFAILLIST:
1134 addDecorator(op, DecorateInfo(unittest.expectedFailure))
1135 yield op
1136
Kulin Sethe011a8e2022-05-13 18:28:53 +00001137# Same logic as test_cuda.py
1138if not torch.backends.mps.is_available():
1139 print('MPS not available, skipping tests', file=sys.stderr)
Catherine Leeeea07332023-03-07 18:30:27 +00001140 TestCase = NoTest # noqa: F811
1141 NNTestCase = NoTest # noqa: F811
Kulin Sethe011a8e2022-05-13 18:28:53 +00001142
Pearu Peterson45401ef2023-06-14 14:00:05 +03001143product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
Nikita Shulgaabf3f902024-04-22 23:43:11 +00001144total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
Denis Vieriu71ec2612023-02-15 06:09:56 +00001145
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001146# Determine whether to enable MPS memory leak check (uses same code as CUDA).
1147TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
1148
1149def skipMPSMemoryLeakCheckIf(condition):
1150 def dec(fn):
1151 if getattr(fn, '_do_mps_memory_leak_check', True):
1152 fn._do_mps_memory_leak_check = not condition
1153 return fn
1154 return dec
1155
Justin Chu73e14552023-07-19 07:40:18 -07001156class MpsMemoryLeakCheck:
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001157 def __init__(self, testcase, name=None):
1158 self.name = testcase.id() if name is None else name
1159 self.testcase = testcase
1160
1161 def __enter__(self):
1162 # Performs a gc if required (required if any memory is held)
1163 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1164 if caching_allocator_mem_allocated > 0:
1165 gc.collect()
1166 torch.mps.empty_cache()
1167
1168 # Acquires caching allocator and driver statistics before the test is run
1169 self.caching_allocator_before = torch.mps.current_allocated_memory()
1170 self.driver_before = torch.mps.driver_allocated_memory()
1171
1172 def __exit__(self, exec_type, exec_value, traceback):
1173 # Don't check for leaks if an exception was thrown
1174 if exec_type is not None:
1175 return
1176 # Compares caching allocator before/after statistics
1177 # An increase in allocated memory is a discrepancy indicating a possible memory leak
1178 discrepancy_detected = False
1179 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1180 if caching_allocator_mem_allocated > self.caching_allocator_before:
1181 discrepancy_detected = True
1182
1183 # Short-circuits if no discrepancy detected
1184 if not discrepancy_detected:
1185 return
1186 # Validates the discrepancy persists after garbage collection and
1187 # is confirmed by the driver API
1188 gc.collect()
1189 torch.mps.empty_cache()
1190
1191 discrepancy_detected = True
1192 # Query memory multiple items to ensure leak was not transient
1193 for n in range(3):
1194 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1195 driver_mem_allocated = torch.mps.driver_allocated_memory()
1196
1197 caching_allocator_discrepancy = False
1198 driver_discrepancy = False
1199
1200 if caching_allocator_mem_allocated > self.caching_allocator_before:
1201 caching_allocator_discrepancy = True
1202
1203 if driver_mem_allocated > self.driver_before:
1204 driver_discrepancy = True
1205
Aaron Gokaslan3fe437b22024-01-03 06:04:44 +00001206 if not (caching_allocator_discrepancy or driver_discrepancy):
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001207 # Leak was false positive, exit loop
1208 discrepancy_detected = False
1209 break
1210
1211 if caching_allocator_discrepancy and not driver_discrepancy:
1212 # Just raises a warning if the leak is not validated by the driver API
1213 msg = ("MPS caching allocator reports a memory leak not "
Aaron Gokaslan5a1216b2024-04-21 14:06:20 +00001214 f"verified by the driver API in {self.name}! "
1215 f"Caching allocator allocated memory was {self.caching_allocator_before} "
1216 f"and is now reported as {caching_allocator_mem_allocated}. "
1217 f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001218 warnings.warn(msg)
1219 elif caching_allocator_discrepancy and driver_discrepancy:
1220 # A caching allocator discrepancy validated by the driver API is a failure
Aaron Gokaslan5a1216b2024-04-21 14:06:20 +00001221 msg = (f"MPS driver API confirmed a leak in {self.name}! "
1222 f"Caching allocator allocated memory was {self.caching_allocator_before} "
1223 f"and is now reported as {caching_allocator_mem_allocated}. "
1224 f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001225
1226 raise RuntimeError(msg)
1227
1228# Expand TestCase class with Memory Leak Detection on MPS device
1229class TestCaseMPS(TestCase):
1230 _do_mps_memory_leak_check = True
1231
1232 def __init__(self, method_name='runTest'):
1233 super().__init__(method_name)
1234 test_method = getattr(self, method_name, None)
1235 if test_method is not None:
1236 # Wraps the tested method if we should do MPS memory check.
1237 if TEST_MPS_MEM_LEAK_CHECK:
1238 if self._do_mps_memory_leak_check:
1239 self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
1240
1241 def assertLeaksNoMpsTensors(self, name=None):
1242 name = self.id() if name is None else name
1243 return MpsMemoryLeakCheck(self, name)
1244
1245 def wrap_with_mps_policy(self, method_name, policy):
1246 test_method = getattr(self, method_name)
1247 setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
1248
1249 # checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
1250 def wrap_with_mps_memory_check(self, method):
1251 return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
1252
1253class TestMemoryLeak(TestCaseMPS):
1254 def test_mps_memory_leak_detection(self):
1255 l = []
1256
1257 @self.wrap_with_mps_memory_check
1258 def no_leak():
1259 pass
1260
1261 # Trigger an intentional memory leak
1262 @self.wrap_with_mps_memory_check
1263 def leak_gpu0():
1264 # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1265 l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
1266
1267 no_leak()
1268
1269 # check if a runtime error for memory leak was emitted which would
1270 # confirm whether memory leak detection worked successfully or not.
1271 with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
1272 leak_gpu0()
1273
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001274 def test_copy_cast_no_leak(self):
Nikita Shulga324cde52023-11-22 14:48:24 +00001275
1276 def step(x):
1277 x = x.to(device='cpu', dtype=torch.float32)
1278 x = x.to(device='mps', dtype=torch.float16)
1279
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001280 a = torch.randn(128, 128, device='mps', dtype=torch.float16)
Nikita Shulga324cde52023-11-22 14:48:24 +00001281 # Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
1282 step(a)
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001283 torch.mps.empty_cache()
1284 driver_before = torch.mps.driver_allocated_memory()
Nikita Shulga324cde52023-11-22 14:48:24 +00001285 step(a)
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001286 torch.mps.empty_cache()
1287 driver_after = torch.mps.driver_allocated_memory()
1288 self.assertTrue(driver_before == driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
1289
alexdremovb60273b2023-09-06 09:11:39 +00001290
1291class TestPixelShuffle(TestCaseMPS):
1292 def test_pixel_shuffle_unshuffle(self):
1293 def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
1294 upscale_factor=None, is_contiguous=True):
1295
1296 def generate_input():
1297 # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
1298 channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
1299 height = random.randint(5, 10)
1300 width = random.randint(5, 10)
1301
1302 if num_input_dims == 1:
1303 input = torch.rand(channels, requires_grad=True, device='mps')
1304 assert is_contiguous
1305 elif num_input_dims == 2:
1306 input = torch.rand(width, height, requires_grad=True, device='mps').T
1307 if is_contiguous:
1308 input = input.contiguous()
1309 else:
1310 batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1311 input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
1312 input = input.transpose(-1, -2)
1313 if is_contiguous:
1314 input = input.contiguous()
1315
1316 if not is_contiguous and len(input.reshape(-1)) > 0:
1317 assert not input.is_contiguous()
1318
1319 input = input.detach().clone()
1320 input.requires_grad = True
1321 return input
1322
1323 # Function to imperatively ensure pixels are shuffled to the correct locations.
1324 # Used to validate the batch operations in pixel_shuffle.
1325 def _verify_pixel_shuffle(input, output, upscale_factor):
1326 for c in range(output.size(-3)):
1327 for h in range(output.size(-2)):
1328 for w in range(output.size(-1)):
1329 height_idx = h // upscale_factor
1330 weight_idx = w // upscale_factor
1331 channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
1332 (c * upscale_factor ** 2)
1333 self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
1334
1335 upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
1336 input = generate_input()
1337
1338 ps = nn.PixelShuffle(upscale_factor)
1339 pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
1340
1341 if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
1342 output = ps(input)
1343 _verify_pixel_shuffle(input, output, upscale_factor)
1344 output.backward(output.data)
1345 self.assertEqual(input.data, input.grad.data)
1346
1347 # Ensure unshuffle properly inverts shuffle.
1348 unshuffle_output = pus(output)
1349 self.assertEqual(input, unshuffle_output)
1350 else:
1351 self.assertRaises(RuntimeError, lambda: ps(input))
1352
1353 def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
1354 downscale_factor=None):
1355 downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
1356 channels = random.randint(1, 4)
1357 # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
1358 height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
1359 # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
1360 width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
1361
1362 if num_input_dims == 1:
1363 input = torch.rand(channels, requires_grad=True, device='mps')
1364 elif num_input_dims == 2:
1365 input = torch.rand(height, width, requires_grad=True, device='mps')
1366 else:
1367 batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1368 input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
1369
1370 pus = nn.PixelUnshuffle(downscale_factor)
1371 self.assertRaises(RuntimeError, lambda: pus(input))
1372
1373 def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
1374 # For 1D - 2D, this is an error case.
1375 # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
1376 is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
1377 for is_contiguous in is_contiguous_check:
1378 _test_pixel_shuffle_unshuffle_helper(
1379 num_input_dims=num_input_dims, is_contiguous=is_contiguous
1380 )
1381 _test_pixel_shuffle_unshuffle_helper(
1382 num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
1383 )
1384 _test_pixel_shuffle_unshuffle_helper(
1385 num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
1386 )
1387 _test_pixel_shuffle_unshuffle_helper(
1388 num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
1389 )
1390
1391 # Error cases for pixel_unshuffle.
1392 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
1393 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
1394 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
1395 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
1396
1397 def test_pixel_shuffle_unshuffle_1D():
1398 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
1399
1400 def test_pixel_shuffle_unshuffle_2D():
1401 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
1402
1403 def test_pixel_shuffle_unshuffle_3D():
1404 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
1405
1406 def test_pixel_shuffle_unshuffle_4D():
1407 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
1408
1409 def test_pixel_shuffle_unshuffle_5D():
1410 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
1411
1412 test_pixel_shuffle_unshuffle_1D()
1413 test_pixel_shuffle_unshuffle_2D()
1414 test_pixel_shuffle_unshuffle_3D()
1415 test_pixel_shuffle_unshuffle_4D()
1416 test_pixel_shuffle_unshuffle_5D()
1417
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001418class MPSReluTest(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001419 def _npRelu(self, np_features):
1420 return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
1421
1422 def testNpRelu(self):
Philip Meierbc73aff2022-11-02 11:25:04 +01001423 torch.testing.assert_close(
Kulin Sethe011a8e2022-05-13 18:28:53 +00001424 np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
1425 self._npRelu(
1426 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1427 0.9]])))
1428
1429 def _testRelu(self, np_features, device):
1430 np_relu = self._npRelu(np_features)
1431 # Convert the numpy array to a PyTorch Tensor,
1432 # and move the Tensor to the CPU/GPU based on the "device" parameter
1433 py_tensor = torch.from_numpy(np_features).to(device)
1434 py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
1435 py_relu_cpu = py_relu.to("cpu")
1436
Philip Meierbc73aff2022-11-02 11:25:04 +01001437 self.assertEqual(np_relu, py_relu_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001438
1439 def _testReluInPlace(self, np_features, device):
1440 np_relu = self._npRelu(np_features)
1441 # Convert the numpy array to a PyTorch Tensor,
1442 # and move the Tensor to the CPU/GPU based on the "device" parameter
1443 py_tensor = torch.from_numpy(np_features).to(device)
1444 py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
1445 py_relu_cpu = py_relu.to("cpu")
1446
Philip Meierbc73aff2022-11-02 11:25:04 +01001447 self.assertEqual(np_relu, py_relu_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001448 # Inplace Relu modifies the initial input and it should match the output of Relu
Philip Meierbc73aff2022-11-02 11:25:04 +01001449 self.assertEqual(np_relu, py_tensor.to("cpu"))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001450
1451 def testNumbersCPU(self):
1452 for t in [np.int32]:
1453 # Force execution on CPU even if a GPU kernel is available for the type.
1454 self._testRelu(
1455 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1456 device="cpu")
1457 self._testReluInPlace(
1458 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1459 device="cpu")
1460
1461 def testNumbersGPU(self):
1462 for t in [np.float16, np.float32]:
1463 self._testRelu(
1464 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1465 device="mps")
1466 self._testReluInPlace(
1467 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1468 device="mps")
lancerts26a27432024-02-03 23:50:35 +00001469 self._testRelu(np.array([]).astype(t), device="mps")
1470 self._testReluInPlace(np.array([]).astype(t), device="mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001471
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001472class MatmulTest(TestCaseMPS):
Kulin Seth978304f2022-05-14 13:33:16 +00001473 def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
1474 if expand_tensor_1_shape:
1475 tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
1476 else:
1477 tensor1_mps = torch.randn(shape_tensor_1, device="mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001478
Kulin Seth978304f2022-05-14 13:33:16 +00001479 if expand_tensor_2_shape:
1480 tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
1481 else:
1482 tensor2_mps = torch.randn(shape_tensor_2, device="mps")
1483
1484 tensor1_cpu = tensor1_mps.to("cpu")
1485 tensor2_cpu = tensor2_mps.to("cpu")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001486
1487 matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
1488 matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
1489
1490 self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
1491
1492 def test_vector_x_vector(self):
1493 # uses `dot`
1494 self._helper(3, 3)
1495
1496 def test_matrix_x_vector(self):
1497 # uses `addmv`
1498 self._helper((3, 4), 4)
1499
1500 def test_batched_matrix_x_broadcasted_vector(self):
1501 self._helper((10, 3, 4), 4)
1502
1503 def test_batched_matrix_x_batched_matrix(self):
1504 # uses `bmm.out`
1505 self._helper((10, 3, 4), (10, 4, 5))
1506
1507 def test_batched_matrix_x_broadcasted_matrix(self):
1508 self._helper((10, 3, 4), (4, 5))
1509
1510
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001511class MPSLeakyReluTest(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001512 def _npLeakyRelu(self, np_features, negative_slope=0.1):
1513 return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
1514
1515 def testNpLeakyRelu(self):
Philip Meierbc73aff2022-11-02 11:25:04 +01001516 torch.testing.assert_close(
Kulin Sethe011a8e2022-05-13 18:28:53 +00001517 np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
1518 [0.1, -0.03, 0.5, -0.07, 0.9]]),
1519 self._npLeakyRelu(
1520 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1521 0.9]]),
1522 negative_slope=0.1))
1523
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001524 def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
1525 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
1526 mps_x = cpu_x.detach().clone().to('mps')
1527
1528 if not contiguous and not (0 in shape or len(shape) < 2):
1529 # Tranposing will make the tensor non-contiguous
1530 cpu_x = cpu_x.transpose(0, 1)
1531 mps_x = mps_x.transpose(0, 1)
1532 assert not mps_x.is_contiguous()
1533
1534 cpu_x.requires_grad_()
1535 mps_x.requires_grad_()
1536
Kulin Sethe011a8e2022-05-13 18:28:53 +00001537 relu_op = torch.nn.LeakyReLU(negative_slope)
1538
1539 cpu_leaky_relu = relu_op(cpu_x)
1540 mps_leaky_relu = relu_op(mps_x)
Philip Meierbc73aff2022-11-02 11:25:04 +01001541 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001542
1543 # test backward pass
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001544
Kulin Sethe011a8e2022-05-13 18:28:53 +00001545 cpu_grad = torch.ones_like(cpu_leaky_relu)
1546 mps_grad = cpu_grad.to('mps')
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001547
Kulin Sethe011a8e2022-05-13 18:28:53 +00001548 mps_leaky_relu.backward(gradient=mps_grad)
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001549 cpu_leaky_relu.backward(gradient=cpu_grad)
1550
1551 assert cpu_x.grad is not None # Check that the grad is well-populated
1552 self.assertEqual(cpu_x.grad, mps_x.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001553
1554 def testNumbersCPU(self):
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001555 for t in [torch.float, torch.half]:
1556 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
1557 for contiguous in [True, False]:
1558 self._testLeakyRelu(shape,
1559 dtype=t,
1560 negative_slope=0.2,
1561 contiguous=contiguous)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001562
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001563class TestAvgPool(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001564 def _sum_pool2d(self, x, kernel_size):
1565 windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
1566 return torch.sum(windows, dim=1)
1567
1568 def _sum_pool3d(self, x, kernel_size):
1569 # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
1570 h = kernel_size[0]
1571 splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
1572 # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
1573 splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
1574 joined_x = torch.cat(splited_x)
1575 return joined_x.view(1, joined_x.numel())
1576
1577 def _avg_pool2d(self, x, kernel_size):
Aaron Gokaslanbd10fea2024-01-01 08:40:46 +00001578 size = reduce(operator.mul, kernel_size) # noqa: F821
Kulin Sethe011a8e2022-05-13 18:28:53 +00001579 return self._sum_pool2d(x, kernel_size) / size
1580
1581 def _avg_pool3d(self, x, kernel_size):
Aaron Gokaslanbd10fea2024-01-01 08:40:46 +00001582 size = reduce(operator.mul, kernel_size) # noqa: F821
Kulin Sethe011a8e2022-05-13 18:28:53 +00001583 return self._sum_pool3d(x, kernel_size) / size
1584
1585 def test_avg_pool2d_with_zero_divisor(self):
1586 self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
1587 lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
1588
1589 def test_doubletensor_avg_pool2d_with_divisor(self):
1590 n, m = 3, 3
1591 input = torch.rand(1, 1, n, m)
1592 for i in range(1, n + 1):
1593 for j in range(1, m + 1):
1594 for divisor in [1, 7, i * j]:
1595 actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
1596 actual = actual.view(1, actual.numel())
1597 expected = self._sum_pool2d(input, (i, j)) / divisor
1598 self.assertEqual(actual, expected, rtol=0, atol=1e-5)
1599
1600 def test_avg_pool2d_ceil_mode(self):
1601 # Regression test for gh-36977
1602 x = 10 * torch.randn((1, 16, 4, 4))
1603 y = torch.nn.functional.avg_pool2d(
1604 x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1605 padding=(0, 1), stride=2)
1606 self.assertTrue(not torch.isnan(y).any())
1607 y = torch.nn.functional.avg_pool2d(
1608 x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1609 padding=(0, 1), stride=2)
1610 self.assertTrue(not torch.isnan(y).any())
1611
1612
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001613class TestMPS(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001614 def test_exp(self, device="mps", dtype=torch.float):
1615 for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
Nikita Shulga06787422024-06-11 15:37:03 -07001616 b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
1617 a = torch.tensor(v, dtype=dtype, device="mps") * b
Kulin Sethe011a8e2022-05-13 18:28:53 +00001618 self.compare_with_numpy(torch.exp, np.exp, a)
1619
Joona Havukainend9eaa222024-06-18 03:44:38 +00001620 def test_triu_inf(self, device="mps", dtype=torch.float):
1621 for diag in [-1, 0, 1]:
1622 mask = torch.full((3, 6, 6), float("-inf"))
1623 mask_mps = mask.clone().detach().to('mps')
1624 cpu_ref = torch.triu(mask, diagonal=diag)
1625 mps_out = torch.triu(mask_mps, diagonal=diag)
1626 self.assertEqual(cpu_ref, mps_out)
1627
Kulin Sethe011a8e2022-05-13 18:28:53 +00001628 def test_exp1(self, device="mps", dtype=torch.float):
Nikita Shulga06787422024-06-11 15:37:03 -07001629 input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
1630 output = torch.exp(input)
1631 output_cpu = torch.exp(input.cpu())
1632 # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
1633 # Mismatched elements: 3 / 4 (75.0%)
1634 # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
1635 # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
1636 self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001637
Denis Vieriu5d483922023-02-07 16:25:03 +00001638 def test_exp_strided_output(self):
1639 x = torch.rand((256, 10), device='mps')
1640 x_cpu = x.to("cpu")
1641
1642 x = x.permute(1, 0)
1643 x_cpu = x_cpu.permute(1, 0)
1644
1645 res = x.exp()
1646 res_cpu = x_cpu.exp()
1647 self.assertEqual(res, res_cpu)
1648
Kulin Sethe011a8e2022-05-13 18:28:53 +00001649 def _testLeakyRelu(self, np_features, negative_slope, device):
1650 cpu_x = torch.from_numpy(np_features).requires_grad_()
1651 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
1652 relu_op = torch.nn.LeakyReLU(negative_slope)
1653
1654 cpu_leaky_relu = relu_op(cpu_x)
1655 mps_leaky_relu = relu_op(mps_x)
Philip Meierbc73aff2022-11-02 11:25:04 +01001656 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001657
1658 # test backward pass
1659 cpu_grad = torch.ones_like(cpu_leaky_relu)
1660 mps_grad = cpu_grad.to('mps')
1661 cpu_leaky_relu.backward(gradient=cpu_grad)
1662 mps_leaky_relu.backward(gradient=mps_grad)
Philip Meierbc73aff2022-11-02 11:25:04 +01001663 torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001664
1665 def testNumbersGPU(self):
1666 for t in [np.float32]:
1667 self._testLeakyRelu(
1668 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1669 negative_slope=0.1,
1670 device="mps")
1671
1672 def test_fill(self):
1673
Li-Huai (Allan) Lin30237aa2023-10-24 12:57:21 -07001674 def helper(val, shape, dtype):
1675 tensor = torch.zeros(shape, device='mps', dtype=dtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001676 tensor_mps = tensor.fill_(val)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001677
Li-Huai (Allan) Lin30237aa2023-10-24 12:57:21 -07001678 tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001679 tensor_cpu = tensor_0.fill_(val)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001680
1681 self.assertEqual(tensor_mps, tensor_cpu)
1682
Li-Huai (Allan) Lin30237aa2023-10-24 12:57:21 -07001683 helper(0, [1024], torch.float32)
1684 helper(0.2, [2, 3], torch.float32)
1685 helper(0.2 + 0.5j, [2, 3], torch.complex64)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001686
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001687 def test_fill_storage_offset(self):
1688 shape = [2, 10]
1689 val = 0.2
1690 tensor = torch.ones(shape, device="mps")
1691 tensor_mps = tensor[:][1].fill_(val)
1692 tensor_0 = torch.ones(shape, device="cpu")
1693 tensor_cpu = tensor_0[:][1].fill_(val)
1694
1695 self.assertEqual(tensor_mps, tensor_cpu)
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001696 self.assertEqual(tensor, tensor_0)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001697
1698 shape = [1, 10]
1699 val = 0.0
1700 tensor = torch.ones(shape, device="mps")
1701 val_tensor_mps = torch.tensor(val, device="mps")
1702 tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001703 # Regression test for https://github.com/pytorch/pytorch/issues/114692
1704 tensor[:, 5].fill_(val_tensor_mps)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001705 tensor_0 = torch.ones(shape, device="cpu")
1706 val_tensor_cpu = torch.tensor(val, device="cpu")
1707 tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001708 tensor_0[:, 5].fill_(val_tensor_cpu)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001709
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001710 self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
1711 self.assertEqual(tensor.to(device="cpu"), tensor_0)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001712
Denis Vieriu80394bb2023-01-04 02:20:50 +00001713 def test_cdist_large(self, device="mps"):
1714 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1715 x = torch.randn(100, 10, device=device)
1716 y = torch.randn(100, 10, device=device)
1717 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1718 expected = self._brute_cdist(x, y, p=2)
1719 self.assertEqual(expected, actual)
1720
1721 def test_cdist_large_batch(self, device="mps"):
1722 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1723 x = torch.randn(4, 3, 100, 10, device=device)
1724 y = torch.randn(4, 3, 100, 10, device=device)
1725 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1726 expected = self._brute_cdist(x, y, p=2)
1727 self.assertEqual(expected, actual)
1728
1729 def test_cdist_non_contiguous(self, device="mps"):
1730 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1731 x = torch.randn(5, 7, device=device).mT
1732 y = torch.randn(5, 3, device=device).mT
1733 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1734 expected = self._brute_cdist(x, y, p=2)
1735 self.assertFalse(x.is_contiguous())
1736 self.assertFalse(y.is_contiguous())
1737 self.assertEqual(expected, actual)
1738
1739 x = torch.randn(7, 5, device=device)
1740 y = torch.randn(5, 3, device=device).t()
1741 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1742 expected = self._brute_cdist(x, y, p=2)
1743 self.assertTrue(x.is_contiguous())
1744 self.assertFalse(y.is_contiguous())
1745 self.assertEqual(expected, actual)
1746
1747 x = torch.randn(5, 7, device=device).t()
1748 y = torch.randn(3, 5, device=device)
1749 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1750 expected = self._brute_cdist(x, y, p=2)
1751 self.assertFalse(x.is_contiguous())
1752 self.assertTrue(y.is_contiguous())
1753 self.assertEqual(expected, actual)
1754
1755 def test_cdist_non_contiguous_batch(self, device="mps"):
1756 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1757 x = torch.randn(4, 3, 2, 5, 7, device=device).mT
1758 y = torch.randn(4, 3, 2, 5, 3, device=device).mT
1759 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1760 expected = self._brute_cdist(x, y, p=2)
1761 self.assertFalse(x.is_contiguous())
1762 self.assertFalse(y.is_contiguous())
1763 self.assertEqual(expected, actual)
1764
1765 x = torch.randn(7, 2, 7, 5, device=device)
1766 y = torch.randn(7, 2, 5, 3, device=device).mT
1767 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1768 expected = self._brute_cdist(x, y, p=2)
1769 self.assertTrue(x.is_contiguous())
1770 self.assertFalse(y.is_contiguous())
1771 self.assertEqual(expected, actual)
1772
1773 x = torch.randn(4, 5, 7, device=device).mT
1774 y = torch.randn(4, 3, 5, device=device)
1775 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1776 expected = self._brute_cdist(x, y, p=2)
1777 self.assertFalse(x.is_contiguous())
1778 self.assertTrue(y.is_contiguous())
1779 self.assertEqual(expected, actual)
1780
1781 def test_cdist_euclidean_large(self, device="mps"):
1782 def _test_euclidean_large_cdist(sizex, sizey=None):
1783 if sizey is None:
1784 sizey = sizex
1785 x = torch.randn(sizex, device=device, dtype=torch.float)
1786 y = torch.randn(sizey, device=device, dtype=torch.float)
1787 eps = 1e-6
1788 # to avoid extremum
1789 x = x - (((x - y) < eps).float() * 2 * eps)
1790 x.requires_grad = True
1791 y.requires_grad = True
1792 dist = torch.cdist(x, y, p=2)
1793 # Do a backward pass to check that it is valid for large
1794 # matrices
1795 loss = dist.sum()
1796 loss.backward()
1797
1798 _test_euclidean_large_cdist((2000, 5))
1799
1800 def test_cdist_same_inputs(self, device="mps"):
1801 # Test to detect issues in cdist gradient calculation
1802 # When the distances are 0
1803 sizex = (1, 27, 32)
1804 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
1805 x = torch.randn(sizex, device=device, dtype=torch.float)
1806 dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
1807 y = x.clone()
1808 eps = 1e-6
1809 x.requires_grad = True
1810 d = torch.cdist(x, y)
1811 d.backward(dist_grad)
1812 # Check that the backward passs does not contain invalid
1813 # values such as nan or inf
1814 assert torch.isfinite(x.grad).all()
1815
1816
1817 def _brute_cdist(self, x, y, p=2):
1818 r1 = x.shape[-2]
1819 r2 = y.shape[-2]
1820 if r1 == 0 or r2 == 0:
1821 return torch.empty(r1, r2, device=x.device)
1822 return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
1823
1824 def test_cdist_norm(self, device="mps"):
1825 for r1 in [3, 4]:
1826 for m in [2, 3]:
1827 for r2 in [4, 6]:
1828 for p in [0, 1, 1.5, 2.5, float('inf')]:
1829 x = torch.randn(r1, m, device=device)
1830 y = torch.randn(r2, m, device=device)
1831 if p == 2:
1832 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1833 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1834 expected = self._brute_cdist(x, y, p=2)
1835 self.assertEqual(expected, actual, rtol=0, atol=0.02)
1836 else:
1837 actual = torch.cdist(x, y, p=p)
1838 expected = self._brute_cdist(x, y, p=p)
1839 self.assertEqual(expected, actual)
1840
1841 def test_cdist_norm_batch(self, device="mps"):
1842 for r1 in [3, 4]:
1843 for m in [2, 3]:
1844 for r2 in [4, 6]:
1845 for p in [0, 3, 1.5, 2.5, float('inf')]:
1846 x = torch.randn(2, 3, 6, r1, m, device=device)
1847 y = torch.randn(2, 3, 6, r2, m, device=device)
1848 if p == 2:
1849 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1850 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1851 expected = self._brute_cdist(x, y, p=2)
1852 self.assertEqual(expected, actual, rtol=0, atol=0.02)
1853 else:
1854 actual = torch.cdist(x, y, p=p)
1855 expected = self._brute_cdist(x, y, p=p)
1856 self.assertEqual(expected, actual)
1857
Kulin Sethe011a8e2022-05-13 18:28:53 +00001858 def test_mm(self):
1859 B = torch.ones(5, 6).to("mps")
1860 C = torch.ones(6, 5).to("mps")
1861 D = torch.mm(B, C).cpu()
Philip Meierbc73aff2022-11-02 11:25:04 +01001862 torch.testing.assert_close(D, torch.full((5, 5), 6.0))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001863
Denis Vieriu1a0738f2023-01-05 14:48:34 +00001864 def test_linalg_cross(self):
1865 def helper(dtype):
1866 device = "mps"
1867 if dtype is torch.int32 or dtype is torch.int64:
1868 x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1869 y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1870 else:
1871 x = torch.rand(100, 3, 100, dtype=dtype, device=device)
1872 y = torch.rand(100, 3, 100, dtype=dtype, device=device)
1873 x_cpu = x.to("cpu")
1874 y_cpu = y.to("cpu")
1875 res1 = torch.linalg.cross(x, y, dim=1)
1876 res2 = torch.tensor((), dtype=dtype, device=device)
1877 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1878 res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1879 torch.linalg.cross(x, y, dim=1, out=res2)
1880 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1881 self.assertEqual(res1, res2)
1882 self.assertEqual(res1, res1_cpu)
1883 self.assertEqual(res2, res2_cpu)
1884
1885 # test for broadcastable inputs
1886 if dtype is torch.int32 or dtype is torch.int64:
1887 x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
1888 y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
1889 else:
1890 x = torch.rand(1, 3, 2, dtype=dtype, device=device)
1891 y = torch.rand(4, 3, 1, dtype=dtype, device=device)
1892 x_cpu = x.to("cpu")
1893 y_cpu = y.to("cpu")
1894 res1 = torch.linalg.cross(x, y, dim=1)
1895 res2 = torch.tensor((), dtype=dtype, device=device)
1896 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1897 res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1898 torch.linalg.cross(x, y, dim=1, out=res2)
1899 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1900 self.assertEqual(res1, res2)
1901 self.assertEqual(res1, res1_cpu)
1902 self.assertEqual(res2, res2_cpu)
1903 [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
1904
1905 def test_cross(self):
1906 a = torch.randn(4, 3, device="mps")
1907 b = torch.randn(4, 3, device="mps")
1908 a_cpu = a.to("cpu")
1909 b_cpu = b.to("cpu")
1910 res = torch.cross(a, b, dim=1)
1911 res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
1912 self.assertEqual(res, res_cpu)
1913
Kulin Sethe011a8e2022-05-13 18:28:53 +00001914 def test_addmm(self):
1915 A = torch.ones(5, 5).to("mps")
1916 B = torch.ones(5, 6).to("mps")
1917 C = torch.ones(6, 5).to("mps")
1918 D = torch.addmm(A, B, C).to("cpu")
Philip Meierbc73aff2022-11-02 11:25:04 +01001919 torch.testing.assert_close(D, torch.full((5, 5), 7.0))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001920
1921 def test_bmm(self):
1922 batch1_cpu = torch.randn(10, 3, 4)
1923 batch2_cpu = torch.randn(10, 4, 5)
1924
1925 batch1_mps = batch1_cpu.detach().clone().to("mps")
1926 batch2_mps = batch2_cpu.detach().clone().to("mps")
1927
1928 output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
1929 output_mps = torch.bmm(batch1_mps, batch2_mps)
1930
1931 self.assertEqual(output_cpu, output_mps)
1932 self.assertEqual(output_cpu.size(), output_mps.size())
1933
Denis Vieriu507b8c32023-02-11 00:16:46 +00001934 def test_addr(self):
1935 A = torch.ones(5, 10).to("mps")
1936 B = torch.ones(5).to("mps")
1937 C = torch.ones(10).to("mps")
1938 D = torch.addr(A, B, C).to("cpu")
1939 torch.testing.assert_close(D, torch.full((5, 10), 2.0))
1940
PumeTufc1c0cd2022-11-18 07:24:33 +00001941 def test_trace(self):
1942 M_cpu = torch.randn(3, 3)
1943 M_mps = M_cpu.detach().clone().to("mps")
1944
1945 output_cpu = torch.trace(M_cpu)
1946 output_mps = torch.trace(M_mps)
1947
1948 self.assertEqual(output_cpu, output_mps)
1949 self.assertEqual(output_cpu.size(), output_mps.size())
1950
Kulin Sethe011a8e2022-05-13 18:28:53 +00001951 def test_addbmm(self):
1952 M_cpu = torch.randn(3, 5)
1953 batch1_cpu = torch.randn(10, 3, 4)
1954 batch2_cpu = torch.randn(10, 4, 5)
1955
1956 M_mps = M_cpu.detach().clone().to("mps")
1957 batch1_mps = batch1_cpu.detach().clone().to("mps")
1958 batch2_mps = batch2_cpu.detach().clone().to("mps")
1959
1960 output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
1961 output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
1962
1963 self.assertEqual(output_cpu, output_mps)
1964 self.assertEqual(output_cpu.size(), output_mps.size())
1965
1966 def test_baddbmm(self):
Kulin Seth3d833212022-05-20 03:18:09 +00001967 def helper(input_shape, batch1_shape, batch2_shape):
1968 M_cpu = torch.randn(input_shape)
1969 batch1_cpu = torch.randn(batch1_shape)
1970 batch2_cpu = torch.randn(batch2_shape)
1971 alpha = 1.2
1972 beta = 0.8
Kulin Sethe011a8e2022-05-13 18:28:53 +00001973
Kulin Seth3d833212022-05-20 03:18:09 +00001974 M_mps = M_cpu.detach().clone().to("mps")
1975 batch1_mps = batch1_cpu.detach().clone().to("mps")
1976 batch2_mps = batch2_cpu.detach().clone().to("mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001977
Kulin Seth3d833212022-05-20 03:18:09 +00001978 output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
1979 output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001980
Kulin Seth3d833212022-05-20 03:18:09 +00001981 self.assertEqual(output_cpu, output_mps)
1982 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Sethd63db522022-05-28 14:41:56 +00001983
Kulin Seth3d833212022-05-20 03:18:09 +00001984 helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
1985 helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
1986 helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001987
1988 def test_local_scalar_dense_mps(self):
1989 x_cpu = torch.randn(1)
1990 y_mps = x_cpu.to("mps")
Philip Meierbc73aff2022-11-02 11:25:04 +01001991 torch.testing.assert_close(x_cpu.item(), y_mps.item())
Kulin Sethe011a8e2022-05-13 18:28:53 +00001992
Kulin Seth7ff6a002022-09-28 00:43:11 +00001993 def test_linear_1d_weight(self):
1994 device = 'cpu'
1995 projected = torch.rand([8]).to(device)
1996 x = torch.rand([1, 2, 2, 8]).to(device)
1997 x_mps = x.to('mps')
1998 projected_mps = projected.to('mps')
1999 linear = F.linear(x, projected)
2000 linear_mps = F.linear(x_mps, projected_mps)
2001
2002 self.assertEqual(linear, linear_mps)
2003
2004 projected = torch.rand([1, 8]).to(device)
2005 x = torch.rand([1, 2, 2, 8]).to(device)
2006 x_mps = x.to('mps')
2007 projected_mps = projected.to('mps')
2008 linear = F.linear(x, projected)
2009 linear_mps = F.linear(x_mps, projected_mps)
2010
2011 self.assertEqual(linear, linear_mps)
2012
Li-Huai (Allan) Lin1fcf40d2023-04-26 12:11:22 +08002013 def test_linear_bias(self):
2014 def helper(bias_shape):
2015 device = "cpu"
2016 x = torch.randn(2, 2, 2, 64, device=device)
2017 linear = torch.nn.Linear(64, 4, device=device)
2018 linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device))
2019 y = linear(x)
2020 device = "mps"
2021 x_mps = x.to(device)
2022 linear.to(device)
2023 y_mps = linear(x_mps)
2024 self.assertEqual(y, y_mps)
2025
2026 helper(())
2027 helper((2, 4))
2028
Nikita Shulgadb3a2d72024-04-25 23:25:20 +00002029 def test_linear_errors(self):
2030 # Mixed CPU<->MPS tensors
2031 size = (3, 3)
2032
2033 # Unsupported dtypes
2034 with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
2035 torch.nn.functional.linear(torch.rand(size, device='mps'),
2036 torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
2037
2038 # Weigths on wrong device
2039 with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
2040 torch.nn.functional.linear(torch.rand(size, device='mps'),
2041 torch.rand(size, device='cpu'))
2042
2043 # Input on wrong device
2044 with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
2045 torch.nn.functional.linear(torch.rand(size, device='cpu'),
2046 torch.rand(size, device='mps'))
2047
Kulin Sethe011a8e2022-05-13 18:28:53 +00002048 def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
2049 cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
2050 mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
2051
2052 # Use the same weights and bias as the ones from the cpu
2053 mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
2054
2055 if bias:
2056 mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
2057
2058 linear_mps_input = torch.randn(shape).to('mps')
2059 linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
2060
2061 if backward_pass:
2062 linear_mps_input = linear_mps_input.requires_grad_()
2063 linear_cpu_input = linear_cpu_input.requires_grad_()
2064
2065 linear_cpu_output = cpu_linear(linear_cpu_input)
2066 linear_mps_output = mps_linear(linear_mps_input)
2067
2068 self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
2069 self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
2070
2071 if backward_pass:
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00002072 cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True)
2073 grad = cpu_grad.detach().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00002074
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00002075 linear_cpu_output.backward(gradient=cpu_grad, create_graph=True)
2076 linear_mps_output.backward(gradient=grad, create_graph=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002077
2078 self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
2079 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2080
2081 self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
2082 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2083 if bias:
2084 self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
2085 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2086
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00002087 # test gradgrad
2088 x_grad_out = torch.rand_like(linear_cpu_input)
2089 x_grad_out_mps = x_grad_out.to("mps")
2090 w_grad_out = torch.rand_like(cpu_linear.weight)
2091 w_grad_out_mps = w_grad_out.to("mps")
2092
2093 linear_cpu_input.grad.detach().zero_()
2094 linear_mps_input.grad.detach().zero_()
2095 cpu_linear.weight.grad.detach().zero_()
2096 mps_linear.weight.grad.detach().zero_()
2097 if bias:
2098 b_grad_out = torch.rand_like(cpu_linear.bias)
2099 b_grad_out_mps = b_grad_out.to("mps")
2100 cpu_linear.bias.grad.detach().zero_()
2101 mps_linear.bias.grad.detach().zero_()
2102
2103 linear_cpu_input.grad.backward(x_grad_out, retain_graph=True)
2104 linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True)
2105 cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True)
2106 mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True)
2107 if bias:
2108 cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True)
2109 mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True)
2110
2111 self.assertEqual(cpu_grad.grad, grad.grad)
2112 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad)
2113 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad)
2114 if bias:
2115 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad)
2116
Ramin Azarmehr0e3953f2022-07-04 02:06:14 +00002117 def test_linear1D(self):
2118 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False)
2119
2120 def test_linear1D_backward(self):
2121 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True)
2122
Kulin Sethe011a8e2022-05-13 18:28:53 +00002123 def test_linear2D(self):
2124 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
2125
2126 def test_linear2D_backward(self):
2127 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
2128
2129 def test_linear2D_no_bias(self):
2130 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
2131
2132 def test_linear2D_no_bias_backward(self):
2133 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
2134
2135 def test_linear3D(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002136 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002137
Nikita Shulga70508262022-05-25 16:23:10 +00002138 def test_linear3D_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002139 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002140
2141 def test_linear3D_no_bias(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002142 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002143
2144 def test_linear3D_no_bias_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002145 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002146
2147 def test_uniform(self):
2148 low = torch.zeros(5, 5, requires_grad=True)
2149 high = (torch.ones(5, 5) * 3).requires_grad_()
2150 low_1d = torch.zeros(1, requires_grad=True)
2151 high_1d = (torch.ones(1) * 3).requires_grad_()
2152 self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
2153 self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
Kulin Seth3d833212022-05-20 03:18:09 +00002154 self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
2155 self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
2156 self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
Kulin Sethe011a8e2022-05-13 18:28:53 +00002157
Kulin Seth3d833212022-05-20 03:18:09 +00002158 # Check log_prob computation when value outside range
2159 uniform = Uniform(low_1d, high_1d, validate_args=False)
2160 above_high = torch.tensor([4.0])
2161 below_low = torch.tensor([-1.0])
2162 self.assertEqual(uniform.log_prob(above_high).item(), -inf)
2163 self.assertEqual(uniform.log_prob(below_low).item(), -inf)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002164
Kulin Seth3d833212022-05-20 03:18:09 +00002165 # check cdf computation when value outside range
2166 self.assertEqual(uniform.cdf(below_low).item(), 0)
2167 self.assertEqual(uniform.cdf(above_high).item(), 1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002168
Kulin Seth3d833212022-05-20 03:18:09 +00002169 state = torch.get_rng_state()
2170 rand = low.new(low.size()).uniform_()
2171 torch.set_rng_state(state)
2172 u = Uniform(low, high).rsample()
2173 u.backward(torch.ones_like(u))
2174 self.assertEqual(low.grad, 1 - rand)
2175 self.assertEqual(high.grad, rand)
2176 low.grad.zero_()
2177 high.grad.zero_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00002178
Denis Vieriu53ef96f2023-01-06 22:49:04 +00002179 def test_randperm(self, device="mps"):
2180 rng_device = None
2181 for n in (5, 100, 50000, 100000):
2182 for dtype in (torch.long, torch.half, torch.float):
2183 if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here.
2184 continue
2185 if n > 256 and dtype == torch.bfloat16:
2186 continue
2187 with torch.random.fork_rng(devices=rng_device):
2188 res1 = torch.randperm(n, dtype=dtype, device=device)
2189 res2 = torch.empty(0, dtype=dtype, device=device)
2190 torch.randperm(n, out=res2, dtype=dtype, device=device)
2191 self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
2192
2193 # Default type is long
2194 for n in (100, 10000):
2195 self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
2196
2197 # randperm of 0 elements is an empty tensor
2198 res1 = torch.randperm(0)
2199 res2 = torch.tensor(5, dtype=dtype, device=device)
2200 torch.randperm(0, out=res2)
2201 self.assertEqual(res1.numel(), 0)
2202 self.assertEqual(res2.numel(), 0)
2203
2204 # Test non-contiguous tensors
2205 for n in (4, 5, 6, 10, 20):
2206 non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
2207 self.assertFalse(non_contiguous_tensor.is_contiguous())
2208 with torch.random.fork_rng(devices=rng_device):
2209 res = torch.randperm(n, dtype=torch.long, device=device)
2210 torch.randperm(n, out=non_contiguous_tensor)
2211 self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
2212
Kulin Sethe011a8e2022-05-13 18:28:53 +00002213 # Test forward maxpool2d
2214 def test_max_pool2d(self):
2215 def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
2216
2217 cpu_x = None
Thomas4935b592022-11-23 02:18:03 +00002218 if (test_ties):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002219 cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
2220 else:
2221 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2222 x = cpu_x.detach().clone().to('mps').requires_grad_()
2223
2224 pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
2225 ceil_mode=ceil_mode, return_indices=return_indices)
2226
Thomas4935b592022-11-23 02:18:03 +00002227 if (return_indices is False):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002228 y = pool(x)
2229 ref_y = pool(cpu_x)
2230
2231 cpu_grad = torch.ones_like(ref_y)
2232 grad = cpu_grad.to('mps')
2233
2234 y.backward(gradient=grad)
2235 ref_y.backward(gradient=cpu_grad)
2236
2237 self.assertEqual(y, ref_y)
2238 self.assertEqual(x.grad, cpu_x.grad)
2239 else:
2240 y, idx = pool(x)
2241 ref_y, ref_idx = pool(cpu_x)
2242
2243 cpu_grad = torch.ones_like(ref_y)
2244 grad = cpu_grad.to('mps')
2245
2246 y.backward(gradient=grad)
2247 ref_y.backward(gradient=cpu_grad)
2248
2249 self.assertEqual(y, ref_y)
2250 self.assertEqual(idx, ref_idx)
2251 self.assertEqual(x.grad, cpu_x.grad)
2252
2253 # Test with no batch dimension
2254 helper((8, 4, 4), ks=2)
2255 helper((2, 8, 4, 4), ks=2)
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002256 helper((1, 1000, 32, 32), ks=4)
2257 helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002258 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002259 helper((1, 1000, 32, 32), ks=4, padding=1)
2260 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002261 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002262 helper((1, 1000, 32, 32), ks=4, dilation=2)
2263 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002264 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002265 helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
2266 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002267
2268 # Test return indices
2269 for test_ties in [False, True]:
2270 # Test with no batch dimension
2271 helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
2272 helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002273 helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
2274 helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002275 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002276 helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
2277 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
Kulin Sethe011a8e2022-05-13 18:28:53 +00002278 return_indices=True, test_ties=test_ties) # test for max_pool1d
2279 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002280 helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
2281 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
Kulin Sethe011a8e2022-05-13 18:28:53 +00002282 return_indices=True, test_ties=test_ties) # test for max_pool1d
2283 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002284 helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
2285 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
Kulin Sethe011a8e2022-05-13 18:28:53 +00002286 return_indices=True, test_ties=test_ties) # test for max_pool1d
2287
2288 def test_adaptive_avg_pool2d_output_size_one(self):
2289 def helper(size, memory_format):
2290 x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
Kulin Seth3d833212022-05-20 03:18:09 +00002291 if memory_format == 'non_contiguous':
2292 x = x[::2, ::2, ::2, ::2]
2293 else:
2294 x = x.to(memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002295
2296 net = torch.nn.AdaptiveAvgPool2d((1, 1))
2297 out = net(x)
2298 ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
2299
2300 out.sum().backward() # make sure it doesn't crash
2301
2302 self.assertEqual(out, ref_out)
2303 if memory_format == torch.channels_last:
2304 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
2305 c = out.size(1)
2306 self.assertEqual(out.stride(), [c, 1, c, c])
2307 else:
2308 self.assertTrue(out.is_contiguous())
2309 c = out.size(1)
2310 self.assertEqual(out.stride(), [c, 1, 1, 1])
2311
2312 helper((2, 3, 6, 6), torch.contiguous_format)
2313
Denis Vieriued1957d2023-03-01 01:36:36 +00002314 def test_masked_scatter(self):
2315 def helper(shape):
2316 x_mps = torch.randn(shape, device="mps")
2317 x_cpu = x_mps.detach().clone().cpu()
2318
2319 mask_mps = torch.rand(shape, device="mps") < 0.6
2320 mask_cpu = mask_mps.detach().clone().cpu()
2321
2322 y_mps = torch.randn(shape, device="mps")
2323 y_cpu = y_mps.detach().clone().cpu()
2324
2325 y_mps.masked_scatter_(mask_mps, x_mps)
2326 y_cpu.masked_scatter_(mask_cpu, x_cpu)
2327
2328 self.assertEqual(y_mps, y_cpu)
2329 helper([2, 5])
2330 helper([10, 10])
2331 helper([5, 10, 3])
2332 helper([10, 5, 10, 3])
2333 helper([10, 5, 10, 3, 20])
2334
Kulin Seth3d833212022-05-20 03:18:09 +00002335 def test_masked_fill(self):
2336 device = "mps"
2337 dtype = torch.float32
2338 mask_dtype = torch.bool
2339
2340 with warnings.catch_warnings(record=True) as w:
2341 warnings.simplefilter("always")
2342 num_dest = 10
2343 dst = torch.zeros(num_dest, dtype=dtype, device=device)
2344 mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
2345 val = random.random()
2346 dst2 = torch.zeros(num_dest, dtype=dtype)
2347 mask_cpu = mask.to("cpu")
2348
2349 dst.masked_fill_(mask, val)
2350 for i in range(num_dest):
2351 if mask_cpu[i]:
2352 dst2[i] = val
2353 self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
2354
2355 # test non-contiguous case
2356 dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
2357 dst2 = dst.contiguous()
2358 if dtype.is_complex:
2359 mask = dst.abs() > 0
2360 else:
2361 mask = dst > 0
2362 self.assertTrue(not dst.is_contiguous())
2363 self.assertTrue(dst2.is_contiguous())
2364 dst.masked_fill_(mask.to(mask_dtype), val)
2365 dst2.masked_fill_(mask.to(mask_dtype), val)
2366 self.assertEqual(dst, dst2, atol=0, rtol=0)
2367
2368 if mask_dtype == torch.uint8:
2369 self.assertEqual(len(w), 3)
2370
2371 warn = 'masked_fill_ received a mask with dtype torch.uint8,'
2372 for wi in w:
2373 self.assertEqual(str(wi.message)[0:52], str(warn))
2374 else:
2375 self.assertEqual(len(w), 0)
2376
2377 def test_nhwc_operation(self):
2378 def helper(shape, channels_last=False):
2379 import numpy as np
2380 np.random.seed(332)
2381 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2382 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00002383 if (channels_last):
Kulin Seth3d833212022-05-20 03:18:09 +00002384 cpu_x = cpu_x.to(memory_format=torch.channels_last)
2385 cpu_x.retain_grad()
2386 x = cpu_x.detach().clone().to('mps').requires_grad_()
2387
2388 # This passes
2389 self.assertEqual(x, cpu_x)
2390
2391 helper((2, 2, 2, 2), True)
2392
Kulin Sethe011a8e2022-05-13 18:28:53 +00002393 # Test forward batch norm
2394 def test_batch_norm(self):
2395 def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
2396 track_running_stats=True, test_module=False):
2397
2398 import numpy as np
2399 np.random.seed(332)
2400 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2401 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00002402 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002403 cpu_x = cpu_x.to(memory_format=torch.channels_last)
2404 cpu_x.retain_grad()
2405 x = cpu_x.detach().clone().to('mps').requires_grad_()
2406
2407 mean_shape = [shape[1]]
2408 cpu_running_mean = None
2409 cpu_running_var = None
2410 running_mean = None
2411 running_var = None
Thomas4935b592022-11-23 02:18:03 +00002412 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002413 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2414 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2415 var_arr = 32 * np.random.random_sample(size=mean_shape)
2416 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2417 running_mean = cpu_running_mean.detach().clone().to('mps')
2418 running_var = cpu_running_var.detach().clone().to('mps')
2419
2420 weight = None
2421 cpu_weight = None
2422 bias = None
2423 cpu_bias = None
Thomas4935b592022-11-23 02:18:03 +00002424 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002425 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2426 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2427 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2428 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2429
2430 y = None
2431 ref_y = None
2432
Thomas4935b592022-11-23 02:18:03 +00002433 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002434 y = torch.nn.functional.batch_norm(x, running_mean, running_var,
2435 weight=weight,
2436 bias=bias,
2437 training=training,
2438 momentum=momentum, eps=eps)
2439 ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
2440 weight=cpu_weight,
2441 bias=cpu_bias,
2442 training=training,
2443 momentum=momentum, eps=eps)
2444
2445 else:
2446
2447 batchnorm_op = None
2448 mps_batchnorm_op = None
2449
Thomas4935b592022-11-23 02:18:03 +00002450 if (len(shape) == 3):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002451 batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2452 eps=eps,
2453 momentum=momentum,
2454 affine=wts,
2455 track_running_stats=track_running_stats,
2456 device='cpu')
2457 mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2458 eps=eps,
2459 momentum=momentum,
2460 affine=wts,
2461 track_running_stats=track_running_stats,
2462 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002463 elif (len(shape) == 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002464 batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2465 eps=eps,
2466 momentum=momentum,
2467 affine=wts,
2468 track_running_stats=track_running_stats,
2469 device='cpu')
2470 mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2471 eps=eps,
2472 momentum=momentum,
2473 affine=wts,
2474 track_running_stats=track_running_stats,
2475 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002476 elif (len(shape) == 5):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002477 batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2478 eps=eps,
2479 momentum=momentum,
2480 affine=wts,
2481 track_running_stats=track_running_stats,
2482 device='cpu')
2483 mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2484 eps=eps,
2485 momentum=momentum,
2486 affine=wts,
2487 track_running_stats=track_running_stats,
2488 device='mps')
2489
Thomas4935b592022-11-23 02:18:03 +00002490 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002491 batchnorm_op.running_mean = cpu_running_mean
2492 batchnorm_op.running_var = cpu_running_var
2493 mps_batchnorm_op.running_mean = running_mean
2494 mps_batchnorm_op.running_var = running_var
Thomas4935b592022-11-23 02:18:03 +00002495 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002496 batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
2497 batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
2498 mps_batchnorm_op.weight = torch.nn.Parameter(weight)
2499 mps_batchnorm_op.bias = torch.nn.Parameter(bias)
2500
2501 ref_y = batchnorm_op(cpu_x)
2502 y = mps_batchnorm_op(x)
2503
2504 self.assertEqual(y, ref_y)
Thomas4935b592022-11-23 02:18:03 +00002505 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002506 self.assertEqual(running_mean, cpu_running_mean)
2507 self.assertEqual(running_var, cpu_running_var)
2508 else:
2509 self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
2510 self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
2511
2512 cpu_grad = torch.randn(ref_y.shape)
2513 grad = cpu_grad.to('mps')
2514 ref_y.backward(gradient=cpu_grad)
2515 y.backward(gradient=grad)
2516
2517 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002518 if (wts):
2519 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002520 self.assertEqual(weight.grad, cpu_weight.grad)
2521 self.assertEqual(bias.grad, cpu_bias.grad)
2522 else:
2523 self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
2524 self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
2525
2526 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2527 for test_module in [False, True]:
2528 for track_running_stats in [True, False]:
Kulin Seth3d833212022-05-20 03:18:09 +00002529 for channels_last in [False]:
Thomas4935b592022-11-23 02:18:03 +00002530 if (channels_last and len(shape) != 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002531 continue
2532 # Running stats must be tracked in eval mode
Thomas4935b592022-11-23 02:18:03 +00002533 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002534 helper(shape, eps=0, momentum=1, channels_last=channels_last,
2535 track_running_stats=track_running_stats, test_module=test_module)
2536 helper(shape, channels_last=channels_last,
2537 track_running_stats=track_running_stats, test_module=test_module)
2538 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
2539 track_running_stats=track_running_stats, test_module=test_module)
2540 helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
2541 track_running_stats=track_running_stats, test_module=test_module)
2542 helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
2543 track_running_stats=track_running_stats, test_module=test_module)
2544 helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
2545 track_running_stats=track_running_stats, test_module=test_module)
2546 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
2547 track_running_stats=track_running_stats, test_module=test_module)
2548 helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
2549 track_running_stats=track_running_stats, test_module=test_module)
2550 helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
2551 track_running_stats=track_running_stats, test_module=test_module)
2552 helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
2553 track_running_stats=track_running_stats, test_module=test_module)
2554
Nikita Shulga583193e2023-04-11 17:23:36 +00002555 def test_batch_norm_backward(self):
Nikita Shulga24428582023-04-29 03:37:35 +00002556 inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
Nikita Shulga583193e2023-04-11 17:23:36 +00002557 x = torch.nn.BatchNorm2d(8).to("mps")
2558 y = torch.nn.BatchNorm2d(8).to("mps")
2559 y.weight.requires_grad = False
2560 y.bias.requires_grad = False
2561 outputs = y(x(inputs))
2562 # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2563 outputs.sum().backward()
2564
Nikita Shulga24428582023-04-29 03:37:35 +00002565 def test_layer_norm_backward(self):
2566 inputs = torch.rand(4, 4, device="mps", requires_grad=True)
2567 x = torch.nn.LayerNorm(4).to("mps")
2568 y = torch.nn.LayerNorm(4).to("mps")
2569 y.weight.requires_grad = False
2570 y.bias.requires_grad = False
2571 outputs = y(x(inputs))
2572 # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2573 outputs.sum().backward()
2574
Denis Vieriu80394bb2023-01-04 02:20:50 +00002575 def test_norm(self):
2576 a = torch.arange(9, dtype=torch.float, device="mps") - 4
2577 b = a.reshape((3, 3))
2578
2579 a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
2580 b_cpu = a_cpu.reshape((3, 3))
2581
2582 res = torch.norm(a)
2583 res_cpu = torch.norm(a_cpu)
2584 self.assertEqual(res, res_cpu)
2585
2586 res = torch.norm(b)
2587 res_cpu = torch.norm(b_cpu)
2588 self.assertEqual(res, res_cpu)
2589
2590 res = torch.norm(a, float('inf'))
2591 res_cpu = torch.norm(a_cpu, float('inf'))
2592 self.assertEqual(res, res_cpu)
2593
2594 res = torch.norm(b, float('inf'))
2595 res_cpu = torch.norm(b_cpu, float('inf'))
2596 self.assertEqual(res, res_cpu)
2597
2598 c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
2599 c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
2600
2601 res = torch.norm(c, dim=0)
2602 res_cpu = torch.norm(c_cpu, dim=0)
2603 self.assertEqual(res, res_cpu)
2604
2605 res = torch.norm(c, dim=1)
2606 res_cpu = torch.norm(c_cpu, dim=1)
2607 self.assertEqual(res, res_cpu)
2608
2609 res = torch.norm(c, p=1, dim=1)
2610 res_cpu = torch.norm(c_cpu, p=1, dim=1)
2611 self.assertEqual(res, res_cpu)
2612
2613 d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
2614 d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
2615
2616 res = torch.norm(d, dim=(1, 2))
2617 res_cpu = torch.norm(d_cpu, dim=(1, 2))
2618 self.assertEqual(res, res_cpu)
2619
2620 res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
2621 res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
2622 self.assertEqual(res, res_cpu)
2623
Denis Vieriu89baa1a2023-04-26 01:34:24 +00002624 def test_linalg_vector_norm(self):
2625 x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps")
2626 x_cpu = x_mps.detach().clone().cpu()
2627
2628 res_mps = torch.linalg.vector_norm(x_mps, ord=0)
2629 res_cpu = torch.linalg.vector_norm(x_cpu, ord=0)
2630 self.assertEqual(res_mps, res_cpu)
2631
2632 a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4
2633 a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4
2634
2635 B_mps = a_mps.reshape(3, 3, 3)
2636 B_cpu = a_cpu.reshape(3, 3, 3)
2637
2638 res_mps = torch.linalg.vector_norm(a_mps, ord=3.5)
2639 res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5)
2640 self.assertEqual(res_mps, res_cpu)
2641
2642 res_mps = torch.linalg.vector_norm(B_mps, ord=3.5)
2643 res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5)
2644 self.assertEqual(res_mps, res_cpu)
2645
2646 for dim in range(0, B_mps.dim()):
2647 res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim)
2648 res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim)
2649 self.assertEqual(res_mps, res_cpu)
2650
2651
Kulin Seth77b68852022-06-10 13:25:41 +00002652 def test_layer_norm(self):
2653 # TODO: Test non-contiguous
2654 def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):
2655 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
2656 x = cpu_x.detach().clone().to('mps').requires_grad_()
2657
2658 cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype)
2659 mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
2660 cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2661 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2662 cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2663 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2664
Thomas4935b592022-11-23 02:18:03 +00002665 if (elementwise_affine):
Kulin Seth77b68852022-06-10 13:25:41 +00002666 cpu_op.weight = torch.nn.Parameter(cpu_wt)
2667 mps_op.weight = torch.nn.Parameter(wt)
2668 cpu_op.bias = torch.nn.Parameter(cpu_bias)
2669 mps_op.bias = torch.nn.Parameter(bias)
2670
2671 cpu_result = cpu_op(cpu_x)
2672 result = mps_op(x)
2673
2674 cpu_grad = torch.randn(cpu_result.shape)
2675 grad = cpu_grad.to('mps')
2676
2677 cpu_result.backward(cpu_grad)
2678 result.backward(grad)
2679
2680 self.assertEqual(result, cpu_result)
2681 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002682 if (elementwise_affine):
Kulin Seth77b68852022-06-10 13:25:41 +00002683 self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad)
2684 self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad)
2685
2686 for elementwise_affine in [True, False]:
2687 helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine)
2688 helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
2689 helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
2690
Nikita Shulga075a4942023-03-09 22:09:10 +00002691 # Regression test for https://github.com/pytorch/pytorch/issues/96113
2692 torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
2693
jhavukainen6a539e82024-05-22 21:48:49 +00002694 @xfailIf(product_version < 14.0)
2695 def test_ifft(self):
2696 # See: https://github.com/pytorch/pytorch/issues/124096
2697 device = torch.device("mps")
2698
2699 N = 64
2700 signal = torch.rand(N, device=device)
2701 fft_result = torch.fft.rfft(signal)
2702 ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0])
2703
2704 # Expecting the inverted to yield the original signal
2705 self.assertEqual(ifft_result, signal)
2706
Kulin Sethe011a8e2022-05-13 18:28:53 +00002707 def test_instance_norm(self):
2708 def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
2709
2710 import numpy as np
2711 np.random.seed(332)
2712 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2713 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00002714 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002715 cpu_x = cpu_x.to(memory_format=torch.channels_last)
2716 cpu_x.retain_grad()
2717 x = cpu_x.detach().clone().to('mps').requires_grad_()
2718
2719 mean_shape = [shape[1]]
2720 cpu_running_mean = None
2721 cpu_running_var = None
2722 running_mean = None
2723 running_var = None
Thomas4935b592022-11-23 02:18:03 +00002724 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002725 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2726 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2727 var_arr = 32 * np.random.random_sample(size=mean_shape)
2728 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2729 running_mean = cpu_running_mean.detach().clone().to('mps')
2730 running_var = cpu_running_var.detach().clone().to('mps')
2731
2732 weight = None
2733 cpu_weight = None
2734 bias = None
2735 cpu_bias = None
Thomas4935b592022-11-23 02:18:03 +00002736 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002737 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2738 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2739 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2740 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2741
2742 y = None
2743 ref_y = None
2744
Thomas4935b592022-11-23 02:18:03 +00002745 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002746 ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
2747 weight=cpu_weight,
2748 bias=cpu_bias,
2749 momentum=momentum, eps=eps)
2750 y = torch.nn.functional.instance_norm(x, running_mean, running_var,
2751 weight=weight,
2752 bias=bias,
2753 momentum=momentum, eps=eps)
2754
2755 else:
2756
2757 instancenorm_op = None
2758 mps_instancenorm_op = None
2759
Thomas4935b592022-11-23 02:18:03 +00002760 if (len(shape) == 3):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002761 instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2762 eps=eps,
2763 momentum=momentum,
2764 affine=wts,
2765 track_running_stats=track_running_stats,
2766 device='cpu')
2767 mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2768 eps=eps,
2769 momentum=momentum,
2770 affine=wts,
2771 track_running_stats=track_running_stats,
2772 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002773 elif (len(shape) == 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002774 instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2775 eps=eps,
2776 momentum=momentum,
2777 affine=wts,
2778 track_running_stats=track_running_stats,
2779 device='cpu')
2780 mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2781 eps=eps,
2782 momentum=momentum,
2783 affine=wts,
2784 track_running_stats=track_running_stats,
2785 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002786 elif (len(shape) == 5):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002787 instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2788 eps=eps,
2789 momentum=momentum,
2790 affine=wts,
2791 track_running_stats=track_running_stats,
2792 device='cpu')
2793 mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2794 eps=eps,
2795 momentum=momentum,
2796 affine=wts,
2797 track_running_stats=track_running_stats,
2798 device='mps')
2799
Thomas4935b592022-11-23 02:18:03 +00002800 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002801 instancenorm_op.running_mean = cpu_running_mean
2802 instancenorm_op.running_var = cpu_running_var
2803 mps_instancenorm_op.running_mean = running_mean
2804 mps_instancenorm_op.running_var = running_var
Thomas4935b592022-11-23 02:18:03 +00002805 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002806 instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
2807 instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
2808 mps_instancenorm_op.weight = torch.nn.Parameter(weight)
2809 mps_instancenorm_op.bias = torch.nn.Parameter(bias)
2810
2811 ref_y = instancenorm_op(cpu_x)
2812 y = mps_instancenorm_op(x)
2813
2814 self.assertEqual(y, ref_y)
Thomas4935b592022-11-23 02:18:03 +00002815 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002816 self.assertEqual(running_mean, cpu_running_mean)
2817 self.assertEqual(running_var, cpu_running_var)
2818 else:
2819 self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
2820 self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
2821
2822 cpu_grad = torch.randn(ref_y.shape)
2823 grad = cpu_grad.to('mps')
2824 ref_y.backward(gradient=cpu_grad)
2825 y.backward(gradient=grad)
2826
2827 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002828 if (wts):
2829 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002830 self.assertEqual(weight.grad, cpu_weight.grad)
2831 self.assertEqual(bias.grad, cpu_bias.grad)
2832 else:
2833 self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
2834 self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
2835
2836 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2837 for test_module in [False, True]:
2838 for track_running_stats in [True, False]:
2839 for channels_last in [False]:
Thomas4935b592022-11-23 02:18:03 +00002840 if (channels_last and len(shape) != 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002841 continue
2842 # Running stats must be tracked in eval mode
Thomas4935b592022-11-23 02:18:03 +00002843 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002844 helper(shape, eps=0, momentum=1, channels_last=channels_last,
2845 track_running_stats=track_running_stats, test_module=test_module)
2846 helper(shape, channels_last=channels_last,
2847 track_running_stats=track_running_stats, test_module=test_module)
2848 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2849 track_running_stats=track_running_stats, test_module=test_module)
2850 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2851 track_running_stats=track_running_stats, test_module=test_module)
2852 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2853 track_running_stats=track_running_stats, test_module=test_module)
2854 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2855 track_running_stats=track_running_stats, test_module=test_module)
2856 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2857 track_running_stats=track_running_stats, test_module=test_module)
2858 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2859 track_running_stats=track_running_stats, test_module=test_module)
2860 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2861 track_running_stats=track_running_stats, test_module=test_module)
2862 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2863 track_running_stats=track_running_stats, test_module=test_module)
2864
igm50303176262023-09-20 02:18:24 +00002865 def test_weight_norm(self):
Nikita Shulga27458cc2024-06-14 11:23:27 -07002866 def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim):
Nikita Shulga9035fff2024-06-14 11:23:30 -07002867 cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim)
2868 norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim)
Nikita Shulga27458cc2024-06-14 11:23:27 -07002869
2870 cpu_out = cpu_norm(cpu_x)
2871 out = norm(x)
2872
2873 self.assertEqual(cpu_out, out)
2874
2875 cpu_grad = torch.randn(cpu_out.shape)
2876 grad = cpu_grad.to('mps')
2877 cpu_out.backward(gradient=cpu_grad)
2878 out.backward(gradient=grad)
2879
Nikita Shulga9035fff2024-06-14 11:23:30 -07002880 self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad)
2881 self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad)
Nikita Shulga27458cc2024-06-14 11:23:27 -07002882
2883 self.assertEqual(x.grad, cpu_x.grad)
2884
igm50303176262023-09-20 02:18:24 +00002885 def helper(dim, layer='linear', dtype=torch.float32):
2886 # linear layer
2887 if layer == 'linear':
2888 cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True)
2889 x = cpu_x.detach().clone().to('mps').requires_grad_()
2890
2891 cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True)
2892 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2893
2894 cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True)
2895 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2896
2897 cpu_linear = torch.nn.Linear(5, 10, device='cpu')
2898 linear = torch.nn.Linear(5, 10, device='mps')
2899
2900 with torch.no_grad():
2901 cpu_linear.weight.copy_(cpu_weight)
2902 cpu_linear.bias.copy_(cpu_bias)
2903 linear.weight.copy_(weight)
2904 linear.bias.copy_(bias)
Nikita Shulga27458cc2024-06-14 11:23:27 -07002905 validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim)
igm50303176262023-09-20 02:18:24 +00002906
2907 # conv layer
2908 if layer == 'conv':
2909 cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True)
2910 x = cpu_x.detach().clone().to('mps').requires_grad_()
2911
2912 cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu')
2913 conv = torch.nn.Conv2d(3, 3, 3, device='mps')
2914
2915 with torch.no_grad():
2916 conv.weight.copy_(cpu_conv.weight)
2917 conv.bias.copy_(cpu_conv.bias)
2918
Nikita Shulga27458cc2024-06-14 11:23:27 -07002919 validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
igm50303176262023-09-20 02:18:24 +00002920
Nikita Shulga27458cc2024-06-14 11:23:27 -07002921 # conv3d layer
Lucas Steuernagel2e517b22023-12-15 23:05:01 +00002922 if layer == 'conv3d':
2923 cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
2924 x = cpu_x.detach().clone().to('mps').requires_grad_()
2925
2926 cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
2927 conv = torch.nn.Conv3d(3, 3, 3, device='mps')
2928
2929 with torch.no_grad():
2930 conv.weight.copy_(cpu_conv.weight)
2931 conv.bias.copy_(cpu_conv.bias)
2932
Nikita Shulga27458cc2024-06-14 11:23:27 -07002933 validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
igm50303176262023-09-20 02:18:24 +00002934
2935 helper(0, layer='linear')
2936 helper(1, layer='linear')
2937 helper(-1, layer='linear')
2938
2939 helper(0, layer='conv')
2940 helper(1, layer='conv')
2941 helper(2, layer='conv')
2942 helper(3, layer='conv')
2943 helper(-1, layer='conv')
2944
Lucas Steuernagel2e517b22023-12-15 23:05:01 +00002945 if product_version >= 13.2:
2946 # Conv3d is only available from MacOS 13 onwards
2947 helper(0, layer='conv3d')
2948 helper(1, layer='conv3d')
2949 helper(2, layer='conv3d')
2950 helper(3, layer='conv3d')
2951 helper(4, layer='conv3d')
2952 helper(-1, layer='conv3d')
2953
Kulin Sethe011a8e2022-05-13 18:28:53 +00002954 # Test conv2d
2955 def test_conv2d_unit(self):
2956 def helper(input_shape, wt_shape,
2957 stride=1, padding=0,
2958 dilation=1, groups=1,
2959 bias_shape=None):
2960
2961 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
2962 x = cpu_x.detach().clone().to('mps').requires_grad_()
2963
2964 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
2965 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2966
2967 cpu_bias = None
2968 bias = None
2969
Thomas4935b592022-11-23 02:18:03 +00002970 if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002971 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
2972 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2973
2974 y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
2975 padding=padding, dilation=dilation, groups=groups)
2976 ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
2977 padding=padding, dilation=dilation, groups=groups)
2978
2979 cpu_grad = torch.ones_like(ref_y)
2980 grad = cpu_grad.to('mps')
2981
2982 y.backward(gradient=grad)
2983 ref_y.backward(gradient=cpu_grad)
2984
2985 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
2986 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
2987 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
Thomas4935b592022-11-23 02:18:03 +00002988 if (bias_shape is not None):
Kulin Seth3d833212022-05-20 03:18:09 +00002989 self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002990
2991 N = 1
2992 C_in = 3
2993 C_out = 64
2994 H = 64
2995 W = 64
2996 kH = 4
2997 kW = 4
2998 stride = 2
2999 padding = 1
3000
3001 helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
3002
3003 N = 4
3004 C_in = 16
3005 H = 32
3006 W = 32
3007
3008 C_out = 8
3009 kH = 3
3010 kW = 3
3011
3012 for groups in [1, 2, 4]:
3013 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3014 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3015
3016 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3017 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3018
3019 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3020 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3021
3022 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3023 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3024 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3025 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3026
3027 # Test conv transpose 2d
3028 def test_conv_transpose2d(self):
3029 def helper(input_shape, wt_shape,
3030 stride=1, padding=0,
3031 output_padding=0,
3032 dilation=1, groups=1,
3033 bias_shape=None):
3034
3035 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
3036 x = cpu_x.detach().clone().to('mps').requires_grad_()
3037
3038 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
3039 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3040
3041 cpu_bias = None
3042 bias = None
3043
Thomas4935b592022-11-23 02:18:03 +00003044 if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003045 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
3046 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3047
3048 y = torch.nn.functional.conv_transpose2d(
3049 x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
3050 ref_y = torch.nn.functional.conv_transpose2d(
3051 cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
3052 output_padding=output_padding, groups=groups, dilation=dilation)
3053
3054 cpu_grad = torch.randn(ref_y.shape)
3055 grad = cpu_grad.to('mps')
3056
3057 y.backward(gradient=grad)
3058 ref_y.backward(gradient=cpu_grad)
3059
3060 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
3061 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
3062 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
3063
Thomas4935b592022-11-23 02:18:03 +00003064 # if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003065 # print(cpu_bias.grad)
3066 # print(bias.grad.to('cpu'))
3067 # self.assertEqual(bias.grad, cpu_bias.grad)
3068
3069 N = 4
Alban Desmaisonbde246f2022-05-30 10:36:31 -04003070 C_in = 2
Kulin Sethe011a8e2022-05-13 18:28:53 +00003071 H = 32
3072 W = 32
3073
3074 C_out = 8
3075 groups = 1
3076 kH = 3
3077 kW = 3
3078
3079 for stride in [1, 2, 3]:
3080 for padding in [0, 1, 2]:
3081 for output_padding in [0, 1, 2]:
3082 for dilation in [1, 2]:
Thomas4935b592022-11-23 02:18:03 +00003083 if (output_padding >= stride or output_padding >= dilation):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003084 continue
3085 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3086 padding=padding, output_padding=output_padding, dilation=dilation)
3087 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3088 padding=padding, output_padding=output_padding, dilation=dilation)
3089
3090 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3091 padding=padding, output_padding=output_padding, dilation=dilation)
3092 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3093 padding=padding, output_padding=output_padding, dilation=dilation)
3094
3095 # Test sigmoid
3096 def test_sigmoid(self):
3097 def helper(shape):
3098
3099 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3100 x = cpu_x.detach().clone().to('mps').requires_grad_()
3101
3102 sigmoid_op = torch.nn.Sigmoid()
3103
3104 y = sigmoid_op(x)
3105 ref_y = sigmoid_op(cpu_x)
3106
3107 cpu_grad = torch.ones_like(ref_y)
3108 grad = cpu_grad.to('mps')
3109
3110 y.backward(gradient=grad)
3111 ref_y.backward(gradient=cpu_grad)
3112
3113 self.assertEqual(y, ref_y)
3114 self.assertEqual(x.grad, cpu_x.grad)
3115
3116 helper((2, 3, 4, 5))
3117 helper((2, 3, 4))
3118 helper((2, 8, 4, 5))
3119
3120 # Test tanh
3121 def test_tanh(self):
3122 def helper(shape):
3123
3124 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3125 x = cpu_x.detach().clone().to('mps').requires_grad_()
3126
3127 tanh_op = torch.nn.Tanh()
3128
3129 y = tanh_op(x)
3130 ref_y = tanh_op(cpu_x)
3131
3132 cpu_grad = torch.ones_like(ref_y)
3133 grad = cpu_grad.to('mps')
3134
3135 y.backward(gradient=grad)
3136 ref_y.backward(gradient=cpu_grad)
3137
3138 self.assertEqual(y, ref_y)
3139 self.assertEqual(x.grad, cpu_x.grad)
3140
3141 helper((2, 3, 4, 5))
3142 helper((2, 3, 4))
3143 helper((2, 8, 4, 5))
3144
3145 def test_threshold(self):
3146 def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
3147 m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
3148
3149 input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
3150 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
3151
3152 output_cpu = m(input_cpu)
3153 output_mps = m(input_mps)
3154
3155 cpu_grad = torch.ones_like(output_cpu)
3156 mps_grad = cpu_grad.to('mps')
3157
3158 self.assertEqual(output_cpu, output_mps)
3159
3160 if requires_grad:
3161 output_cpu.backward(gradient=cpu_grad)
3162 output_mps.backward(gradient=mps_grad)
3163
3164 self.assertEqual(input_cpu.grad, input_mps.grad)
3165
3166 helper(threshold=0.1, value=20, num_elems=2)
3167 helper(threshold=-0.1, value=10, num_elems=10)
3168 helper(threshold=0.5, value=-15, num_elems=100)
3169 helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
3170
3171 # Test pow
3172 def test_pow(self):
3173 def helper(shape):
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00003174 # aten::pow.Tensor_Tensor
Kulin Sethe011a8e2022-05-13 18:28:53 +00003175 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3176 x = cpu_x.detach().clone().to('mps')
3177 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3178 y = cpu_y.detach().clone().to('mps')
3179 z = torch.pow(x, y)
3180 ref_z = torch.pow(cpu_x, cpu_y)
3181
3182 self.assertEqual(z, ref_z)
3183
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00003184 # aten::pow.Tensor_Scalar
Kulin Sethe011a8e2022-05-13 18:28:53 +00003185 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3186 x = cpu_x.detach().clone().to('mps')
3187 exp = random.random()
3188 z = torch.pow(x, exp)
3189 ref_z = torch.pow(cpu_x, exp)
3190
3191 self.assertEqual(z, ref_z)
3192
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00003193 # aten::pow.Scalar
3194 x = random.random()
3195 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3196 y = cpu_y.detach().clone().to('mps')
3197 z = torch.pow(x, y)
3198 ref_z = torch.pow(x, cpu_y)
3199
3200 self.assertEqual(z, ref_z)
3201
Kulin Sethe011a8e2022-05-13 18:28:53 +00003202 helper((2, 8, 4, 5))
3203
3204 # Test addcmul
3205 def test_addcmul(self):
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003206 def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None):
3207 def rand_helper(dtype):
3208 if dtype.is_floating_point:
3209 return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
3210 return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003211
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003212 cpu_x = rand_helper(xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003213 x = cpu_x.detach().clone().to('mps')
3214
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003215 cpu_y = rand_helper(ytype if ytype is not None else xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003216 y = cpu_y.detach().clone().to('mps')
3217
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003218 cpu_z = rand_helper(ztype if ztype is not None else xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003219 z = cpu_z.detach().clone().to('mps')
3220
3221 y = torch.addcmul(x, y, z, value=value)
3222 ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
3223
3224 self.assertEqual(y, ref_y)
3225
3226 helper((2, 3, 4, 5), 0.1)
3227 helper((2, 8, 4, 5), 0.1)
3228 helper((2, 3, 4, 5), 0.2)
3229 helper((2, 8, 4, 5), 0.2)
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003230 # Integral types
3231 helper((2, 2), 1.0, xtype=torch.int32)
3232 helper((2, 2), 2.0, xtype=torch.int16)
3233
3234 # Mixed types
3235 helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32)
3236 helper((3, 2), 1.0, ytype=torch.float16)
3237 helper((2, 3), 1.0, ztype=torch.float16)
3238 helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8)
3239 helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003240
3241 # Test addcdiv
3242 def test_addcdiv(self):
3243 def helper(shape, value):
3244 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3245 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3246 # clamp to avoid division by 0
3247 cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
3248 cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3249
3250 mps_x = cpu_x.detach().clone().to('mps')
3251 mps_y = cpu_y.detach().clone().to('mps')
3252 mps_z = cpu_z.detach().clone().to('mps')
3253 mps_out = cpu_out.detach().clone().to('mps')
3254
3255 result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
3256 result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
3257 self.assertEqual(result_div_mps, result_div_cpu)
3258 # test .out variant
3259 self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
3260
3261 helper((2, 3, 4, 5), 0.1)
3262 helper((2, 8, 4, 5), 0.2)
3263 helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally
3264
Joona Havukainena5ba9b22024-06-06 16:09:18 +00003265 def test_addcdiv_transpose(self):
3266 # Regression test for issue https://github.com/pytorch/pytorch/issues/118115
3267 # Testing continuity of all input tensors
3268
3269 def helper(shape, value):
3270 shape_t = shape[::-1]
3271 for i in range(2):
3272 for j in range(2):
3273 for k in range(2):
3274 x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t()
3275 y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t()
3276 z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t()
3277
3278 x_mps = x.detach().clone().to(device="mps")
3279 y_mps = y.detach().clone().to(device="mps")
3280 z_mps = z.detach().clone().to(device="mps")
3281
3282 result_cpu = x.addcdiv_(y, z, value=value)
3283 result_mps = x_mps.addcdiv(y_mps, z_mps, value=value)
3284 result_mps_out = result_cpu.detach().clone().to('mps')
3285 torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value)
3286
3287 self.assertEqual(result_cpu, result_mps)
3288 self.assertEqual(result_cpu, result_mps_out)
3289
3290 helper((2, 3), 1.0)
3291 helper((2, 3), 0.2)
3292 helper((100, 300), 1.0)
3293 helper((100, 300), 0.2)
3294
Ramin Azarmehraa62b3e2022-05-31 19:15:45 +00003295 def test_buffer_size_match(self):
3296 # this test shouldn't cause any crash
3297 size = 16
3298 cpu_A = torch.rand(size, device='cpu')
3299 cpu_F = torch.rand(size, size, size, device='cpu')
3300
3301 mps_A = cpu_A.to('mps')
3302 mps_F = cpu_F.to('mps')
3303 self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
3304
Kulin Sethe011a8e2022-05-13 18:28:53 +00003305 def test_transpose_inplace(self):
3306 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3307 cpu_x = torch.tensor(values, device='cpu')
3308 mps_x = torch.tensor(values, device='mps')
3309
3310 cpu_x.transpose_(0, 1)
3311 mps_x.transpose_(0, 1)
3312 self.assertEqual(cpu_x, mps_x.to('cpu'))
3313
Kulin Seth4858c562022-06-02 06:17:19 +00003314 def test_expand_cpu_to_mps_copy(self):
3315 # https://github.com/pytorch/pytorch/issues/78642
3316
3317 x = torch.tensor(1).expand([10]).to("mps")
3318 x_cpu = torch.tensor(1).expand([10])
3319
3320 self.assertEqual(x_cpu, x.cpu())
3321
Denis Vieriu0a677f22023-01-10 22:45:48 +00003322 def test_cpu_to_strided_mps_copy(self):
3323 # https://github.com/pytorch/pytorch/issues/86975
3324
3325 a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3326 b1 = torch.Tensor([-1, -1])
3327 a1[1:, 1] = b1
3328
3329 a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3330 b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
3331 a2[1:, 1] = b2
3332
3333 self.assertEqual(a1, a2)
3334
Denis Vieriue3ac1092023-02-07 16:20:08 +00003335 def test_view_slice_reshape(self):
3336 x = torch.randn([1, 4, 4], device="mps")
3337 y = x[0, :1, 1:]
3338
3339 x_cpu = x.to("cpu")
3340 y_cpu = x_cpu[0, :1, 1:]
3341
3342 r = y + 1
3343 r_cpu = y_cpu + 1
3344 self.assertEqual(r, r_cpu)
3345
3346 def test_slice_reshape(self):
3347 x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
3348 x_cpu = x.detach().clone().to("cpu")
3349
3350 x = x[:, 3:].view(2, 3, 4, 1)
3351 x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1)
3352 self.assertEqual(x, x_cpu)
3353
3354 x = x + 2
3355 x_cpu = x_cpu + 2
3356 self.assertEqual(x, x_cpu)
3357
Denis Vieriu304a9542023-03-03 08:08:31 +00003358 def test_reshape_storage_offset(self):
3359 # https://github.com/pytorch/pytorch/issues/95883
3360 B = 4
3361 T = 1
3362
3363 lin_cpu = nn.Linear(10, 256)
3364 lin_mps = nn.Linear(10, 256, device="mps")
3365
3366 # Use the same weights and bias as the ones from the cpu
3367 lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
3368 lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
3369
3370 x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
3371 x_cpu = x_mps.detach().clone().cpu().requires_grad_()
3372 x_mps = lin_mps(x_mps)
3373 x_cpu = lin_cpu(x_cpu)
3374
3375 self.assertEqual(x_mps.shape, (B, T, 256))
3376 self.assertEqual(x_cpu.shape, (B, T, 256))
3377
3378 cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
3379 cls_token_cpu = cls_token_mps.detach().clone().cpu()
3380 x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
3381 x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
3382
3383 x_mps = x_mps.transpose(0, 1)
3384 x_cpu = x_cpu.transpose(0, 1)
3385
3386 target_mps = torch.rand_like(x_mps)
3387 target_cpu = target_mps.detach().clone().cpu()
3388 loss_mps = F.mse_loss(x_mps, target_mps)
3389 loss_cpu = F.mse_loss(x_cpu, target_cpu)
3390 self.assertEqual(loss_mps, loss_cpu)
3391
3392 loss_mps.backward()
3393 loss_cpu.backward()
3394 self.assertEqual(x_mps.grad, x_cpu.grad)
3395
Li-Huai (Allan) Lin88a659e2023-11-08 16:19:38 -08003396 def test_stack_storage_offset(self):
Denis Vieriu304a9542023-03-03 08:08:31 +00003397 # https://github.com/pytorch/pytorch/issues/87856
3398 x_cpu = torch.tensor([[1, 2]])
3399 x_mps = x_cpu.detach().clone().to("mps")
3400
3401 y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
3402 y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
3403
3404 self.assertEqual(y_cpu, y_mps)
3405
3406 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3407 t_cpu = t_mps.detach().cpu().detach()
3408
3409 x_mps = t_mps[2:]
3410 y_mps = t_mps[:2]
3411
3412 x_cpu = t_cpu[2:]
3413 y_cpu = t_cpu[:2]
3414
3415 res_mps = torch.stack((y_mps, x_mps), dim=-1)
3416 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
3417
3418 self.assertEqual(res_mps, res_cpu)
3419
3420 def test_unsafe_chunk(self):
3421 # https://github.com/pytorch/pytorch/issues/91065
3422 a = torch.rand(5, dtype=torch.float32, device="cpu")
3423 ret = a.unsafe_chunk(4, 0)
3424 y = ret[0] * ret[2]
3425 a_mps = a.to("mps")
3426 ret_mps = a_mps.unsafe_chunk(4, 0)
3427 y_mps = ret_mps[0] * ret_mps[2]
3428 self.assertEqual(y, y_mps)
3429
Ramin Azarmehr9511b9f2023-02-18 16:29:01 +00003430 def test_slice_casting(self):
3431 # generate random binary numbers
3432 cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
3433 mps_in = cpu_in.detach().clone().to("mps")
3434 # check copy_cast(unit8 -> bool) on tensors with storage offset
3435 cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
3436 mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
3437 self.assertEqual(cpu_out, mps_out)
3438
Denis Vieriue3ac1092023-02-07 16:20:08 +00003439 def test_slice_reshape_contg_view(self):
3440 import torch
3441
3442 x_mps = torch.randn(1, 4800, 2, device="mps")
3443 x_cpu = x_mps.detach().clone().cpu()
3444
3445 r_mps = x_mps + 2
3446 r_cpu = x_cpu + 2
3447
3448 self.assertEqual(r_mps, r_cpu)
3449
Denis Vieriu86efa102023-02-23 17:26:10 +00003450 def test_contiguous_slice_2d(self):
3451 def helper(shape):
3452 for i in range(0, shape[0]):
3453 for j in range(0, shape[1]):
3454 t_mps = torch.randn(shape, device="mps")
3455 t_cpu = t_mps.detach().clone().cpu()
3456
3457 y_mps = t_mps[i:, :j]
3458 y_cpu = t_cpu[i:, :j]
3459 self.assertEqual(y_mps + 1, y_cpu + 1)
3460
3461 y_mps = t_mps[i:, j]
3462 y_cpu = t_cpu[i:, j]
3463 self.assertEqual(y_mps + 1, y_cpu + 1)
3464
3465 y_mps = t_mps[i, :j]
3466 y_cpu = t_cpu[i, :j]
3467 self.assertEqual(y_mps + 1, y_cpu + 1)
3468
3469 y_mps = t_mps[:i, :j]
3470 y_cpu = t_cpu[:i, :j]
3471 self.assertEqual(y_mps + 1, y_cpu + 1)
3472
3473 y_mps = t_mps[:i, j]
3474 y_cpu = t_cpu[:i, j]
3475 self.assertEqual(y_mps + 1, y_cpu + 1)
3476
3477 y_mps = t_mps[:i, j:]
3478 y_cpu = t_cpu[:i, j:]
3479 self.assertEqual(y_mps + 1, y_cpu + 1)
3480
3481 l = []
3482 for N in range(1, 3):
3483 l.append(N)
3484 for C in range(1, 3):
3485 l.append(C)
3486 helper(l)
3487 for D in range(1, 3):
3488 l.append(D)
3489 helper(l)
3490 for H in range(1, 3):
3491 l.append(H)
3492 helper(l)
3493 for W in range(1, 3):
3494 l.append(W)
3495 helper(l)
3496 l.pop()
3497 l.pop()
3498 l.pop()
3499 l.pop()
3500 l.pop()
3501
3502 helper([9, 15, 4])
3503 helper([9, 3, 2])
3504 helper([3, 4, 18, 22])
3505 helper([3, 4, 18, 22, 150])
3506
Denis Vieriue5a959a2023-03-01 16:16:49 +00003507 def test_contiguous_slice_3d(self):
3508 x = torch.randn(2, 3, 3, device="mps")
3509 x_cpu = x.detach().clone().cpu()
3510 x = x[:1]
3511 x_cpu = x_cpu[:1]
3512 out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
3513 out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
3514 self.assertEqual(out, out_cpu)
3515
Denis Vieriub71c7102022-12-08 17:59:55 +00003516 def test_view_slice(self):
3517 # https://github.com/pytorch/pytorch/issues/83995
3518 NUM_SAMPLES = 60
3519 s = (0, 1)
3520
3521 X = torch.rand(8000, 3, dtype=torch.float32, device='cpu')
3522 X_mps = X.detach().clone().to("cpu")
3523
3524 idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s))
3525 pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1]))
3526 idx_mps = idx.to("mps")
3527 pts_mps = pts.to("mps")
3528 pts[:, s] = idx
3529 pts_mps[:, s] = idx_mps
3530
3531 actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float)
3532 actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
3533
3534 for i in range(NUM_SAMPLES):
3535 for j in range(X.shape[1]):
3536 actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j]
3537 actual_pts[i, j] = X[pts[i, j], j]
3538 self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j])
3539
Denis Vieriudbf96162023-01-02 16:31:27 +00003540 def test_slice_scatter(self):
3541 shape = (4, 4)
3542 tensor = torch.randint(10, shape, device="mps")
3543 tensor_before = tensor.clone()
3544 torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
3545 torch.testing.assert_close(tensor, tensor_before)
Denis Vieriub71c7102022-12-08 17:59:55 +00003546
Kulin Sethe011a8e2022-05-13 18:28:53 +00003547 def test_slice(self):
3548 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3549 cpu_x = torch.tensor(values, device='cpu')
3550 mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
3551
3552 cpu_slice1 = cpu_x[:2, :]
3553 mps_slice1 = mps_x[:2, :]
Kulin Sethe011a8e2022-05-13 18:28:53 +00003554 self.assertEqual(cpu_slice1, mps_slice1)
3555
3556 cpu_slice2 = cpu_x[:, :1]
3557 mps_slice2 = mps_x[:, :1]
Kulin Sethe011a8e2022-05-13 18:28:53 +00003558 self.assertEqual(cpu_slice2, mps_slice2)
3559
3560 cpu_slice3 = cpu_x[1:2, :]
3561 mps_slice3 = mps_x[1:2, :]
3562 self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
3563
3564 cpu_slice4 = cpu_x[1, :]
3565 mps_slice4 = mps_x[1, :].to('cpu')
3566 self.assertEqual(cpu_slice4, mps_slice4)
3567
Denis Vieriua6b75bb2022-08-22 17:05:53 +00003568 def test_scalar_from_slice_unary(self):
3569 # https://github.com/pytorch/pytorch/issues/82543
3570 tensor_list = torch.tensor([1.0, 1.2], device="mps")
3571
3572 for scalar in tensor_list:
3573 r_mps = torch.ceil(scalar)
3574 r_cpu = torch.ceil(scalar.to("cpu"))
3575 self.assertEqual(r_mps.cpu(), r_cpu)
3576
3577 def test_scalar_from_slice_binary(self):
3578 # https://github.com/pytorch/pytorch/issues/82543
3579 def helper(binary_op):
3580 tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
3581
3582 for scalar in tensor_list:
3583 r_mps = binary_op(scalar, 1.0)
3584 r_cpu = binary_op(scalar.cpu(), 1.0)
3585 self.assertEqual(r_mps.cpu(), r_cpu)
3586 helper(torch.sub)
3587 helper(torch.add)
3588 helper(torch.not_equal)
3589 helper(torch.eq)
3590
Kulin Sethd63db522022-05-28 14:41:56 +00003591 def test_slice_contiguous_view(self):
3592 # https://github.com/pytorch/pytorch/issues/77750
3593
3594 def helper(operator):
3595 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3596 t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
3597
3598 # contiguous view
3599 x_mps = t_mps[2:] # 3, 4
3600 y_mps = t_mps[:2] # 1, 2
3601
3602 x_cpu = t_cpu[2:]
3603 y_cpu = t_cpu[:2]
3604
3605 res_mps = res_cpu = None
3606 if operator == "<=":
3607 res_mps = x_mps <= y_mps
3608 res_cpu = x_cpu <= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003609 elif operator == "<":
Kulin Sethd63db522022-05-28 14:41:56 +00003610 res_mps = x_mps < y_mps
3611 res_cpu = x_cpu < y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003612 elif operator == ">=":
Kulin Sethd63db522022-05-28 14:41:56 +00003613 res_mps = x_mps >= y_mps
3614 res_cpu = x_cpu >= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003615 elif operator == ">":
Kulin Sethd63db522022-05-28 14:41:56 +00003616 res_mps = x_mps >= y_mps
3617 res_cpu = x_cpu >= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003618 elif operator == "==":
Kulin Sethd63db522022-05-28 14:41:56 +00003619 res_mps = x_mps == y_mps
3620 res_cpu = x_cpu == y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003621 elif operator == "!=":
Kulin Sethd63db522022-05-28 14:41:56 +00003622 res_mps = x_mps != y_mps
3623 res_cpu = x_cpu != y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003624 elif operator == "stack":
3625 res_mps = torch.stack((y_mps, x_mps), dim=-1)
3626 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
Kulin Sethd63db522022-05-28 14:41:56 +00003627
3628 self.assertEqual(res_mps, res_cpu)
3629
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003630 for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
Kulin Sethd63db522022-05-28 14:41:56 +00003631 helper(op)
3632
Denis Vieriube327ec2022-09-30 18:51:43 +00003633 def test_slice_of_slice(self):
3634 x = torch.tensor([0.5, 0.5], device="cpu")
3635 x_mps = torch.tensor([0.5, 0.5], device="mps")
3636
3637 tensor = x[1][None]
3638 tensor_mps = x_mps[1][None]
3639
3640 res = tensor.ne(0)
3641 res_mps = tensor_mps.ne(0)
3642
3643 self.assertEqual(res, res_mps)
3644
Kulin Sethd63db522022-05-28 14:41:56 +00003645 def test_index_storage_offset(self):
3646 # https://github.com/pytorch/pytorch/issues/78107
3647
3648 a = torch.tensor([8.2670e-01, -1.0293e+00])
3649 b_cpu = a[0]
3650 c_cpu = a[1]
3651
3652 # both 'b' and 'c' are views of 'a'
3653 # 'b' has a storage offset of 0, while 'c' has a storage offset of 1
3654 # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
3655 # otherwise it ends with same value as 'b'
3656 b = b_cpu.to('mps')
3657 c = c_cpu.to('mps')
3658
3659 res_mps = b > c
3660 res_cpu = b_cpu > c_cpu
3661 self.assertEqual(res_mps, res_cpu)
3662
3663 res_mps = c > b
3664 res_cpu = c_cpu > b_cpu
3665 self.assertEqual(res_mps, res_cpu)
3666
Kulin Sethe011a8e2022-05-13 18:28:53 +00003667 def test_flatten(self):
3668 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3669 cpu_x = torch.tensor(values, device='cpu')
3670 mps_x = torch.tensor(values, device='mps')
3671
3672 cpu_flatten1 = cpu_x.flatten()
3673 mps_flatten1 = mps_x.flatten().to('cpu')
3674 self.assertEqual(cpu_flatten1, mps_flatten1)
3675
3676 cpu_flatten2 = cpu_x.flatten(start_dim=1)
3677 mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
3678 self.assertEqual(cpu_flatten2, mps_flatten2)
3679
3680 cpu_flatten3 = cpu_x.flatten(end_dim=1)
3681 mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
3682 self.assertEqual(cpu_flatten3, mps_flatten3)
3683
3684 # Test repeat
3685 def test_repeat(self):
3686 def helper(shape, repeats):
3687
3688 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3689 x = cpu_x.detach().clone().to('mps').requires_grad_()
3690
3691 y = x.repeat(repeats)
3692 ref_y = cpu_x.repeat(repeats)
3693
3694 cpu_grad = torch.randn(ref_y.shape)
3695 grad = cpu_grad.to('mps')
3696
3697 y.backward(gradient=grad)
3698 ref_y.backward(gradient=cpu_grad)
3699
3700 self.assertEqual(y, ref_y)
3701 self.assertEqual(x.grad, cpu_x.grad)
3702
3703 helper((2, 3, 4, 5), (2, 3, 4, 5))
3704 helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
3705 helper((3, 4, 5), (2, 3, 4, 5))
3706 helper((3, 4, 5), (2, 2, 2))
3707
Henry Chengfe0c7fb2023-02-12 08:43:52 +00003708 def test_torch_repeat_interleave(self, device="mps"):
3709 y = torch.tensor([[1, 2], [3, 4]], device=device)
3710 # exercise single argument function signature
3711 temp = y.repeat_interleave(2)
3712 self.assertEqual(torch.Size([8]), temp.size())
3713
3714 for dtype in [torch.int, torch.long]:
3715 lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
3716 output_size = torch.sum(lengths)
3717 a = torch.repeat_interleave(
3718 y,
3719 lengths,
3720 dim=0,
3721 )
3722 self.assertEqual(a.dtype, y.dtype)
3723 self.assertEqual(a.size(), torch.Size([3, 2]))
3724
3725 a_with_output = torch.repeat_interleave(
3726 y,
3727 lengths,
3728 dim=0,
3729 output_size=output_size,
3730 )
3731 self.assertEqual(a_with_output.dtype, y.dtype)
3732 self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
3733
3734 def test_repeat_interleave(self, device="mps"):
3735 x = torch.tensor([0, 1, 2, 3], device=device)
BJ Hargravedc52ba22023-04-12 19:23:04 +00003736 expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device)
3737 # Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32
3738 self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=product_version >= 13.3)
Henry Chengfe0c7fb2023-02-12 08:43:52 +00003739
3740 with self.assertRaises(RuntimeError):
3741 torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
3742
3743 with self.assertRaises(RuntimeError):
3744 torch.repeat_interleave(torch.arange(4.0, device=device))
3745
3746 with self.assertRaises(RuntimeError):
3747 torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
3748
3749 y = torch.tensor([[1, 2], [3, 4]], device=device)
3750
3751 y1_v1 = torch.repeat_interleave(y, 2)
3752 y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
3753 y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
3754 y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
3755 self.assertEqual(y1_v1, y1_expect)
3756 self.assertEqual(y1_v2, y1_expect)
3757 self.assertEqual(y1_v3, y1_expect)
3758
3759 y2 = torch.repeat_interleave(y, 3, dim=1)
3760 y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
3761 [3, 3, 3, 4, 4, 4]], device=device)
3762 self.assertEqual(y2, y2_expect)
3763
3764 y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
3765 y3_expect = torch.tensor([[1, 2],
3766 [3, 4],
3767 [3, 4]], device=device)
3768 self.assertEqual(y3, y3_expect)
3769
3770 with self.assertRaises(RuntimeError):
3771 torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
3772
3773 with self.assertRaises(RuntimeError):
3774 torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
3775
3776 # test zero sized dimension
3777 x = torch.zeros((5, 0), device=device)
3778 y = torch.repeat_interleave(x, repeats=3, dim=1)
3779 self.assertEqual(y, x.new_zeros(5, 0, device=device))
3780
3781 x = torch.tensor([], dtype=torch.int64, device=device)
3782 y = torch.repeat_interleave(x, x)
3783 self.assertEqual(y, x)
3784
3785 def test_repeat_interleave_simple(self):
3786 def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None):
3787 x = torch.randn(shape, dtype=dtype, device="mps")
3788 x_cpu = x.detach().clone().cpu()
3789
3790 num_repeats_cpu = num_repeats.detach().clone().cpu()
3791
3792 repeats = torch.repeat_interleave(x, num_repeats, dim)
3793 repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim)
3794
3795 self.assertEqual(repeats, repeats_cpu)
3796 helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
3797 helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
3798 helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
3799 helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
3800 helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
3801
Rohan Mitchellf42b42d2022-05-31 18:23:25 +00003802 def test_count_nonzero(self):
3803 def helper(dtype):
3804 n = [
3805 [[1, 0, 2], [3, 0, 2], [7, 9, -4]],
3806 [[0, 2, 3], [3, 2, 1], [2, 0, 0]],
3807 ]
3808 cpu_x = torch.tensor(n, dtype=dtype)
3809 mps_x = torch.tensor(n, dtype=dtype).to('mps')
3810
3811 # All non-zeros
3812 self.assertEqual(
3813 torch.count_nonzero(cpu_x),
3814 torch.count_nonzero(mps_x)
3815 )
3816
3817 # dim=1
3818 self.assertEqual(
3819 torch.count_nonzero(cpu_x, dim=1),
3820 torch.count_nonzero(mps_x, dim=1)
3821 )
3822
3823 # dim=(0, 1)
3824 self.assertEqual(
3825 torch.count_nonzero(cpu_x, dim=(0, 1)),
3826 torch.count_nonzero(mps_x, dim=(0, 1))
3827 )
3828 helper(torch.int32)
3829 helper(torch.int64)
3830 helper(torch.float16)
3831 helper(torch.float32)
3832
Kulin Sethe011a8e2022-05-13 18:28:53 +00003833 def _test_module_empty_input(self, module, inp, check_size=True):
3834 inp.requires_grad_(True)
3835 out = module(inp)
3836 gO = torch.rand_like(out)
3837 out.backward(gO)
3838 if check_size:
3839 self.assertEqual(out.size(), inp.size())
3840 for p in module.parameters():
3841 if p.requires_grad:
3842 self.assertEqual(p.grad, torch.zeros_like(p.grad))
3843 self.assertEqual(inp.grad, torch.zeros_like(inp))
3844
Lukas Hoeniga52bfe22022-05-24 20:09:45 +00003845 # Test dtype casting, with and without simultaneous device change
3846 def test_to(self):
3847 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3848 cpu_x = torch.tensor(values, device='cpu')
3849 mps_x = torch.tensor(values, device='mps')
3850
3851 self.assertEqual(cpu_x.int(), mps_x.int().cpu())
3852 self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
3853 self.assertEqual(cpu_x.float(), mps_x.float().cpu())
3854
3855 self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
3856 torch.tensor(1, dtype=torch.int32))
3857 self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
3858 self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
3859 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
3860 torch.tensor(1, dtype=torch.int32))
3861 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
3862 torch.tensor(1.0))
3863 self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
3864 torch.tensor(4, dtype=torch.int32))
3865 self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
3866 torch.tensor(4, dtype=torch.int32))
3867 self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
3868 torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
Nikita Shulga43905462022-06-22 18:41:21 +00003869 # Cast int8 and uint8 to float and compare results
3870 # See https://github.com/pytorch/pytorch/issues/80009 for more details
3871 cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8)
3872 cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8)
3873 for x_cpu in [cpu_byte, cpu_char]:
3874 x_mps = x_cpu.to('mps')
3875 self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32))
3876
Lukas Hoeniga52bfe22022-05-24 20:09:45 +00003877
Kulin Sethd63db522022-05-28 14:41:56 +00003878 def test_setitem_scalar(self) -> None:
3879 device = 'mps'
3880 for dtype in [torch.int32, torch.float32, torch.int64]:
3881 for i in range(3, 6):
3882 for j in range(3, 6):
3883 t = torch.zeros(i, j, dtype=dtype, device=device)
3884 self.assertEqual(t.sum(), 0)
3885 t[1, 1] = 1
3886 t[2, 1] = j
3887 t[1, 2] = i
3888 self.assertEqual(t[1, 1], 1)
3889 self.assertEqual(t[1, 2], i)
3890 self.assertEqual(t[2, 1], j)
3891 self.assertEqual(t.sum(), 1 + i + j)
Nikita Shulga437ecfc2022-05-27 20:46:53 +00003892
Nikita Shulga81cd2762022-06-14 07:48:56 -07003893 def test_stride_of_strides(self) -> None:
3894 x = torch.rand(32, 1, device='mps')
3895 y = x.as_strided(size=(32, 2), stride=(1, 0))
3896 # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
3897 # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
3898 z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
3899 self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
3900
Kulin Seth596bb412022-07-20 14:27:54 +00003901 def test_type_casting(self):
3902 # https://github.com/pytorch/pytorch/issues/81567
3903 def helper(data, to_dtype):
3904 a_cpu = torch.tensor(data)
3905 a_mps = a_cpu.to(torch.device('mps'))
3906
3907 res_cpu = a_cpu.type(to_dtype)
3908 res_mps = a_mps.type(to_dtype)
3909 self.assertEqual(res_cpu, res_mps)
3910
3911 helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
3912 helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
3913 helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
3914 helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
3915 helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
3916 helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
3917 helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
3918
3919 def test_to_casting(self):
3920 # https://github.com/pytorch/pytorch/issues/81567
3921 def helper(data, to_dtype):
3922 a_cpu = torch.tensor(data)
3923 a_mps = a_cpu.to(torch.device('mps'))
3924
3925 res_cpu = a_cpu.to(to_dtype)
3926 res_mps = a_mps.to(to_dtype)
3927 self.assertEqual(res_cpu, res_mps)
3928
3929 helper([9.0, 3.0, 5.0, 4.0], torch.int64)
3930 helper([9.0, 3.0, 5.0, 4.0], torch.float)
3931 helper([9.0, 3.0, 5.0, 4.0], torch.int32)
3932 helper([9.0, 3.0, 5.0, 4.0], torch.short)
3933 helper([9.0, 3.0, 5.0, 4.0], torch.half)
3934 helper([9.0, 3.0, 5.0, 4.0], torch.int8)
3935 helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
3936
3937 def test_storage_offset_greater_than_src_nbytes(self):
3938 # https://github.com/pytorch/pytorch/issues/80844
3939 n_tensors = 100
3940 n_tensor_elems = 784
3941 elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
3942
3943 tensor_list = []
3944 for i in range(0, n_tensors - 1):
3945 # create a list of contiguous view tensors (view tensor created by the slice op)
3946 t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
3947 tensor_list.append(t)
3948
3949 for i in range(0, n_tensors - 1):
Nikita Shulgaae62cf72022-10-21 14:10:05 +00003950 t = tensor_list[i].view(1, n_tensor_elems)
Kulin Seth596bb412022-07-20 14:27:54 +00003951 t_mps = t.to("mps")
Nikita Shulgaae62cf72022-10-21 14:10:05 +00003952 self.assertEqual(t, t_mps.cpu(), f"i={i}")
Kulin Sethe011a8e2022-05-13 18:28:53 +00003953
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00003954 # See https://github.com/pytorch/pytorch/issues/82427
Nikita Shulgaff533b12022-08-18 21:59:15 +00003955 # and https://github.com/pytorch/pytorch/issues/83692
3956 def test_full_bugs(self):
3957 # Test should not crash
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00003958 x = torch.full((3, 3), True, device='mps')
Nikita Shulgaff533b12022-08-18 21:59:15 +00003959 # torch.full should work for uint8
3960 y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
3961 y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
3962 self.assertEqual(y_mps, y_cpu)
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00003963
Denis Vieriu71ec2612023-02-15 06:09:56 +00003964 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +00003965 # See https://github.com/pytorch/pytorch/issues/84995
3966 def test_div_bugs(self):
3967 for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
Kulin Seth299ada92023-02-10 00:10:08 +00003968 if dtype != torch.int64:
3969 x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
3970 y = torch.div(x, 101, rounding_mode=mode)
3971 self.assertEqual(y.sum(), 0)
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +00003972
Nikita Shulgadcf51882022-08-03 14:54:47 +00003973 # See https://github.com/pytorch/pytorch/issues/82663
3974 def test_bool_expand(self):
3975 x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
3976 y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
PyTorch MergeBotcba96362022-12-02 21:36:13 +00003977 self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
Nikita Shulgadcf51882022-08-03 14:54:47 +00003978
Nikita Shulga420c5762022-08-02 21:15:37 +00003979 # Empty unary op should return tensor of the same size
3980 def test_empty_neg(self):
3981 x = torch.tensor([[]], device='mps')
3982 y = -x
3983 self.assertEqual(x, y)
3984
Kulin Sethfc596642023-01-04 22:15:13 +00003985 def _test_unique_scalar_empty(self, dtype, device, f):
3986 # test scalar
3987 x = torch.tensor(0, dtype=dtype, device=device)
3988 unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
3989 expected_unique = torch.tensor([0], dtype=dtype, device=device)
3990 expected_inverse = torch.tensor(0, device=device)
3991 expected_counts = torch.tensor([1], device=device)
3992 self.assertEqual(unique, expected_unique)
3993 self.assertEqual(inverse, expected_inverse)
3994 self.assertEqual(counts, expected_counts)
3995
3996 # test zero sized tensor
3997 x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
3998 unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
3999 expected_unique = torch.tensor([], dtype=dtype, device=device)
4000 expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
4001 expected_counts = torch.tensor([], dtype=torch.long, device=device)
4002 self.assertEqual(unique, expected_unique)
4003 self.assertEqual(inverse, expected_inverse)
4004 self.assertEqual(counts, expected_counts)
4005
4006 def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
4007 def ensure_tuple(x):
4008 if isinstance(x, torch.Tensor):
4009 return (x,)
4010 return x
4011
4012 for return_inverse in [True, False]:
4013 for return_counts in [True, False]:
4014 # test with expected
4015 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4016 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4017 self.assertEqual(expected_unique, ret[0])
4018 if return_inverse:
4019 self.assertEqual(expected_inverse, ret[1])
4020 if return_counts:
4021 count_index = 1 + int(return_inverse)
4022 self.assertEqual(expected_counts, ret[count_index])
4023
4024 # tests per-element unique on a higher rank tensor.
4025 y = x.view(additional_shape)
4026 y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
4027 self.assertEqual(expected_unique, y_unique)
4028 self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
4029 self.assertEqual(expected_counts, y_counts)
4030
4031 def test_unique_all_dtypes(self, device="mps"):
4032 def helper(dtype):
4033 def ensure_tuple(x):
4034 if isinstance(x, torch.Tensor):
4035 return (x,)
4036 return x
4037
4038 if dtype is torch.bool:
4039 x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
4040 expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
4041 expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
4042 expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
4043 else:
4044 x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
4045 expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
4046 expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
4047 expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
4048
4049 # test sorted unique
4050 fs = (
4051 lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
4052 lambda x, **kwargs: x.unique(sorted=True, **kwargs),
4053 )
4054 x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
4055 xs = (x, x_sliced)
4056 for f, x in product(fs, xs):
4057 self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
4058 self._test_unique_scalar_empty(dtype, device, f)
4059
4060 # test unsorted unique
4061 fs = (
4062 lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
4063 lambda x, **kwargs: x.unique(sorted=False, **kwargs)
4064 )
4065 for f, x in product(fs, xs):
4066 self._test_unique_scalar_empty(dtype, device, f)
4067 for return_inverse, return_counts in product((True, False), repeat=2):
4068 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4069 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4070 x_list = x.tolist()
4071 x_unique_list = ret[0].tolist()
4072 self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
4073 if return_inverse:
4074 x_inverse_list = ret[1].tolist()
4075 for i, j in enumerate(x_inverse_list):
4076 self.assertEqual(x_list[i], x_unique_list[j])
4077 if return_counts:
4078 count_index = 1 + int(return_inverse)
4079 x_counts_list = ret[count_index].tolist()
4080 for i, j in zip(x_unique_list, x_counts_list):
4081 count = 0
4082 for k in x_list:
4083 if k == i:
4084 count += 1
4085 self.assertEqual(j, count)
4086 [helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]]
4087
4088 def test_unique(self):
4089 def helper(x, return_inverse, return_counts):
4090 cpu_x = x
4091 x = cpu_x.detach().clone().to('mps')
4092
4093 result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts)
4094 result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts)
4095
4096 self.assertEqual(result, result_cpu)
4097 helper(torch.tensor([1, 2, 4, 2, 1]), False, False)
4098 helper(torch.randint(3, (10, )), False, False)
4099 helper(torch.randint(3, (10, )), True, False)
4100 helper(torch.randint(3, (10, )), False, True)
4101 helper(torch.randint(3, (10, )), True, True)
4102 helper(torch.randint(3, (1, )), True, True)
4103 helper(torch.randint(3, (0, )), True, True)
Nikita Shulga5e4ee152023-07-11 19:55:52 +00004104 # Regression test for https://github.com/pytorch/pytorch/issues/104879
4105 x = torch.arange(2, device="mps")
4106 self.assertEqual(x.reshape(1, 1, 2).unique(), x)
Kulin Sethfc596642023-01-04 22:15:13 +00004107
4108 def test_unique_consecutive(self):
4109 def helper(x, dim, return_inverse, return_counts):
4110 cpu_x = x
4111 x = cpu_x.detach().clone().to('mps')
4112
4113 result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4114 result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4115
4116 self.assertEqual(result, result_cpu)
4117 helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False)
4118 helper(torch.randint(3, (10, )), 0, False, False)
4119 helper(torch.randint(3, (10, )), 0, True, False)
4120 helper(torch.randint(3, (10, )), 0, False, True)
4121 helper(torch.randint(3, (10, )), 0, True, True)
4122 helper(torch.randint(3, (10, )), 0, True, True)
4123 helper(torch.randint(3, (1, )), 0, True, True)
4124 helper(torch.randint(3, (0, )), 0, True, True)
4125
4126 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False)
4127 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True)
4128 helper(torch.randint(2, (20, 2)), 0, True, True)
4129 helper(torch.randint(2, (1, 2)), 0, True, True)
4130 helper(torch.randint(2, (0, 2)), 0, True, True)
4131
4132 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False)
4133 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True)
4134 helper(torch.randint(2, (2, 20)), 1, True, True)
4135 helper(torch.randint(2, (2, 1)), 1, True, True)
4136 helper(torch.randint(2, (2, 0)), 1, True, True)
4137
Nikita Shulga1367f242022-09-27 15:44:53 +00004138 # See https://github.com/pytorch/pytorch/issues/85675
4139 def test_cat_non_contiguous(self):
Kulin Sethc74f4382023-02-11 19:43:33 +00004140 def rotate_subset(data, dim):
4141 x1 = data[:, :, :2, :]
4142 x2 = data[:, :, 2:, :]
4143 self.assertFalse(x1.is_contiguous())
4144 self.assertFalse(x2.is_contiguous())
4145 return torch.concat((x1, x2), dim=dim)
Nikita Shulga1367f242022-09-27 15:44:53 +00004146 for dtype in MPS_DTYPES:
4147 if dtype == torch.bool:
4148 continue
Kulin Sethc74f4382023-02-11 19:43:33 +00004149 data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
4150 data = data.to(memory_format=torch.channels_last)
Nikita Shulga1367f242022-09-27 15:44:53 +00004151 mps_data = data.to("mps")
Kulin Sethc74f4382023-02-11 19:43:33 +00004152 self.assertEqual(data, mps_data)
4153 for dim in range(data.dim()):
4154 cpu_result = rotate_subset(data, dim)
4155 mps_result = rotate_subset(mps_data, dim)
4156 self.assertEqual(cpu_result, mps_result.to("cpu"))
4157 # TODO: enable memory format test
4158 # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
Nikita Shulga1367f242022-09-27 15:44:53 +00004159
Nikita Shulgab9b24c32022-10-02 20:13:05 +00004160 # See https://github.com/pytorch/pytorch/issues/85967
4161 def test_from_numpy_non_contiguous(self):
4162 a = np.arange(9).reshape(3, 3)[:, :2]
4163 t_cpu = torch.tensor(a, device="cpu")
4164 t_mps = torch.tensor(a, device="mps")
4165 self.assertEqual(t_cpu, t_mps.to("cpu"))
4166
Nikita Shulgaae62cf72022-10-21 14:10:05 +00004167 # See https://github.com/pytorch/pytorch/issues/86954
4168 def test_copy_non_contiguous(self):
4169 x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
4170 self.assertFalse(x.is_contiguous())
4171 y = x.to('mps')
4172 self.assertFalse(y.is_contiguous())
4173 self.assertEqual(x, y.to('cpu'))
4174
4175 x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
4176 y = x.to('mps')
4177 self.assertEqual(x, y.to('cpu'))
4178
4179 x = torch.full((4, 4, 4, 4), 13, device="cpu")
4180 y = torch.full((4, 4, 4, 4), 13, device="mps")
4181 z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
4182 x.permute(3, 2, 1, 0)[1::, ::2] = z
4183 # As y is on MPS and z on CPU, this dispatches to a copy operator
4184 y.permute(3, 2, 1, 0)[1::, ::2] = z
4185 self.assertEqual(x, y.to('cpu'))
4186
Li-Huai (Allan) Linb7c2a652023-02-28 05:24:31 +00004187 # See https://github.com/pytorch/pytorch/issues/95417
4188 def test_copy_storage_offset(self):
4189 x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
4190 x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
4191 update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
4192 update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
4193 x_cpu[2:4] = update_cpu
4194 x_mps[2:4] = update_mps # implicit type casting and copy
4195 self.assertEqual(x_cpu, x_mps)
4196
Li-Huai (Allan) Lin00871182023-09-18 16:18:37 -07004197 x_cpu[2:4] = update_mps # implicit device moving and copy
4198 self.assertEqual(x_cpu, x_mps)
4199
Peter Stefekc9c2b142023-08-03 04:03:28 +00004200 def test_copy_broadcasting(self):
4201 def helper(src_shape, dst_shape, src_dtype, dst_dtype):
4202 cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
4203 cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
4204 cpu_result = cpu_dst.copy_(cpu_src)
4205 mps_src = cpu_src.to("mps")
4206 mps_dst = cpu_dst.to("mps")
4207 mps_result = mps_dst.copy_(mps_src)
4208 self.assertEqual(cpu_result, mps_result)
4209
4210 test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
4211
4212 for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
4213 helper((2, 1), (2, 3), src_dtype, dst_dtype)
4214 helper((2, 1), (2, 2), src_dtype, dst_dtype)
4215 helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
4216 helper((3,), (2, 3), src_dtype, dst_dtype)
4217 helper((2,), (2, 2), src_dtype, dst_dtype)
4218 helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
4219 helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
4220 helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
4221 helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
4222 helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
Nikita Shulgabae40932023-08-31 21:08:29 +00004223 # Regression test for https://github.com/pytorch/pytorch/issues/107867
4224 self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
Peter Stefekc9c2b142023-08-03 04:03:28 +00004225
Lukas Hoenig81a8fdc2022-11-17 04:54:23 +00004226 # See https://github.com/pytorch/pytorch/pull/84742
4227 # and https://github.com/pytorch/pytorch/pull/78319
4228 def test_binops_dtype_precedence(self):
4229 # Test dtype precedence (casting order) in binary operations by comparing to CPU result
4230 # Example values for all dtypes supported on the MPS backend
4231 sample_vals = {
4232 torch.bool: [False, True],
4233 torch.int16: [-15, 0, 1, 10],
4234 torch.int32: [-376, 0, 1, 13],
4235 torch.int64: [-8, 0, 1, 77],
4236 torch.float16: [-234.5, 0.0, 1.0, 2.0],
4237 torch.float32: [-1.0, 0.0, 0.1, 111.99],
4238 }
4239 # Test all combinations of dtypes, operations, dimensionality
4240 for dtype1, dtype2, binop in itertools.product(
4241 sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']):
4242 # bool minus bool is generally unsupported, so skip
4243 if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
4244 continue
4245 full_shape = (10,)
4246 for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
4247 # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
4248 # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4249 # (torch.tensor(val2, dtype=dtype2, device='mps')))
4250 # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4251 # (torch.tensor(val2, dtype=dtype2, device='cpu')))
4252 self.assertEqual(
4253 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4254 (torch.tensor(val2, dtype=dtype2, device='mps')),
4255 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4256 (torch.tensor(val2, dtype=dtype2, device='cpu')))
4257 self.assertEqual(
4258 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4259 (torch.tensor([val2], dtype=dtype2, device='mps')),
4260 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4261 (torch.tensor([val2], dtype=dtype2, device='cpu')))
4262 self.assertEqual(
4263 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4264 (torch.tensor([val2], dtype=dtype2, device='mps')),
4265 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4266 (torch.tensor([val2], dtype=dtype2, device='cpu')))
4267 self.assertEqual(
4268 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4269 (torch.tensor(val2, dtype=dtype2, device='mps')),
4270 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4271 (torch.tensor(val2, dtype=dtype2, device='cpu')))
4272 # Test tensors created with torch.full
4273 x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
4274 y1 = torch.tensor(val2, dtype=dtype2, device='mps')
4275 x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
4276 y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
4277 self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
4278 x3 = torch.tensor(val1, dtype=dtype1, device='mps')
4279 y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
4280 x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
4281 y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
4282 self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
4283 self.assertEqual(
4284 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4285 (torch.full(full_shape, val2, dtype=dtype2, device='mps')),
4286 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4287 (torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
Nikita Shulgaae62cf72022-10-21 14:10:05 +00004288
Soof Golan19264b52023-02-09 10:30:51 +00004289 def test_nansum(self):
4290 def helper(dtype, noncontiguous, dim):
4291 zero_cpu = torch.zeros((), dtype=dtype)
4292
4293 # Randomly scale the values
4294 scale = random.randint(10, 100)
4295 x_cpu: torch.Tensor = make_tensor(
4296 (5, 5), dtype=dtype, device='cpu',
4297 low=-scale, high=scale, noncontiguous=noncontiguous)
4298
4299 if dtype.is_floating_point:
4300 nan_mask_cpu = x_cpu < (0.2 * scale)
4301 x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
4302 x_cpu[nan_mask_cpu] = np.nan
4303 else:
4304 x_no_nan_cpu = x_cpu
4305
4306 x_mps = x_cpu.to('mps')
4307 actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
4308 expect_out_cpu = torch.empty(0, dtype=dtype)
4309 dim_kwargs = {"dim": dim} if dim is not None else {}
4310 expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
4311
4312 actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
4313 # Sanity check on CPU
4314 self.assertEqual(expect, actual_cpu)
4315
4316 # Test MPS
4317 actual_mps = torch.nansum(x_mps, **dim_kwargs)
4318 # Test out= variant
4319 torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
4320 torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
4321 self.assertEqual(expect, actual_mps)
4322 self.assertEqual(expect_out_cpu, actual_out_mps)
4323
4324 args = itertools.product(
4325 (torch.float16, torch.float32, torch.int32, torch.int64), # dtype
4326 (True, False), # noncontiguous
4327 (0, 1, None), # dim
4328 )
4329
4330 for dtype, noncontiguous, dim in args:
4331 with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
4332 helper(dtype, noncontiguous, dim)
4333
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004334 def test_cumsum_all_dtypes(self):
4335 def helper(dtype):
4336 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4337 t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4338
4339 a = t.cumsum(0, dtype=dtype)
4340 a_cpu = t_cpu.cumsum(0, dtype=dtype)
4341
4342 self.assertEqual(a.cpu(), a_cpu)
4343 [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4344
4345 try:
4346 helper(torch.int64)
4347 except Exception as e:
4348 e_string = str(e)
Peter Stefek97e50552023-08-01 21:51:16 +00004349 self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
4350 " Support has been added in macOS 13.3")
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004351
Roy Hvaarae15da782024-05-03 01:19:21 +00004352 def test_cumsum_bool(self):
4353 a = torch.ones(2**16, dtype=torch.bool)
4354 t_cpu = a.cumsum(0)
4355 t_mps = a.to("mps").cumsum(0)
4356
4357 self.assertEqual(t_cpu, t_mps)
4358
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004359 def test_cumsum_minus_one_axis(self):
4360 def helper(dtype):
4361 # Test with axis -1
4362 cpu_x = None
Aaron Gokaslan3fe437b22024-01-03 06:04:44 +00004363 if dtype == torch.float32:
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004364 cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4365 else:
4366 cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4367 x = cpu_x.detach().clone().to('mps')
4368
4369 cpu_y = cpu_x.cumsum(-1)
4370 y = x.cumsum(-1)
4371
4372 self.assertEqual(y, cpu_y)
4373
4374 [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00004375
Peter Stefek97e50552023-08-01 21:51:16 +00004376 def test_cumprod_all_dtypes(self):
4377 def helper(dtype):
4378 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4379 t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4380
4381 a = t.cumprod(0, dtype=dtype)
4382 a_cpu = t_cpu.cumprod(0, dtype=dtype)
4383
4384 self.assertEqual(a.cpu(), a_cpu)
4385 [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4386
4387 try:
4388 helper(torch.int64)
4389 except Exception as e:
4390 e_string = str(e)
4391 self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
4392 + " Support has been added in macOS 13.3")
4393
4394 def test_cumprod_minus_one_axis(self):
4395 def helper(dtype):
4396 # Test with axis -1
4397 cpu_x = None
Aaron Gokaslan3fe437b22024-01-03 06:04:44 +00004398 if dtype == torch.float32:
Peter Stefek97e50552023-08-01 21:51:16 +00004399 cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4400 else:
4401 cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4402 x = cpu_x.detach().clone().to('mps')
4403
4404 cpu_y = cpu_x.cumprod(-1)
4405 y = x.cumprod(-1)
4406
4407 self.assertEqual(y, cpu_y)
4408
4409 [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
4410
Kulin Seth105f7202023-02-09 19:29:07 +00004411 def test_median_int16(self):
4412 def helper(shape, dtype):
4413 cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
4414 x = cpu_x.detach().clone().to('mps')
4415
4416 median_result = torch.median(x)
4417 median_result_cpu = torch.median(cpu_x)
4418 self.assertEqual(median_result, median_result_cpu)
4419
4420 helper((2, 8, 4, 5), torch.int16)
4421
soulitzer91dcc3b2023-07-07 17:05:13 -04004422 def test_activation_checkpoint_does_not_error(self):
4423 from torch.utils.checkpoint import checkpoint
4424
4425 for use_reentrant in (True, False):
4426 a = torch.tensor(1., device="mps", requires_grad=True)
4427
4428 def fn(x):
4429 return x.sin().cos().exp()
4430
4431 out = checkpoint(fn, a, use_reentrant=use_reentrant)
4432 out.backward()
4433
Kulin Sethe011a8e2022-05-13 18:28:53 +00004434 def test_as_strided(self):
Kulin Seth54361342022-07-06 03:39:20 +00004435 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4436 values_1 = [[1.0, 1.0], [1.0, 1.0]]
4437 cpu_x = torch.tensor(values, device='cpu')
4438 ones1 = torch.tensor(values_1, device='mps')
4439 x = cpu_x.detach().clone().to('mps').requires_grad_()
4440 strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
4441 strided_mps = torch.as_strided(x, (2, 2), (1, 2))
4442 self.assertEqual(strided_mps, strided_cpu)
4443 strided_cpu_out = strided_cpu + ones1.to('cpu')
4444 strided_mps_out = strided_mps + ones1
4445 self.assertEqual(strided_cpu_out, strided_mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004446
Kulin Seth54361342022-07-06 03:39:20 +00004447 # test with storage offsets
4448 cpu_x = torch.rand(3, 3, device='cpu')
4449 mps_x = cpu_x.to('mps')
4450 strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0)
4451 strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0)
4452 strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1)
4453 strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1)
4454 strided_cpu_out = strided_cpu1 - strided_cpu2
4455 strided_mps_out = strided_mps1 - strided_mps2
4456 self.assertEqual(strided_cpu_out, strided_mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004457
Denis Vieriu4477a5b2022-12-22 21:21:00 +00004458 def test_unfold(self):
4459 x = torch.arange(1., 8)
4460 x_mps = torch.arange(1., 8, device="mps")
Kulin Seth54361342022-07-06 03:39:20 +00004461
Denis Vieriu4477a5b2022-12-22 21:21:00 +00004462 y = x.unfold(0, 2, 1)
4463 y_mps = x_mps.unfold(0, 2, 1)
4464
4465 self.assertEqual(y, y_mps)
4466
4467 def test_unfold_all_devices_and_dtypes(self):
4468 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
4469 for dt in supported_dtypes:
4470 x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
4471 self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
4472
4473 def test_unfold_scalars(self):
4474 x = torch.tensor(0.5, device="mps")
4475 # unfold on a 0-dimensional tensor should always return a 1-d dimensional
4476 # tensor of shape [size] (i.e., the second parameter to unfold)
4477
4478 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
4479 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
4480 self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
Kulin Sethe011a8e2022-05-13 18:28:53 +00004481
Denis Vieriuf7939b22023-01-03 06:01:07 +00004482 def test_bincount_simple(self):
4483 input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
4484 input_cpu = input.to("cpu")
4485 weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
4486 weights_cpu = weights.to("cpu")
4487
4488 x = torch.bincount(input)
4489 x_cpu = torch.bincount(input_cpu)
4490 self.assertEqual(x, x_cpu)
4491
4492 y = input.bincount(weights)
4493 y_cpu = input_cpu.bincount(weights_cpu)
4494 self.assertEqual(y, y_cpu)
4495
4496 def test_bincount_reduction(self):
4497 device = "mps"
4498 # negative input throws
4499 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4500 torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32))
4501 # n-d input, with n > 1 throws
4502 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4503 torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
4504 # minlength < 0 throws
4505 with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
4506 torch.bincount(torch.tensor([1, 3], device=device),
4507 torch.tensor([.2, .2], device=device),
4508 minlength=-1)
4509 # n-d weights, with n > 1 throws
4510 with self.assertRaisesRegex(RuntimeError, '1-d'):
4511 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4512 torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float))
4513 # input and weights dim mismatch
4514 with self.assertRaisesRegex(RuntimeError, 'same length'):
4515 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4516 torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float))
4517 # 1-d input with no elements and default minlength
4518 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
4519 torch.zeros(0, dtype=torch.long, device=device))
4520 # 1-d input with no elements and specified minlength
4521 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
4522 torch.zeros(10, dtype=torch.long, device=device))
4523
4524 # test tensor method without weights
4525 long_counts = torch.tensor(
4526 [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
4527 self.assertEqual(
4528 torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
4529 long_counts)
4530 # test avoiding overflow for uint8 (#76979)
4531 count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
4532 count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
4533 self.assertEqual(count_uint8, count_int16)
4534 # test minlength functionality
4535 int_counts = torch.bincount(
4536 torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5)
4537 self.assertEqual(
4538 torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
4539 int_counts)
4540 # test weights
4541 byte_counts = torch.bincount(
4542 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4543 torch.tensor([.1, .2, .3, .4, .5], device=device))
4544 self.assertEqual(
4545 torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
4546 byte_counts = torch.bincount(
4547 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4548 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
4549 self.assertEqual(
4550 torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts)
4551 # test non-contiguous inputs and weights
4552 inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32)
4553 weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
4554 for i in [0, 1]:
4555 assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
4556 assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
4557 # inputs are non-contiguous but weights are contiguous
4558 self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
4559 # inputs and weights are non-contiguous
4560 self.assertEqual(
4561 inputs[:, 1].bincount(weights[:, 1]),
4562 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4563 # weights are non-contiguous but inputs are contiguous
4564 self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
4565 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4566
4567 # test bincount on non-contiguous slices
4568 all0s = torch.zeros((32, 2), dtype=torch.int32, device=device)
4569 self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
4570
4571 all1s = torch.ones((32, 2), dtype=torch.int32, device=device)
4572 self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
4573
4574 # test large number of bins - global memory use
4575 big_exp = torch.zeros(100, device=device)
4576 big_exp[-1] = 50.0
4577 big_w = torch.tensor([.5] * 100, device=device)
4578 big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w)
4579 self.assertEqual(big_exp, big_out)
4580 # test large input size
4581 big_exp = torch.zeros(2, device=device, dtype=torch.int64)
4582 big_exp[1] = 10
4583 big_out = torch.ones(10, dtype=torch.int8, device=device).bincount()
4584 self.assertEqual(big_exp, big_out)
4585
4586 def test_bincount(self):
4587 device = "mps"
4588 input_size = (5000,)
4589 w = torch.randn(input_size, dtype=torch.float, device=device)
4590 w_cpu = w.cpu()
4591
4592 t = torch.randint(50, input_size, dtype=torch.int8, device=device)
4593 self.assertEqual(t.cpu().bincount(), t.bincount())
4594 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4595
4596 t = torch.randint(500, input_size, dtype=torch.int32, device=device)
4597 self.assertEqual(t.cpu().bincount(), t.bincount())
4598 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4599
4600 t = torch.randint(2000, input_size, dtype=torch.int32, device=device)
4601 self.assertEqual(t.cpu().bincount(), t.bincount())
4602 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4603
4604 t = torch.zeros([10], dtype=torch.int32, device=device)
4605 t[0] = 35488
4606 counted = t.bincount(minlength=65536)
4607 self.assertEqual(torch.sum(counted), 10)
4608
Kulin Seth3d833212022-05-20 03:18:09 +00004609 def test_sum_backward(self):
4610 def helper(n, c):
4611 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4612 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4613 x = cpu_x.detach().clone().to('mps').requires_grad_()
4614
4615 all_sum = torch.sum(x)
4616 all_sum_cpu = torch.sum(cpu_x)
4617
4618 all_sum.backward()
4619 all_sum_cpu.backward()
Kulin Seth3d833212022-05-20 03:18:09 +00004620 self.assertEqual(all_sum, all_sum_cpu)
4621 self.assertEqual(x.grad, cpu_x.grad)
4622
4623 helper(3, 3)
4624
qqaatwff44bfa2022-06-24 17:18:30 +00004625 # L1 loss
4626 def test_l1_loss(self):
4627 def helper(shape, reduction):
4628 # create the criterion
4629 loss = torch.nn.L1Loss(reduction=reduction)
4630
4631 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4632 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4633 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4634 targetMPS = targetCPU.detach().clone().to('mps')
4635
4636 # forward pass
4637 outputCPU = loss(inputCPU, targetCPU)
4638 outputMPS = loss(inputMPS, targetMPS)
4639 self.assertEqual(outputCPU, outputMPS)
4640
4641 # backward pass
4642 if reduction != 'none':
4643 # chose 2 just to make the grad_output > 1 in backward pass
4644 outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4645 outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4646 self.assertEqual(inputCPU.grad, inputMPS.grad)
4647
4648 helper([8, 5, 4], 'none')
4649 helper([7, 5, 2, 4], 'sum')
4650 # verify if changes in shape would cause cached graph lookup problems
4651 helper([7, 5, 2, 4, 6], 'sum')
4652 helper([8, 4, 5, 7, 6], 'mean')
4653
Kulin Sethe011a8e2022-05-13 18:28:53 +00004654 # Mean Squared Error
4655 def test_mse_loss(self):
4656 def helper(shape, reduction):
4657 # create the criterion
4658 loss = torch.nn.MSELoss(reduction=reduction)
4659
4660 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4661 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4662 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4663 targetMPS = targetCPU.detach().clone().to('mps')
4664
4665 # forward pass
4666 outputCPU = loss(inputCPU, targetCPU)
4667 outputMPS = loss(inputMPS, targetMPS)
4668 self.assertEqual(outputCPU, outputMPS)
4669
4670 # backward pass
4671 if reduction != 'none':
4672 # chose 2 just to make the grad_output > 1 in backward pass
4673 outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4674 outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4675 self.assertEqual(inputCPU.grad, inputMPS.grad)
4676
4677 helper([8, 5, 4], 'none')
4678 helper([7, 5, 2, 4], 'sum')
4679 # verify if changes in shape would cause cached graph lookup problems
4680 helper([7, 5, 2, 4, 6], 'sum')
4681 helper([8, 4, 5, 7, 6], 'mean')
4682
Denis Vieriu3fb53bb2024-05-08 00:52:26 +00004683 def test_mse_loss_strided_output(self):
4684 # https://github.com/pytorch/pytorch/issues/124621
4685 lf = nn.MSELoss(reduction='none')
4686 model_cpu = nn.Sequential(
4687 nn.Conv1d(3, 3, 1),
4688 )
4689 model_mps = copy.deepcopy(model_cpu).to("mps")
4690
4691 x = torch.randn(128, 10, 3)
4692 x = x.permute(0, 2, 1)
4693
4694 x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
4695 x_mps = x_mps.permute(0, 2, 1)
4696
4697 y = model_cpu(x)
4698 y_mps = model_mps(x_mps)
4699
4700 y = y.permute(0, 2, 1)[:, :5, :]
4701 y_mps = y_mps.permute(0, 2, 1)[:, :5, :]
4702
4703 y_hat = torch.randn(128, 5, 3)
4704 y_hat_mps = y_hat.detach().clone().to("mps")
4705
4706 loss = lf(y, y_hat)
4707 loss_mps = lf(y_mps, y_hat_mps)
4708 self.assertEqual(loss, loss_mps)
4709
Kulin Sethe011a8e2022-05-13 18:28:53 +00004710 # Binary Cross Enropy
Kulin Seth4615f6a2022-06-16 20:21:31 +00004711 def test_bce_loss_simple(self):
Kulin Sethe011a8e2022-05-13 18:28:53 +00004712 def helper(shape, reduction):
4713 # create the criterion
4714 loss = torch.nn.BCELoss(reduction=reduction)
4715
4716 # input and target must be within [0..1]
4717 input_t = np.random.random_sample(size=shape).astype(np.float32)
4718 target_t = np.random.random_sample(size=shape).astype(np.float32)
4719 inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
4720 targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
4721 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4722 targetMPS = targetCPU.detach().clone().to('mps')
4723
4724 # forward pass
4725 outputCPU = loss(inputCPU, targetCPU)
4726 outputMPS = loss(inputMPS, targetMPS)
4727 self.assertEqual(outputCPU, outputMPS)
4728
4729 # backward pass
4730 if reduction != 'none':
4731 # chose 0.6 just to have the grad_output != 1
4732 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
4733 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
4734 self.assertEqual(inputCPU.grad, inputMPS.grad)
4735
4736 helper([8, 5, 4], 'none')
4737 helper([7, 5, 2, 4], 'sum')
4738 # verify if changes in shape would cause cached graph lookup problems
4739 helper([7, 5, 2, 4, 6], 'sum')
4740 helper([8, 4, 5, 7, 6], 'mean')
Kulin Seth4615f6a2022-06-16 20:21:31 +00004741 helper([1, 1, 32, 32], 'mean')
4742
4743 def test_bce_loss_always_nonnegative(self):
4744 target = torch.ones(5, device='mps')
4745 input = torch.ones(5, device='mps')
4746 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4747
4748 target = torch.zeros(5, device='mps')
4749 input = torch.zeros(5, device='mps')
4750 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4751
4752 def test_bce_loss_size_mismatch(self):
4753 bceloss = nn.BCELoss()
4754 a = torch.rand(25, device='mps')
4755 b = torch.rand(25, 1, device='mps')
4756 with self.assertRaisesRegex(ValueError, r'Using a target size \('):
4757 bceloss(a, b)
4758
4759 def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
4760 x_size = 1024
4761 y_size = 256
4762 target = torch.rand(x_size, y_size, device='mps')
4763
4764 for reduction in ['none', 'mean', 'sum']:
4765 output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
4766 output_logits = output_sig.clone().detach()
4767
4768 output_sig.requires_grad = True
4769 output_logits.requires_grad = True
4770 weight = torch.rand(y_size, device='mps')
4771
4772 loss_sig = nn.BCELoss(weight, reduction=reduction)(
4773 torch.sigmoid(output_sig), target
4774 )
4775 loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
4776 output_logits, target
4777 )
4778
4779 self.assertEqual(loss_logits, loss_sig)
4780
4781 if reduction == 'none':
4782 grad = torch.rand(x_size, y_size, device='mps')
4783 loss_sig.backward(grad)
4784 loss_logits.backward(grad)
4785 else:
4786 loss_sig.backward()
4787 loss_logits.backward()
4788
4789 self.assertEqual(output_sig.grad, output_logits.grad)
4790
4791 def test_bce_with_logits_has_correct_grad_at_zero(self):
4792 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4793 target = torch.zeros(3, 1, device='mps')
4794 nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
4795 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4796 self.assertEqual(output.grad, expected_grad)
4797
4798 def test_bce_with_logits_broadcasts_weights(self):
4799 target = torch.rand(16, 4, device='mps')
4800 output = torch.rand(16, 4, device='mps') - 0.5
4801
4802 weight = torch.rand(4, device='mps')
4803 out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4804
4805 weight = weight.expand(16, 4).contiguous()
4806 out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4807
4808 self.assertEqual(out1, out2)
4809
4810 weight = torch.rand(16, 1, device='mps')
4811 out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4812
4813 weight = weight.expand(16, 4).contiguous()
4814 out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4815
4816 self.assertEqual(out1, out2)
4817
4818 def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
4819 target = torch.rand(64, 4, device='mps')
4820 output = torch.rand(64, 4, device='mps') - 0.5
4821 pos_weight = torch.ones(64, 4, device='mps')
4822
4823 self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
4824 nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
4825
4826 def test_bce_with_logits_broadcasts_pos_weights(self):
4827 target = torch.rand(64, 4, device='mps')
4828 output = torch.rand(64, 4, device='mps') - 0.5
4829 pos_weight = torch.rand(4, device='mps')
4830 out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4831
4832 pos_weight1 = pos_weight.expand(1, 4)
4833 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
4834
4835 pos_weight2 = pos_weight.expand(64, 4)
4836 out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
4837
4838 self.assertEqual(out1, out2)
4839 self.assertEqual(out1, out3)
4840
4841 def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
4842 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4843 target = torch.zeros(3, 1, device='mps')
4844 pos_weight = torch.ones(3, 1, device='mps')
4845 nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
4846 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4847 grad = output.grad
4848 self.assertEqual(grad, expected_grad)
4849
4850 def test_bce_with_logits_stability(self):
4851 output = torch.tensor([0., -120.], device='mps')
4852 target = torch.tensor([0., 1.], device='mps')
4853 pos_weight = torch.tensor([1., 1.], device='mps')
4854
4855 out1 = nn.BCEWithLogitsLoss()(output, target)
4856 self.assertTrue(torch.isfinite(out1).all().item())
4857
4858 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4859 self.assertTrue(torch.isfinite(out2).all().item())
4860
4861 def test_bce_loss_broadcasts_weights(self):
4862 sigmoid = nn.Sigmoid()
4863 target = torch.rand(16, 4, device='mps')
4864 output = torch.rand(16, 4, device='mps') - 0.5
4865
4866 weight = torch.rand(4, device='mps')
4867 out1 = nn.BCELoss(weight)(sigmoid(output), target)
4868
4869 weight = weight.expand(16, 4).contiguous()
4870 out2 = nn.BCELoss(weight)(sigmoid(output), target)
4871
4872 self.assertEqual(out1, out2)
4873
4874 weight = torch.rand(16, 1, device='mps')
4875 out1 = nn.BCELoss(weight)(sigmoid(output), target)
4876
4877 weight = weight.expand(16, 4).contiguous()
4878 out2 = nn.BCELoss(weight)(sigmoid(output), target)
4879
4880 self.assertEqual(out1, out2)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004881
Nikita Shulga09ee96b2024-01-03 05:58:26 +00004882 def test_cross_entropy_loss(self):
4883 # Regression test for https://github.com/pytorch/pytorch/issues/116095
4884 loss = nn.CrossEntropyLoss()
4885 pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps')
4886 target = torch.ones(3, dtype=torch.long, device='mps')
4887 output = loss(pred, target)
4888 output.backward()
4889
Kulin Sethe011a8e2022-05-13 18:28:53 +00004890 def test_log_softmax(self):
4891 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
4892 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4893 mps_x = torch.tensor(values, device='mps', requires_grad=True)
4894
4895 cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
4896 mps_log_softmax = F.log_softmax(mps_x, dim=0)
4897 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
4898
4899 cpu_grad = torch.ones_like(cpu_log_softmax)
4900 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
4901
4902 cpu_log_softmax.backward(gradient=cpu_grad)
4903 mps_log_softmax.backward(gradient=mps_grad)
4904
4905 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
4906
alexdremova17a7cc2023-02-18 18:26:29 +00004907 def test_log_softmax_large_numbers(self):
4908 values = [
4909 [10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
4910 [-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
4911 ]
4912 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4913 mps_x = torch.tensor(values, device='mps', requires_grad=True)
4914
4915 cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
4916 mps_log_softmax = F.log_softmax(mps_x, dim=-1)
4917 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
4918
4919 cpu_grad = torch.ones_like(cpu_log_softmax)
4920 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
4921
4922 cpu_log_softmax.backward(gradient=cpu_grad)
4923 mps_log_softmax.backward(gradient=mps_grad)
4924
4925 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
4926
Kulin Sethe011a8e2022-05-13 18:28:53 +00004927 def test_eq(self):
4928 values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
4929 values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
4930 mps_x = torch.tensor(values1, device='mps')
4931 mps_y = torch.tensor(values2, device='mps')
4932 cpu_x = torch.tensor(values1, device='cpu')
4933 cpu_y = torch.tensor(values2, device='cpu')
4934 result_mps = torch.eq(mps_x, mps_y)
4935 result_cpu = torch.eq(cpu_x, cpu_y)
4936
4937 self.assertEqual(result_cpu, result_mps.to('cpu'))
4938
Denis Vieriu71ec2612023-02-15 06:09:56 +00004939 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Ramin Azarmehr6485d262022-12-23 17:11:55 +00004940 def test_signed_vs_unsigned_comparison(self):
4941 cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
4942 mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
4943 # in the comparison of signed vs. unsigned we should always cast to unsigned
4944 self.assertEqual(cpu_x == -1, mps_x == -1)
4945 self.assertEqual(cpu_x > -1, mps_x > -1)
4946 self.assertEqual(cpu_x < -1, mps_x < -1)
4947
Kulin Sethe011a8e2022-05-13 18:28:53 +00004948 def test_eq_int64(self):
4949 values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
4950 values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
4951 mps_x = torch.tensor(values1, device='mps')
4952 mps_y = torch.tensor(values2, device='mps')
4953 cpu_x = torch.tensor(values1, device='cpu')
4954 cpu_y = torch.tensor(values2, device='cpu')
4955 result_mps = torch.eq(mps_x, mps_y)
4956 result_cpu = torch.eq(cpu_x, cpu_y)
4957
4958 self.assertEqual(result_cpu, result_mps.to('cpu'))
4959
4960 def test_ne(self):
4961 def helper(shape):
4962 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4963 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
4964 mps_x = cpu_x.detach().clone().to('mps')
4965 mps_y = cpu_y.detach().clone().to('mps')
4966 result_mps = torch.ne(mps_x, mps_y)
4967 result_cpu = torch.ne(cpu_x, cpu_y)
4968
4969 self.assertEqual(result_cpu, result_mps.to('cpu'))
4970
4971 helper((2, 3, 4, 5))
4972
4973 def test_ne_scalar(self):
4974 def helper(shape):
4975 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4976 mps_x = cpu_x.detach().clone().to('mps')
4977 result_mps = torch.ne(mps_x, 0.0)
4978 result_cpu = torch.ne(cpu_x, 0.0)
4979
4980 self.assertEqual(result_cpu, result_mps.to('cpu'))
4981
4982 helper((2, 3, 4, 5))
4983
4984 def test_lt(self):
4985 def helper(shape):
4986 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4987 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
4988 mps_x = cpu_x.detach().clone().to('mps')
4989 mps_y = cpu_y.detach().clone().to('mps')
4990 result_mps = torch.lt(mps_x, mps_y)
4991 result_cpu = torch.lt(cpu_x, cpu_y)
4992
4993 self.assertEqual(result_cpu, result_mps.to('cpu'))
4994
4995 helper((2, 3, 4, 5))
4996
4997 def test_lt_scalar(self):
4998 def helper(shape):
4999 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5000 mps_x = cpu_x.detach().clone().to('mps')
5001 result_mps = torch.lt(mps_x, 0.0)
5002 result_cpu = torch.lt(cpu_x, 0.0)
5003
5004 self.assertEqual(result_cpu, result_mps.to('cpu'))
5005
5006 helper((2, 3, 4, 5))
5007
5008 def test_le(self):
5009 def helper(shape):
5010 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5011 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5012 mps_x = cpu_x.detach().clone().to('mps')
5013 mps_y = cpu_y.detach().clone().to('mps')
5014 result_mps = torch.le(mps_x, mps_y)
5015 result_cpu = torch.le(cpu_x, cpu_y)
5016
5017 self.assertEqual(result_cpu, result_mps.to('cpu'))
5018
5019 helper((2, 3, 4, 5))
5020
5021 def test_le_scalar(self):
5022 def helper(shape):
5023 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5024 mps_x = cpu_x.detach().clone().to('mps')
5025 result_mps = torch.le(mps_x, 0.0)
5026 result_cpu = torch.le(cpu_x, 0.0)
5027
5028 self.assertEqual(result_cpu, result_mps.to('cpu'))
5029
5030 helper((2, 3, 4, 5))
5031
5032 def test_ge(self):
5033 def helper(shape):
5034 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5035 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5036 mps_x = cpu_x.detach().clone().to('mps')
5037 mps_y = cpu_y.detach().clone().to('mps')
5038 result_mps = torch.ge(mps_x, mps_y)
5039 result_cpu = torch.ge(cpu_x, cpu_y)
5040
5041 self.assertEqual(result_cpu, result_mps.to('cpu'))
5042
5043 helper((2, 3, 4, 5))
5044
5045 def test_ge_scalar(self):
5046 def helper(shape):
5047 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5048 mps_x = cpu_x.detach().clone().to('mps')
5049 result_mps = torch.ge(mps_x, 0.0)
5050 result_cpu = torch.ge(cpu_x, 0.0)
5051
5052 self.assertEqual(result_cpu, result_mps.to('cpu'))
5053
5054 helper((2, 3, 4, 5))
5055
5056 def test_gt(self):
5057 def helper(shape):
5058 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5059 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5060 mps_x = cpu_x.detach().clone().to('mps')
5061 mps_y = cpu_y.detach().clone().to('mps')
5062 result_mps = torch.gt(mps_x, mps_y)
5063 result_cpu = torch.gt(cpu_x, cpu_y)
5064
5065 self.assertEqual(result_cpu, result_mps.to('cpu'))
5066
5067 helper((2, 3, 4, 5))
5068
5069 def test_gt_scalar(self):
5070 def helper(shape):
5071 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5072 mps_x = cpu_x.detach().clone().to('mps')
5073 result_mps = torch.gt(mps_x, 0.0)
5074 result_cpu = torch.gt(cpu_x, 0.0)
5075
5076 self.assertEqual(result_cpu, result_mps.to('cpu'))
5077
5078 helper((2, 3, 4, 5))
5079
Denis Vieriu28720ad2023-06-30 18:11:49 +00005080 def test_argmax(self):
5081 # https://github.com/pytorch/pytorch/issues/98191
5082 cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]])
5083 res_cpu = torch.argmax(cpu_tensor, dim=1)
5084
5085 mps_tensor = cpu_tensor.to(torch.device('mps'))
5086 res_mps = torch.argmax(mps_tensor, dim=1)
5087 self.assertEqual(res_cpu, res_mps)
5088
5089 # https://github.com/pytorch/pytorch/issues/92311
5090 mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32)
5091 cpu_tensor = mps_tensor.detach().clone().cpu()
5092
5093 res_mps = torch.argmax(mps_tensor, dim=1)
5094 res_cpu = torch.argmax(cpu_tensor, dim=1)
5095 self.assertEqual(res_cpu, res_mps)
5096
qqaatw2458b3c2022-07-07 00:04:49 +00005097 # Test forward argmin argmax
5098 def test_argmin_argmax(self):
5099 def helper(n, c, h, w, reduction_type, dtype=torch.float32):
5100 if reduction_type == "max":
5101 arg_reduction_fn = torch.argmax
5102 else:
5103 arg_reduction_fn = torch.argmin
5104
Kulin Sethe011a8e2022-05-13 18:28:53 +00005105 cpu_x = None
5106 x = None
Thomas4935b592022-11-23 02:18:03 +00005107 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005108 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5109 x = cpu_x.detach().clone().to('mps')
5110 elif (dtype == torch.bool):
5111 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5112 x = cpu_x.detach().clone().to('mps')
5113 else:
5114 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5115 x = cpu_x.detach().clone().to('mps').requires_grad_()
5116
qqaatw2458b3c2022-07-07 00:04:49 +00005117 y = arg_reduction_fn(x)
5118 ref_y = arg_reduction_fn(cpu_x)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005119 self.assertEqual(y, ref_y)
5120
qqaatw2458b3c2022-07-07 00:04:49 +00005121 y_0 = arg_reduction_fn(x, dim=0)
5122 refy_0 = arg_reduction_fn(cpu_x, dim=0)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005123 self.assertEqual(y_0, refy_0)
5124
qqaatw2458b3c2022-07-07 00:04:49 +00005125 y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
5126 refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005127 self.assertEqual(y_0dim, refy_0dim)
5128
qqaatw2458b3c2022-07-07 00:04:49 +00005129 y_1 = arg_reduction_fn(x, dim=1)
5130 refy_1 = arg_reduction_fn(cpu_x, dim=1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005131 self.assertEqual(y_1, refy_1)
5132
qqaatw2458b3c2022-07-07 00:04:49 +00005133 y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
5134 refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005135 self.assertEqual(y_1dim, refy_1dim)
5136
qqaatw2458b3c2022-07-07 00:04:49 +00005137 y_2 = arg_reduction_fn(x, dim=2)
5138 refy_2 = arg_reduction_fn(cpu_x, dim=2)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005139 self.assertEqual(y_2, refy_2)
5140
qqaatw2458b3c2022-07-07 00:04:49 +00005141 y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
5142 refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005143 self.assertEqual(y_2dim, refy_2dim)
5144
qqaatw2458b3c2022-07-07 00:04:49 +00005145 y_3 = arg_reduction_fn(x, dim=3)
5146 refy_3 = arg_reduction_fn(cpu_x, dim=3)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005147 self.assertEqual(y_3, refy_3)
5148
qqaatw2458b3c2022-07-07 00:04:49 +00005149 y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
5150 refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005151 self.assertEqual(y_3dim, refy_3dim)
5152
qqaatw2458b3c2022-07-07 00:04:49 +00005153 helper(2, 8, 4, 4, "max", torch.float32)
5154 helper(2, 8, 4, 4, "max", torch.int32)
5155 helper(2, 8, 4, 4, "max", torch.float16)
5156 helper(2, 8, 4, 4, "max", torch.int64)
5157 helper(2, 8, 4, 4, "min", torch.float32)
5158 helper(2, 8, 4, 4, "min", torch.int32)
5159 helper(2, 8, 4, 4, "min", torch.float16)
5160 helper(2, 8, 4, 4, "min", torch.int64)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005161
Denis Vieriud0dd8982023-03-02 12:44:59 +00005162 @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
5163 def test_reduction_sum_max_long_val(self):
5164 x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
5165 x_cpu = x_mps.detach().clone().cpu()
5166
5167 res_mps = torch.sum(x_mps)
5168 res_cpu = torch.sum(x_cpu)
5169 self.assertEqual(res_mps, res_cpu)
5170
Kulin Sethe011a8e2022-05-13 18:28:53 +00005171 # Test forward max
5172 # Note - don't test grad now
5173 def test_max_el(self):
5174 def helper(n, c, h, w, dtype=torch.float32):
5175
Thomas4935b592022-11-23 02:18:03 +00005176 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005177 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5178 x = cpu_x.detach().clone().to('mps')
5179 elif (dtype == torch.bool):
5180 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5181 x = cpu_x.detach().clone().to('mps')
5182 else:
5183 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5184 x = cpu_x.detach().clone().to('mps')
5185
5186 ref_y = torch.max(cpu_x)
5187 y = torch.max(x)
5188 self.assertEqual(y, ref_y)
5189
5190 for dim in [0, 1, 2, 3]:
5191 for keepdim in [True, False]:
5192 y, idx = torch.max(x, dim=dim, keepdim=keepdim)
5193 refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
5194 self.assertEqual(y, refy)
5195 self.assertEqual(idx, refidx)
5196
5197 y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
5198 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5199 torch.max(x, dim=0, out=(y_0, idx_0))
5200 refy_0, refidx_0 = torch.max(cpu_x, dim=0)
5201 self.assertEqual(y_0, refy_0)
5202 self.assertEqual(idx_0, refidx_0)
5203
5204 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
5205 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5206 torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5207 refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
5208 self.assertEqual(y_0dim, refy_0dim)
5209 self.assertEqual(idx_0dim, refidx_0dim)
5210
5211 y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
5212 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5213 torch.max(x, dim=1, out=(y_1, idx_1))
5214 refy_1, refidx_1 = torch.max(cpu_x, dim=1)
5215 self.assertEqual(y_1, refy_1)
5216 self.assertEqual(idx_1, refidx_1)
5217
5218 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
5219 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5220 torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5221 refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
5222 self.assertEqual(y_1dim, refy_1dim)
5223 self.assertEqual(idx_1dim, refidx_1dim)
5224
5225 y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
5226 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5227 torch.max(x, dim=2, out=(y_2, idx_2))
5228 refy_2, refidx_2 = torch.max(cpu_x, dim=2)
5229 self.assertEqual(y_2, refy_2)
5230 self.assertEqual(idx_2, refidx_2)
5231
5232 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
5233 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5234 torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5235 refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
5236 self.assertEqual(y_2dim, refy_2dim)
5237 self.assertEqual(idx_2dim, refidx_2dim)
5238
5239 y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
5240 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5241 torch.max(x, dim=3, out=(y_3, idx_3))
5242 refy_3, refidx_3 = torch.max(cpu_x, dim=3)
5243 self.assertEqual(y_3, refy_3)
5244 self.assertEqual(idx_3, refidx_3)
5245
5246 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
5247 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5248 torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5249 refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
5250 self.assertEqual(y_3dim, refy_3dim)
5251 self.assertEqual(idx_3dim, refidx_3dim)
5252
5253 helper(2, 8, 4, 5, torch.float32)
5254 helper(2, 8, 4, 5, torch.int32)
5255 # helper(2, 8, 4, 5, torch.int64)
5256
Raman kumarfd0efb02022-11-18 02:53:39 +00005257 def test_median(self):
5258 def helper_dtype_int32(n1, n2, n3):
5259 cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
5260 mps_x = cpu_x.detach().clone().to('mps')
5261
5262 result_cpu = torch.median(cpu_x)
5263 result_mps = torch.median(mps_x)
5264
5265 self.assertEqual(result_cpu, result_mps)
5266
5267 for dim in [0, 1, 2]:
5268 for keepdim in [True, False]:
5269 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5270 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5271 self.assertEqual(y, refy)
5272 self.assertEqual(idx, refidx)
5273
5274 def helper_dtype_float32(n1, n2, n3):
5275 cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
5276 mps_x = cpu_x.detach().clone().to('mps')
5277
5278 result_cpu = torch.median(cpu_x)
5279 result_mps = torch.median(mps_x)
5280
5281 self.assertEqual(result_cpu, result_mps)
5282
5283 for dim in [0, 1, 2]:
5284 for keepdim in [True, False]:
5285 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5286 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5287 self.assertEqual(y, refy)
5288 self.assertEqual(idx, refidx)
5289
5290 helper_dtype_int32(10, 10, 10) # median at even place
5291 helper_dtype_int32(3, 3, 3) # median at odd place
5292 helper_dtype_int32(1, 1, 1)
5293 helper_dtype_int32(1, 2, 3)
5294 helper_dtype_float32(10, 10, 10)
5295 helper_dtype_float32(3, 3, 3)
5296 helper_dtype_float32(1, 1, 1)
5297
Kulin Sethe011a8e2022-05-13 18:28:53 +00005298 def test_any(self):
5299 def helper(shape):
5300 input_xs = []
5301 prod = 1
5302
5303 for i in range(len(shape)):
5304 prod *= shape[i]
5305 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5306 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5307 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5308 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5309 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5310 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5311 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5312 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5313 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5314 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5315
5316 for i, cpu_x in enumerate(input_xs):
5317 x = cpu_x.detach().clone().to('mps')
5318 y = torch.any(x)
5319 ref_y = torch.any(cpu_x)
5320 self.assertEqual(y, ref_y)
5321
5322 y_0 = torch.any(x, dim=0)
5323 refy_0 = torch.any(cpu_x, dim=0)
5324 self.assertEqual(y_0, refy_0)
5325
5326 y_0dim = torch.any(x, dim=0, keepdim=True)
5327 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5328 self.assertEqual(y_0dim, refy_0dim)
5329
5330 y_0dim = torch.any(x, dim=0, keepdim=True)
5331 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5332 self.assertEqual(y_0dim, refy_0dim)
5333
5334 y_1 = torch.any(x, dim=1)
5335 refy_1 = torch.any(cpu_x, dim=1)
5336 self.assertEqual(y_1, refy_1)
5337
5338 y_1dim = torch.any(x, dim=1, keepdim=True)
5339 refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
5340 self.assertEqual(y_1dim, refy_1dim)
5341
5342 if (len(shape) > 2):
5343 y_2 = torch.any(x, dim=2)
5344 refy_2 = torch.any(cpu_x, dim=2)
5345 self.assertEqual(y_2, refy_2)
5346
5347 y_2dim = torch.any(x, dim=2, keepdim=True)
5348 refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
5349 self.assertEqual(y_2dim, refy_2dim)
5350
5351 y_3 = torch.any(x, dim=3)
5352 refy_3 = torch.any(cpu_x, dim=3)
5353 self.assertEqual(y_3, refy_3)
5354
5355 y_3dim = torch.any(x, dim=3, keepdim=True)
5356 refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
5357 self.assertEqual(y_3dim, refy_3dim)
5358 helper((1, 1, 1, 1))
5359 helper((1, 1, 3, 3))
5360 helper((7, 13))
5361 helper((2, 8, 4, 5))
5362
Kulin Sethe20c94b2023-05-05 22:57:06 +00005363 @unittest.skip("Test is crashing")
5364 def test_reduction_ops_5D(self):
5365 def helper(fn, dim):
5366 x_cpu = fn(torch.zeros(1, 1, 1, 1, 1), dim=dim)
5367 x_mps = fn(torch.zeros(1, 1, 1, 1, 1, device="mps"), dim=dim)
5368 self.assertEqual(x_cpu, x_mps.to('cpu'))
5369 for fn in [torch.any]:
5370 for dim in range(0, 4):
5371 helper(fn, dim)
5372
Kulin Sethe011a8e2022-05-13 18:28:53 +00005373 def test_all(self):
5374 def helper(shape):
5375 input_xs = []
5376 prod = 1
5377
5378 for i in range(len(shape)):
5379 prod *= shape[i]
5380 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5381 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5382 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5383 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5384 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5385 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5386 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5387 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5388 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5389 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5390
5391 for i, cpu_x in enumerate(input_xs):
5392 x = cpu_x.detach().clone().to('mps')
5393 y = torch.all(x)
5394 ref_y = torch.all(cpu_x)
5395 self.assertEqual(y, ref_y)
5396
5397 y_0 = torch.all(x, dim=0)
5398 refy_0 = torch.all(cpu_x, dim=0)
5399 self.assertEqual(y_0, refy_0)
5400
5401 y_0dim = torch.all(x, dim=0, keepdim=True)
5402 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5403 self.assertEqual(y_0dim, refy_0dim)
5404
5405 y_0dim = torch.all(x, dim=0, keepdim=True)
5406 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5407 self.assertEqual(y_0dim, refy_0dim)
5408
5409 y_1 = torch.all(x, dim=1)
5410 refy_1 = torch.all(cpu_x, dim=1)
5411 self.assertEqual(y_1, refy_1)
5412
5413 y_1dim = torch.all(x, dim=1, keepdim=True)
5414 refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
5415 self.assertEqual(y_1dim, refy_1dim)
5416 if (len(shape) > 2):
5417 y_2 = torch.all(x, dim=2)
5418 refy_2 = torch.all(cpu_x, dim=2)
5419 self.assertEqual(y_2, refy_2)
5420
5421 y_2dim = torch.all(x, dim=2, keepdim=True)
5422 refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
5423 self.assertEqual(y_2dim, refy_2dim)
5424
5425 y_3 = torch.all(x, dim=3)
5426 refy_3 = torch.all(cpu_x, dim=3)
5427 self.assertEqual(y_3, refy_3)
5428
5429 y_3dim = torch.all(x, dim=3, keepdim=True)
5430 refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
5431 self.assertEqual(y_3dim, refy_3dim)
5432
5433 helper((1, 1, 1, 1))
5434 helper((1, 1, 3, 3))
5435 helper((7, 13))
5436 helper((2, 8, 4, 5))
David Radley17250972023-07-14 17:42:51 +00005437 x_cpu = torch.tensor([], dtype=torch.bool)
5438 x_mps = x_cpu.to("mps")
5439 assert x_cpu.all() == x_mps.all().cpu()
Kulin Sethe011a8e2022-05-13 18:28:53 +00005440
5441 # Test forward min
5442 def test_min_el(self):
5443 def helper(n, c, h, w):
5444 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5445 x = cpu_x.detach().clone().to('mps')
5446
5447 y = torch.min(x)
5448 ref_y = torch.min(cpu_x)
5449 self.assertEqual(y, ref_y)
5450
5451 y_0, idx_0 = torch.min(x, dim=0)
5452 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5453 self.assertEqual(y_0, refy_0)
5454 self.assertEqual(idx_0, refidx_0)
5455
5456 y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
5457 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5458 torch.min(x, dim=0, out=(y_0, idx_0))
5459 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5460 self.assertEqual(y_0, refy_0)
5461 self.assertEqual(idx_0, refidx_0)
5462
5463 y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
5464 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5465 self.assertEqual(y_0dim, refy_0dim)
5466 self.assertEqual(idx_0dim, refidx_0dim)
5467
5468 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
5469 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5470 torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5471 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5472 self.assertEqual(y_0dim, refy_0dim)
5473 self.assertEqual(idx_0dim, refidx_0dim)
5474
5475 y_1, idx_1 = torch.min(x, dim=1)
5476 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5477 self.assertEqual(y_1, refy_1)
5478 self.assertEqual(idx_1, refidx_1)
5479
5480 y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
5481 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5482 torch.min(x, dim=1, out=(y_1, idx_1))
5483 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5484 self.assertEqual(y_1, refy_1)
5485 self.assertEqual(idx_1, refidx_1)
5486
5487 y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
5488 refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
5489 self.assertEqual(y_1dim, refy_1dim)
5490 self.assertEqual(idx_1dim, refidx_1dim)
5491
5492 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
5493 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5494 torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5495 refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
5496 self.assertEqual(y_1dim, refy_1dim)
5497 self.assertEqual(idx_1dim, refidx_1dim)
5498
5499 y_2, idx_2 = torch.min(x, dim=2)
5500 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5501 self.assertEqual(y_2, refy_2)
5502 self.assertEqual(idx_2, refidx_2)
5503
5504 y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
5505 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5506 torch.min(x, dim=2, out=(y_2, idx_2))
5507 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5508 self.assertEqual(y_2, refy_2)
5509 self.assertEqual(idx_2, refidx_2)
5510
5511 y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
5512 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
5513 self.assertEqual(y_2dim, refy_2dim)
5514 self.assertEqual(idx_2dim, refidx_2dim)
5515
5516 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
5517 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5518 torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5519 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
5520 self.assertEqual(y_2dim, refy_2dim)
5521 self.assertEqual(idx_2dim, refidx_2dim)
5522
5523 y_3, idx_3 = torch.min(x, dim=3)
5524 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5525 self.assertEqual(y_3, refy_3)
5526 self.assertEqual(idx_3, refidx_3)
5527
5528 y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
5529 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5530 torch.min(x, dim=3, out=(y_3, idx_3))
5531 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5532 self.assertEqual(y_3, refy_3)
5533 self.assertEqual(idx_3, refidx_3)
5534
5535 y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
5536 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
5537 self.assertEqual(y_3dim, refy_3dim)
5538 self.assertEqual(idx_3dim, refidx_3dim)
5539
5540 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
5541 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5542 torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5543 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
5544 self.assertEqual(y_3dim, refy_3dim)
5545 self.assertEqual(idx_3dim, refidx_3dim)
5546
5547 helper(2, 8, 4, 5)
5548
5549 # Test forward sum
5550 def test_sum(self):
5551 def helper(n, c, h, w, dtype=torch.float32):
5552 cpu_x = None
5553 x = None
Thomas4935b592022-11-23 02:18:03 +00005554 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005555 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5556 x = cpu_x.detach().clone().to('mps')
5557 elif (dtype == torch.bool):
5558 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5559 x = cpu_x.detach().clone().to('mps')
5560 else:
5561 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5562 x = cpu_x.detach().clone().to('mps').requires_grad_()
5563
5564 all_sum = torch.sum(x)
5565 all_sum_cpu = torch.sum(cpu_x)
5566
5567 self.assertEqual(all_sum, all_sum_cpu)
5568
5569 nil_dim_sum = torch.sum(x, dim=[])
5570 nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
5571
5572 self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
5573
5574 nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
5575 nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
5576
5577 self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
5578
5579 zero_dim_sum = torch.sum(x, dim=[0])
5580 zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
5581
5582 self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
5583
5584 zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
5585 zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
5586
5587 self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
5588
5589 zero_one_dim_sum = torch.sum(x, dim=[0, 1])
5590 zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
5591
5592 self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
5593
5594 zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
5595 zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
5596
5597 self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
5598
5599 two_three_dim_sum = torch.sum(x, dim=[2, 3])
5600 two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
5601
5602 self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
5603
5604 two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
5605 two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
5606
5607 self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
5608
5609 helper(2, 8, 4, 5)
5610 helper(2, 8, 4, 5, dtype=torch.int32)
5611 helper(2, 8, 4, 5, dtype=torch.int64)
5612 helper(2, 8, 4, 5, dtype=torch.bool)
5613
5614 # Test forward prod
5615 def test_prod(self):
5616 def helper(shape, dtype=torch.float32):
5617 cpu_x = None
5618 x = None
Thomas4935b592022-11-23 02:18:03 +00005619 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005620 cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
5621 x = cpu_x.detach().clone().to('mps')
5622 elif (dtype == torch.bool):
5623 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
5624 x = cpu_x.detach().clone().to('mps')
5625 else:
5626 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
5627 x = cpu_x.detach().clone().to('mps').requires_grad_()
5628
5629 all_prod = torch.prod(x)
5630 all_prod_cpu = torch.prod(cpu_x)
5631
5632 self.assertEqual(all_prod, all_prod_cpu)
5633
5634 for dim in range(len(shape)):
5635 dim_prod = torch.prod(x, dim=dim)
5636 dim_prod_cpu = torch.prod(cpu_x, dim=dim)
5637
5638 self.assertEqual(dim_prod, dim_prod_cpu)
5639
5640 dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
5641 dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
5642
5643 self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
5644
5645 for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
5646 helper((2, 3), dtype)
5647
5648 # Test forward mean
5649 def test_mean(self):
5650 def helper(n, c, h, w):
5651 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
5652 x = cpu_x.detach().clone().to('mps').requires_grad_()
5653
5654 all_mean = torch.mean(x)
5655 all_mean_cpu = torch.mean(cpu_x)
5656
5657 self.assertEqual(all_mean, all_mean_cpu)
5658
5659 nil_dim_mean = torch.mean(x, dim=[])
5660 nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
5661
5662 self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
5663
5664 nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
5665 nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
5666
5667 self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
5668
5669 zero_dim_mean = torch.mean(x, dim=[0])
5670 zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
5671
5672 self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
5673
5674 zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
5675 zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
5676
5677 self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
5678
5679 zero_one_dim_mean = torch.mean(x, dim=[0, 1])
5680 zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
5681
5682 self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
5683
5684 zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
5685 zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
5686
5687 self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
5688
5689 two_three_dim_mean = torch.mean(x, dim=[2, 3])
5690 two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
5691
5692 self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
5693
5694 two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
5695 two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
5696
5697 self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
5698
5699 helper(2, 8, 4, 5)
5700
5701 # Test std
5702 def test_std(self):
5703 def helper(shape):
5704 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5705 x = cpu_x.detach().clone().to('mps')
5706
5707 all_std = torch.std(x, unbiased=False)
5708 all_std_cpu = torch.std(cpu_x, unbiased=False)
5709
5710 self.assertEqual(all_std, all_std_cpu)
5711
5712 nil_dim_std = torch.std(x, dim=[], unbiased=False)
5713 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
5714
5715 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5716
5717 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
5718 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
5719
5720 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5721
5722 zero_dim_std = torch.std(x, dim=[0], unbiased=False)
5723 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
5724
5725 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5726
5727 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
5728 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
5729
5730 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5731
5732 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
5733 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
5734
5735 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5736
5737 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
5738 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
5739
5740 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5741
5742 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
5743 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
5744
5745 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5746
5747 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
5748 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
5749
5750 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5751
5752 all_std = torch.std(x, unbiased=True)
5753 all_std_cpu = torch.std(cpu_x, unbiased=True)
5754
5755 self.assertEqual(all_std, all_std_cpu)
5756
5757 nil_dim_std = torch.std(x, dim=[], unbiased=True)
5758 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
5759
5760 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5761
5762 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
5763 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
5764
5765 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5766
5767 zero_dim_std = torch.std(x, dim=[0], unbiased=True)
5768 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
5769
5770 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5771
5772 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
5773 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
5774
5775 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5776
5777 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
5778 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
5779
5780 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5781
5782 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
5783 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
5784
5785 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5786
5787 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
5788 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
5789
5790 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5791
5792 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
5793 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
5794
5795 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5796
5797 helper((4, 5, 6, 7))
qqaatwae6f07e2022-06-30 12:56:55 +00005798 # verify if a change in shape of input would cause problems with graph caching
5799 helper((9, 5, 6, 7))
Kulin Sethe011a8e2022-05-13 18:28:53 +00005800
5801 # Test var
Abhishek Pathakf0570352022-09-25 19:03:58 +00005802 def test_var_simple(self):
5803 def helper():
5804
5805 shape = [2, 3, 4, 5]
5806
Kulin Sethe011a8e2022-05-13 18:28:53 +00005807 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5808 x = cpu_x.detach().clone().to('mps')
5809
Abhishek Pathakf0570352022-09-25 19:03:58 +00005810 for unbiased in [False, True]:
5811 for keepdim in [False, True]:
Kulin Sethe011a8e2022-05-13 18:28:53 +00005812
Abhishek Pathakf0570352022-09-25 19:03:58 +00005813 zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
5814 zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005815
Abhishek Pathakf0570352022-09-25 19:03:58 +00005816 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005817
Abhishek Pathakf0570352022-09-25 19:03:58 +00005818 all_var = torch.var(x, unbiased=unbiased)
5819 all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005820
Abhishek Pathakf0570352022-09-25 19:03:58 +00005821 self.assertEqual(all_var, all_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005822
Abhishek Pathakf0570352022-09-25 19:03:58 +00005823 nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
5824 nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005825
Abhishek Pathakf0570352022-09-25 19:03:58 +00005826 self.assertEqual(nil_dim_var, nil_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005827
Abhishek Pathakf0570352022-09-25 19:03:58 +00005828 zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
5829 zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005830
Abhishek Pathakf0570352022-09-25 19:03:58 +00005831 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005832
Abhishek Pathakf0570352022-09-25 19:03:58 +00005833 zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
5834 zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005835
Abhishek Pathakf0570352022-09-25 19:03:58 +00005836 self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005837
Abhishek Pathakf0570352022-09-25 19:03:58 +00005838 two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
5839 two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005840
Abhishek Pathakf0570352022-09-25 19:03:58 +00005841 self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005842
Abhishek Pathakf0570352022-09-25 19:03:58 +00005843 helper()
Kulin Sethe011a8e2022-05-13 18:28:53 +00005844
Abhishek Pathak074dc742022-06-18 00:14:05 +00005845 # Test forward amax
5846 def test_amax(self):
5847 def helper(shape, dim, keepdim):
5848 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5849 x = cpu_x.detach().clone().to('mps').requires_grad_()
5850
5851 result = torch.amax(x, dim=dim, keepdim=keepdim)
5852 result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
5853
5854 cpu_grad = torch.randn(result_cpu.shape)
5855 grad = cpu_grad.to('mps')
5856
5857 result_cpu.backward(gradient=cpu_grad)
5858 result.backward(gradient=grad)
5859
5860 self.assertEqual(result, result_cpu)
5861 self.assertEqual(x.grad, cpu_x.grad)
5862
5863 for dim in ([], [0], [0, 1], [2, 3]):
5864 for keepdim in [False, True]:
5865 helper((2, 8, 4, 5), dim, keepdim)
5866
5867 # Test forward amin
5868 def test_amin(self):
5869 def helper(shape, dim, keepdim):
5870 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5871 x = cpu_x.detach().clone().to('mps').requires_grad_()
5872
5873 result = torch.amin(x, dim=dim, keepdim=keepdim)
5874 result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
5875
5876 cpu_grad = torch.randn(result_cpu.shape)
5877 grad = cpu_grad.to('mps')
5878
5879 result_cpu.backward(gradient=cpu_grad)
5880 result.backward(gradient=grad)
5881
5882 self.assertEqual(result, result_cpu)
5883 self.assertEqual(x.grad, cpu_x.grad)
5884
5885 for dim in ([], [0], [0, 1], [2, 3]):
5886 for keepdim in [False, True]:
5887 helper((2, 8, 4, 5), dim, keepdim)
5888
Kulin Sethe011a8e2022-05-13 18:28:53 +00005889 # Test minimum and maximum
5890 def test_minimum_maximum(self):
5891 def helper(n, c, h, w):
5892 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5893 cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5894 mps_x = cpu_x.detach().clone().to('mps')
5895 mps_y = cpu_y.detach().clone().to('mps')
5896
5897 minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
5898 minimum_result_mps = torch.minimum(mps_x, mps_y)
5899 self.assertEqual(minimum_result_cpu, minimum_result_mps)
5900
5901 maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
5902 maximum_result_mps = torch.maximum(mps_x, mps_y)
5903 self.assertEqual(maximum_result_cpu, maximum_result_mps)
5904
5905 helper(1, 1, 4, 5)
5906
Denis Vieriud1a2aa12023-08-02 02:51:34 +00005907 def test_clamp_fp16_fp32(self):
5908 cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False)
5909 x = cpu_x.detach().clone().to('mps')
5910
5911 dtype = torch.float16
5912
5913 clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16)
5914 clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10
5915 clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps)
5916
5917 clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16)
5918 clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10
5919 clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu)
5920
5921 self.assertEqual(clamp_result_mps, clamp_result_cpu)
5922
Roger Lam40acc842024-03-18 19:38:15 +00005923 def test_clamp_nan(self):
5924 t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
5925 t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu")
5926
5927 clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100)
5928 clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100)
5929
5930 self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu)
5931
5932 clamp_min_mps = torch.clamp(t_mps, min=-100)
5933 clamp_min_cpu = torch.clamp(t_cpu, min=-100)
5934
5935 self.assertEqual(clamp_min_mps, clamp_min_cpu)
5936
5937 clamp_max_mps = torch.clamp(t_mps, max=100)
5938 clamp_max_cpu = torch.clamp(t_cpu, max=100)
5939
5940 self.assertEqual(clamp_max_mps, clamp_max_cpu)
5941
Kulin Sethe011a8e2022-05-13 18:28:53 +00005942 # Test clamp_min
5943 def test_clamp_min(self):
5944 def helper(n, c, h, w):
5945 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5946 x = cpu_x.detach().clone().to('mps')
5947
5948 cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5949 min_t = cpu_min_t.detach().clone().to('mps')
5950
5951 clamp_min_result = torch.clamp_min(x, min=5.0)
5952 clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
5953
5954 self.assertEqual(clamp_min_result, clamp_min_result_cpu)
5955
5956 clamp_min_t_result = torch.clamp_min(x, min=min_t)
5957 clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
5958
5959 self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
5960
5961 helper(2, 8, 4, 5)
5962
5963 # Test clamp_max
5964
5965 def test_clamp_max(self):
5966 def helper(n, c, h, w):
5967 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5968 x = cpu_x.detach().clone().to('mps')
5969
5970 cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5971 max_t = cpu_max_t.detach().clone().to('mps')
5972
5973 clamp_max_result = torch.clamp_max(x, max=100.0)
5974 clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
5975
5976 self.assertEqual(clamp_max_result, clamp_max_result_cpu)
5977
5978 clamp_max_t_result = torch.clamp_max(x, max=max_t)
5979 clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
5980
5981 self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
5982
5983 helper(2, 8, 4, 5)
5984
5985 # Test clamp
5986 def test_clamp(self):
5987 def helper(n, c, h, w):
5988 import numpy as np
5989 upper_bound = 1000
5990 half_upper_bound = upper_bound / 2
5991
5992 # x=[0..1000)
5993 x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
5994 cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
5995 x = cpu_x.detach().clone().to('mps')
5996
5997 # x=[0..500)
5998 min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
5999 cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
6000 min_t = cpu_min_t.detach().clone().to('mps')
6001
6002 # x=[500..1000), to ensure max's are greater than mins
6003 max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
6004 cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
6005 max_t = cpu_max_t.detach().clone().to('mps')
6006
6007 # [200..600]: just an arbitrary range between [0..1000]
6008 clamp_result = torch.clamp(x, min=200.0, max=600.0)
6009 clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
6010 self.assertEqual(clamp_result, clamp_result_cpu)
6011
6012 # test optional scalar refs and cached graph keys by passing only max
6013 clamp_opt_result = torch.clamp(x, max=600.0)
6014 clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
6015 self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
6016
6017 clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
6018 clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
6019 self.assertEqual(clamp_t_result, clamp_t_result_cpu)
6020
6021 # test optional tensor refs and cached graph keys by passing only max
6022 clamp_topt_result = torch.clamp(x, max=max_t)
6023 clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
6024 self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
6025
Li-Huai (Allan) Lind4d086c2023-08-04 09:32:09 +00006026 # test strided x
6027 clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
6028 clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
6029 self.assertEqual(clamp_result, clamp_result_cpu)
6030
6031 # test strided x, min_t, max_t
6032 clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
6033 clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
6034 self.assertEqual(clamp_result, clamp_result_cpu)
6035
6036 # test strided min_t, max_t
6037 clamp_result = torch.clamp(
6038 x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6039 min=min_t.movedim(0, -1),
6040 max=max_t.movedim(0, -1)
6041 )
6042 clamp_result_cpu = torch.clamp(
6043 cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6044 min=cpu_min_t.movedim(0, -1),
6045 max=cpu_max_t.movedim(0, -1)
6046 )
6047 self.assertEqual(clamp_result, clamp_result_cpu)
6048
Kulin Sethe011a8e2022-05-13 18:28:53 +00006049 # test inplace clamping
6050 x.clamp_(min=200.0, max=600.0)
6051 cpu_x.clamp_(min=200.0, max=600.0)
6052 self.assertEqual(cpu_x, x)
6053
6054 helper(2, 8, 4, 5)
6055
6056 def test_divmode(self):
6057 def helper(shape, rounding_mode):
Abhishek Pathakbccc26f2022-09-10 03:10:04 +00006058 for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
Kulin Seth5d9d8c62023-03-01 20:52:28 +00006059 if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or
6060 (rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False:
Kulin Seth299ada92023-02-10 00:10:08 +00006061 cpu_x = None
6062 cpu_y = None
6063 if (dtype in [torch.float32, torch.float16]):
6064 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6065 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6066 else:
6067 cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
6068 cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
Abhishek Pathakbccc26f2022-09-10 03:10:04 +00006069
Kulin Seth299ada92023-02-10 00:10:08 +00006070 mps_x = cpu_x.detach().clone().to('mps')
6071 # clamp to avoid division by 0
6072 mps_y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00006073
Kulin Seth299ada92023-02-10 00:10:08 +00006074 if (rounding_mode == "floor_divide"):
6075 result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
6076 result_div_mps = torch.floor_divide(mps_x, mps_y)
6077 self.assertEqual(result_div_mps, result_div_cpu)
6078 else:
6079 result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
6080 result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
6081 self.assertEqual(result_div_mps, result_div_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006082
Kulin Setha6347f52022-06-07 18:22:10 +00006083 helper((2, 8, 4, 5), None)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006084 helper((2, 8, 4, 5), "floor")
6085 helper((2, 8, 4, 5), "trunc")
Ramin Azarmehrb63f0312022-12-20 17:02:29 +00006086 helper((2, 8, 4, 5), "floor_divide")
Kulin Sethe011a8e2022-05-13 18:28:53 +00006087
6088 def test_rounding(self):
6089 def helper(shape):
6090 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6091 mps_x = cpu_x.detach().clone().to('mps')
6092
6093 result_floor_cpu = torch.floor(cpu_x)
6094 result_floor_mps = torch.floor(mps_x)
6095 self.assertEqual(result_floor_mps, result_floor_cpu)
6096
6097 result_ceil_cpu = torch.ceil(cpu_x)
6098 result_ceil_mps = torch.ceil(mps_x)
6099 self.assertEqual(result_ceil_mps, result_ceil_cpu)
6100
6101 result_trunc_cpu = torch.trunc(cpu_x)
6102 result_trunc_mps = torch.trunc(mps_x)
6103 self.assertEqual(result_trunc_mps, result_trunc_cpu)
6104
6105 result_round_cpu = torch.round(cpu_x)
6106 result_round_mps = torch.round(mps_x)
6107 self.assertEqual(result_round_mps, result_round_cpu)
6108
6109 helper((2, 6, 3, 5))
6110 helper((2, 8, 4, 5))
6111
Denis Vieriucedb7e32023-02-14 01:06:49 +00006112 def test_remainder(self):
6113 res_cpu = torch.remainder(
6114 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32))
6115 res_mps = torch.remainder(
6116 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32))
6117 self.assertEqual(res_cpu, res_mps)
6118
6119 res_cpu = torch.remainder(
6120 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5)
6121 res_mps = torch.remainder(
6122 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
6123 self.assertEqual(res_cpu, res_mps)
6124
Kulin Sethe011a8e2022-05-13 18:28:53 +00006125 def test_expand(self):
6126 def helper(n, c):
6127 values = [[1.0], [4.0], [7.0]]
6128 cpu_x = torch.tensor(values, device='cpu')
6129 x = cpu_x.detach().clone().to('mps')
6130
6131 strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
6132 strided_mps = torch.as_strided(x, (3, 4), (1, 0))
6133
Kulin Sethe011a8e2022-05-13 18:28:53 +00006134 self.assertEqual(strided_mps, strided_cpu)
6135
6136 helper(3, 1)
6137
Kulin Seth0fe11582023-02-10 15:22:59 +00006138 def test_im2col(self):
6139 def helper(x):
6140 return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3)
6141 x_cpu = torch.rand(1, 1, 200, 100)
6142 x = x_cpu.detach().clone().to('mps')
6143 self.assertEqual(helper(x_cpu), helper(x))
6144
Kulin Sethe011a8e2022-05-13 18:28:53 +00006145 def test_select(self):
6146 def helper(n, c):
6147 cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
6148 x = cpu_x.detach().clone().to('mps').requires_grad_()
6149
6150 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
6151 strided_mps = torch.as_strided(x, (3, 1), (3, 1))
6152 self.assertEqual(strided_mps, strided_cpu)
6153
6154 strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
6155 strided_mps = torch.as_strided(x, (1, 3), (3, 1))
6156 self.assertEqual(strided_mps, strided_cpu)
6157
6158 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
6159 strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006160
6161 self.assertEqual(strided_mps, strided_cpu)
6162
6163 helper(3, 3)
6164
Kulin Seth18587cb2023-02-13 01:03:22 +00006165 def test_sort(self):
6166 for SIZE in (4, 2049):
6167 device = 'mps'
6168 x = torch.rand(4, SIZE, device=device)
6169 res1val, res1ind = torch.sort(x)
6170
6171 res2val = torch.tensor((), device=device)
6172 res2ind = torch.tensor((), device=device, dtype=torch.long)
6173 torch.sort(x, out=(res2val, res2ind))
6174 self.assertEqual(res1val, res2val, atol=0, rtol=0)
6175 self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
6176 self.assertEqual(torch.argsort(x), res1ind)
6177 self.assertEqual(x.argsort(), res1ind)
6178
6179 self.assertEqual(
6180 torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
6181 torch.tensor((10, 20, 30, 40, 50), device=device),
6182 atol=0, rtol=0
6183 )
6184
Kulin Sethe011a8e2022-05-13 18:28:53 +00006185 def test_upsample_nearest2d(self):
Denis Vieriua2afc652023-02-17 05:07:22 +00006186 def helper(N, C, H, W, memory_format):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006187 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
Denis Vieriua2afc652023-02-17 05:07:22 +00006188 requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006189 inputCPU.retain_grad()
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006190 inputMPS = inputCPU.detach().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006191
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006192 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00006193
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006194 for i in values:
6195 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00006196 upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
6197
6198 outputCPU = upsample_nearest2d(inputCPU)
6199 outputMPS = upsample_nearest2d(inputMPS)
6200
6201 self.assertEqual(outputCPU, outputMPS)
6202 upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
6203
6204 outputCPU = upsample_nearest2d(inputCPU)
6205 outputMPS = upsample_nearest2d(inputMPS)
6206
6207 self.assertEqual(outputCPU, outputMPS)
6208
6209 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6210 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6211
6212 self.assertEqual(inputCPU.grad, inputMPS.grad)
6213
Denis Vieriua2afc652023-02-17 05:07:22 +00006214 for memory_format in [torch.channels_last, torch.contiguous_format]:
6215 helper(1, 1, 4, 4, memory_format=memory_format)
6216 helper(7, 5, 3, 2, memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006217
6218 def test_upsample_bilinear2d(self):
6219 def helper(N, C, H, W):
6220 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
6221 requires_grad=True).reshape(N, C, H, W)
6222 inputCPU.retain_grad()
6223 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6224
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006225 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00006226
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006227 for i in values:
6228 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00006229 upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
6230
6231 outputCPU = upsample_bilinear2d(inputCPU)
6232 outputMPS = upsample_bilinear2d(inputMPS)
6233
6234 self.assertEqual(outputCPU, outputMPS)
6235
6236 upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
6237
6238 outputCPU = upsample_bilinear2d(inputCPU)
6239 outputMPS = upsample_bilinear2d(inputMPS)
6240
6241 self.assertEqual(outputCPU, outputMPS)
6242
6243 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6244 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6245
6246 self.assertEqual(inputCPU.grad, inputMPS.grad)
6247
6248 helper(1, 1, 4, 4)
6249 helper(7, 5, 3, 2)
6250
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00006251 def test_interpolate(self):
6252 def helper(shape, output_size, scales, mode, align_corners=False):
6253 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6254 inputCPU.retain_grad()
6255 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
Kulin Seth067c8062022-07-13 21:39:50 +00006256
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00006257 # align_corners is used for 2D interpolation only
6258 if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
6259 if scales is not None:
6260 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
6261 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
6262 else:
6263 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
6264 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
6265 elif scales is not None:
6266 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
6267 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
6268 else:
6269 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
6270 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
Kulin Seth067c8062022-07-13 21:39:50 +00006271
6272 self.assertEqual(outputCPU, outputMPS)
6273
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00006274 # backward pass (chose 0.6 just to have the grad_output != 1)
6275 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
6276 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
6277 self.assertEqual(inputCPU.grad, inputMPS.grad)
6278
6279 # 1D interpolation
6280 for mode in ['nearest', 'nearest-exact']:
6281 helper([2, 3, 4], [3], None, mode) # downsample with size
6282 helper([2, 3, 4], [6], None, mode) # upsample with size
6283 helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor
6284 helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor
6285 # 2D interpolation
6286 for mode in ['nearest', 'nearest-exact', 'bilinear']:
6287 helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size
6288 helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size
6289 helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor
6290 helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor
6291 # align_corners=True
6292 helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
6293 helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
Kulin Seth067c8062022-07-13 21:39:50 +00006294
Kulin Sethe011a8e2022-05-13 18:28:53 +00006295 # Test concat forward
6296 def test_cat1(self):
6297 def helper(shape_x, shape_y, shape_z):
6298 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6299 x = cpu_x.detach().clone().to('mps')
6300
6301 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6302 y = cpu_y.detach().clone().to('mps')
6303
6304 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6305 z = cpu_z.detach().clone().to('mps')
6306
6307 cat = torch.cat([x, y, z], dim=1)
6308 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6309
6310 self.assertEqual(cat, cat_cpu)
6311
6312 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
Abhishek Pathakd7210e62022-07-20 16:31:44 +00006313 helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6314 helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5])
6315 helper([2, 2, 6, 5], [0], [2, 5, 6, 5])
6316 helper([0], [2, 3, 6, 5], [2, 5, 6, 5])
6317 helper([2, 3, 4, 5], [2, 5, 4, 5], [0])
6318 helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5])
6319 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6320 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5])
Kulin Sethe011a8e2022-05-13 18:28:53 +00006321
Kulin Sethe011a8e2022-05-13 18:28:53 +00006322 # Test stack forward
6323 def test_stack(self):
6324 # All shapes must be same
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006325 def helper(shape, dtype=torch.float32):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006326
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006327 x, cpu_x = None, None
6328 y, cpu_y = None, None
6329 z, cpu_z = None, None
Kulin Sethe011a8e2022-05-13 18:28:53 +00006330
Thomas4935b592022-11-23 02:18:03 +00006331 if (dtype not in [torch.float32, torch.bool]):
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006332 cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6333 x = cpu_x.detach().clone().to('mps')
6334 cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6335 y = cpu_y.detach().clone().to('mps')
6336 cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6337 z = cpu_z.detach().clone().to('mps')
6338 elif (dtype == torch.bool):
6339 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6340 x = cpu_x.detach().clone().to('mps')
6341 cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6342 y = cpu_y.detach().clone().to('mps')
6343 cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6344 z = cpu_z.detach().clone().to('mps')
6345 else:
6346 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6347 x = cpu_x.detach().clone().to('mps').requires_grad_()
6348 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6349 y = cpu_y.detach().clone().to('mps').requires_grad_()
6350 cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6351 z = cpu_z.detach().clone().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006352
6353 stack = torch.stack([x, y, z], dim=1)
6354 stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
6355
6356 self.assertEqual(stack, stack_cpu)
6357
6358 helper([2, 8, 4, 5])
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006359 helper([2, 8, 4, 5], dtype=torch.float16)
6360 helper([2, 8, 4, 5], dtype=torch.int32)
6361 helper([2, 8, 4, 5], dtype=torch.int64)
6362 helper([2, 8, 4, 5], dtype=torch.bool)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006363 # Empty test - Currently failing! Empty tensor not handled!
6364 # helper([0, 2, 4, 5])
6365
6366 # Test abs
6367 def test_abs(self):
6368 def helper(shape):
6369 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6370 x = cpu_x.detach().clone().to('mps')
6371
6372 abs_result = torch.abs(x)
6373 abs_result_cpu = torch.abs(cpu_x)
6374
6375 self.assertEqual(abs_result, abs_result_cpu)
6376
6377 helper((2, 8, 4, 5))
6378
6379 def test_log(self):
6380 def helper(shape):
6381 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6382 x = cpu_x.detach().clone().to('mps')
6383
6384 log_result = torch.log(x)
6385 log_result_cpu = torch.log(cpu_x)
6386
6387 self.assertEqual(log_result, log_result_cpu)
6388
6389 helper((2, 8, 4, 5))
6390
6391 def test_log_ten(self):
6392 def helper(shape):
6393 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6394 x = cpu_x.detach().clone().to('mps')
6395
6396 log_ten_result = torch.log10(x)
6397 log_ten_result_cpu = torch.log10(cpu_x)
6398
6399 self.assertEqual(log_ten_result, log_ten_result_cpu)
6400
6401 helper((2, 8, 4, 5))
6402
6403 def test_log_two(self):
6404 def helper(shape):
6405 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6406 x = cpu_x.detach().clone().to('mps')
6407
6408 log_two_result = torch.log2(x)
6409 log_two_result_cpu = torch.log2(cpu_x)
6410
6411 self.assertEqual(log_two_result, log_two_result_cpu)
6412
6413 helper((2, 8, 4, 5))
6414
6415 def test_log1p(self):
6416 def helper(shape):
6417 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6418 x = cpu_x.detach().clone().to('mps')
6419
6420 log_result = torch.log1p(x)
6421 log_result_cpu = torch.log1p(cpu_x)
6422
6423 self.assertEqual(log_result, log_result_cpu)
6424
6425 helper((2, 8, 4, 5))
6426
6427 def test_logaddexp(self):
6428 def helper(shape):
6429 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6430 x = cpu_x.detach().clone().to('mps')
6431
6432 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6433 y = cpu_y.detach().clone().to('mps')
6434
6435 log_result = torch.logaddexp(x, y)
6436 log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
6437
6438 self.assertEqual(log_result, log_result_cpu)
6439
6440 helper((2, 8, 4, 5))
6441
6442 def test_logaddexp2(self):
6443 def helper(shape):
6444 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6445 x = cpu_x.detach().clone().to('mps')
6446
6447 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6448 y = cpu_y.detach().clone().to('mps')
6449
6450 log_result = torch.logaddexp2(x, y)
6451 log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
6452
6453 self.assertEqual(log_result, log_result_cpu)
6454
6455 helper((2, 8, 4, 5))
6456
6457 # Test concat forward
6458 def test_cat2(self):
6459
6460 def helper1(shape_x, shape_y, shape_z, shape_w):
6461 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6462 x = cpu_x.detach().clone().to('mps')
6463
6464 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6465 y = cpu_y.detach().clone().to('mps')
6466
6467 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6468 z = cpu_z.detach().clone().to('mps')
6469
6470 cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
6471 w = cpu_w.detach().clone().to('mps')
6472
6473 cat = torch.cat([x, y, z, w], dim=1)
6474 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
6475
6476 self.assertEqual(cat, cat_cpu)
6477
6478 def helper(shape_x, shape_y, shape_z):
6479 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6480 x = cpu_x.detach().clone().to('mps')
6481
6482 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6483 y = cpu_y.detach().clone().to('mps')
6484
6485 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6486 z = cpu_z.detach().clone().to('mps')
6487
6488 cat = torch.cat([x, y, z], dim=1)
6489 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6490
6491 self.assertEqual(cat, cat_cpu)
6492
6493 helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
6494 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
6495 # Empty test - Currently failing! Empty tensor not handled!
6496 # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
6497
6498 # Test isnan
6499 def test_isnan(self):
6500 def helper(shape):
6501 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6502 nan_index = [random.randrange(0, shape[0])]
6503 # make a selected row inf
6504 cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
6505 x = cpu_x.detach().clone().to('mps')
6506
6507 isnan_result = torch.isnan(x)
6508 isnan_result_cpu = torch.isnan(cpu_x)
6509
6510 self.assertEqual(isnan_result, isnan_result_cpu)
6511
6512 helper((8, 2, 4, 5))
6513
6514 # Test reciprocal
6515 def test_reciprocal(self):
6516 def helper(shape):
6517 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6518 x = cpu_x.detach().clone().to('mps').requires_grad_()
6519
6520 reciprocal_result = torch.reciprocal(x)
6521 reciprocal_result_cpu = torch.reciprocal(cpu_x)
6522
6523 cpu_grad = torch.ones_like(reciprocal_result_cpu)
6524 grad = cpu_grad.to('mps')
6525
6526 reciprocal_result.backward(gradient=grad)
6527 reciprocal_result_cpu.backward(gradient=cpu_grad)
6528
6529 self.assertEqual(reciprocal_result, reciprocal_result_cpu)
6530 self.assertEqual(x.grad, cpu_x.grad)
6531
6532 helper((2, 8, 4, 5))
6533
6534 # Test sqrt
6535 def test_sqrt(self):
6536 def helper(shape):
6537 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6538 x = cpu_x.detach().clone().to('mps').requires_grad_()
6539
6540 sqrt_result = torch.sqrt(x)
6541 sqrt_result_cpu = torch.sqrt(cpu_x)
6542
6543 cpu_grad = torch.ones_like(sqrt_result_cpu)
6544 grad = cpu_grad.to('mps')
6545
6546 sqrt_result.backward(gradient=grad)
6547 sqrt_result_cpu.backward(gradient=cpu_grad)
6548
6549 self.assertEqual(sqrt_result, sqrt_result_cpu)
6550 self.assertEqual(x.grad, cpu_x.grad)
6551
6552 helper((2, 8, 4, 5))
6553
6554 # Test selu, elu, celu
6555 def test_elu(self):
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006556 def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
6557 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
6558 cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006559
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006560 x = cpu_x.detach().clone().to('mps').requires_grad_(True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006561 for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
6562 elu_result = activation_func(x)
6563 elu_result_cpu = activation_func(cpu_x)
6564
6565 cpu_grad = torch.randn(elu_result_cpu.shape)
6566 grad = cpu_grad.to('mps')
6567
6568 elu_result.backward(gradient=grad)
6569 elu_result_cpu.backward(gradient=cpu_grad)
6570
6571 self.assertEqual(elu_result, elu_result_cpu)
6572 self.assertEqual(x.grad, cpu_x.grad)
6573
6574 # Test empty shape too
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006575 for memory_fromat in [torch.channels_last, torch.contiguous_format]:
6576 for shape in [(2, 8, 4, 5)]:
6577 for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
6578 helper(shape, alpha, memory_fromat)
Kulin Setha6347f52022-06-07 18:22:10 +00006579
Denis Vieriu58e045d2024-05-08 01:34:40 +00006580 def test_elu_strided_output(self):
6581 # https://github.com/pytorch/pytorch/issues/124834
6582 elu_input = torch.randn(1, 1024, 500)
6583 alpha = float(1)
6584 inplace = False
6585
6586 elu_input_noncontiguous = elu_input.transpose(1, 2)
6587 self.assertEqual(
6588 F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace),
6589 F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
6590 )
6591
qqaatwc980fc32022-06-30 08:58:42 +00006592 # Test glu
6593 def test_glu(self):
6594 def helper(shape, dim=0):
6595 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6596 x = cpu_x.detach().clone().to('mps').requires_grad_()
Kulin Setha6347f52022-06-07 18:22:10 +00006597
qqaatwc980fc32022-06-30 08:58:42 +00006598 for activation_func in [torch.nn.GLU(dim=dim)]:
6599 glu_result = activation_func(x)
6600 glu_result_cpu = activation_func(cpu_x)
6601
6602 cpu_grad = torch.randn(glu_result_cpu.shape)
6603 grad = cpu_grad.to('mps')
6604
6605 glu_result.backward(gradient=grad)
6606 glu_result_cpu.backward(gradient=cpu_grad)
6607
6608 self.assertEqual(glu_result, glu_result_cpu)
6609 self.assertEqual(x.grad, cpu_x.grad)
6610
6611 for shape in [[4], (2, 4), (2, 8, 4, 6)]:
6612 for dim in range(len(shape)):
6613 helper(shape, dim)
6614
6615 # Test softplus
Kulin Setha6347f52022-06-07 18:22:10 +00006616 def test_softplus(self):
Li-Huai (Allan) Lincce58a42023-05-28 21:52:25 +08006617 def helper(shape, beta, threshold, dtype):
6618 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
Kulin Setha6347f52022-06-07 18:22:10 +00006619 x = cpu_x.detach().clone().to('mps').requires_grad_()
6620
Li-Huai (Allan) Lin7c353eb2022-11-10 09:40:05 +00006621 softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
6622 softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
Kulin Setha6347f52022-06-07 18:22:10 +00006623
qqaatw87451182022-07-06 06:13:21 +00006624 cpu_grad = torch.randn(softplus_result.shape)
6625 grad = cpu_grad.to('mps')
6626
6627 softplus_result.backward(gradient=grad)
6628 softplus_result_cpu.backward(gradient=cpu_grad)
6629
Kulin Setha6347f52022-06-07 18:22:10 +00006630 self.assertEqual(softplus_result, softplus_result_cpu)
qqaatw87451182022-07-06 06:13:21 +00006631 self.assertEqual(x.grad, cpu_x.grad)
Kulin Setha6347f52022-06-07 18:22:10 +00006632
6633 # Test empty shape too
Li-Huai (Allan) Lincce58a42023-05-28 21:52:25 +08006634 for shape, beta, threshold, dtype in product(
6635 [(), (2, 3), (10, 10), (2, 3, 4, 5)],
6636 [0.5, 1, 2, 3, 4],
6637 [0.5, 20, 30, 40, 50],
6638 [torch.float16, torch.float32]
6639 ):
6640 helper(shape, beta, threshold, dtype)
Kulin Setha6347f52022-06-07 18:22:10 +00006641
Kulin Sethe011a8e2022-05-13 18:28:53 +00006642 # Test silu
6643
6644 def test_silu(self):
6645 def helper(shape):
6646 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6647 x = cpu_x.detach().clone().to('mps').requires_grad_()
6648
6649 silu_result = torch.nn.SiLU()(x)
6650 silu_result_cpu = torch.nn.SiLU()(cpu_x)
6651
6652 cpu_grad = torch.randn(silu_result_cpu.shape)
6653 grad = cpu_grad.to('mps')
6654
6655 silu_result.backward(gradient=grad)
6656 silu_result_cpu.backward(gradient=cpu_grad)
6657
6658 self.assertEqual(silu_result, silu_result_cpu)
6659 self.assertEqual(x.grad, cpu_x.grad)
6660
6661 # Test empty shape too
6662 for shape in [[], (2, 3), (2, 8, 4, 5)]:
6663 helper(shape)
6664
Denis Vieriu4247cc92022-09-14 17:24:24 +00006665 def test_cast_mps_to_cpu(self):
6666 def helper(src_dtype, dst_dtype):
6667 input = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6668 input_cast_mps = input.to('mps')
6669 input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype)
6670
6671 # needs to match the initial Tensor
6672 self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype))
6673 helper(torch.half, torch.float)
6674 helper(torch.float, torch.half)
6675
6676 def test_cast_mps_to_mps(self):
6677 def helper(src_dtype, dst_dtype):
6678 input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6679 input_mps = input_cpu.to('mps')
6680 output_mps = input_mps.to(dtype=dst_dtype)
6681 output_cpu = input_cpu.to(dtype=dst_dtype)
6682 self.assertEqual(output_mps.cpu(), output_cpu)
6683 helper(torch.half, torch.float)
6684 helper(torch.float, torch.half)
6685 helper(torch.half, torch.long)
6686 helper(torch.float, torch.int)
6687
Ramin Azarmehr6c80d0a2023-02-09 02:06:40 +00006688 def test_avg_pool2d_count_include_pad(self):
6689 cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True)
6690 x = cpu_x.detach().clone().to('mps').requires_grad_()
6691 pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True)
6692 ref_y = pool(cpu_x)
6693 y = pool(x)
6694 self.assertEqual(y, ref_y)
6695 cpu_grad = torch.randn(ref_y.shape)
6696 grad = cpu_grad.to('mps')
6697 ref_y.backward(gradient=cpu_grad)
6698 y.backward(gradient=grad)
6699 self.assertEqual(x.grad, cpu_x.grad)
6700
Kulin Sethe011a8e2022-05-13 18:28:53 +00006701 # Test adaptive avg pool2d - when the input size is a multiple of output size
6702 # Not testing for channels last right now
6703 def test_adaptive_avg_pool2d_simple(self):
6704 def helper(input_shape, out_shape, channels_last):
6705 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00006706 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006707 cpu_x = cpu_x.to(memory_format=torch.channels_last)
6708 cpu_x.retain_grad()
6709 x = cpu_x.detach().clone().to('mps').requires_grad_()
6710
6711 avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
6712 avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
6713
6714 cpu_grad = torch.randn(avg_result_cpu.shape)
6715 grad = cpu_grad.to('mps')
6716
6717 avg_result.backward(gradient=grad)
6718 avg_result_cpu.backward(gradient=cpu_grad)
6719
6720 self.assertEqual(avg_result, avg_result_cpu)
6721 self.assertEqual(x.grad, cpu_x.grad)
6722
6723 helper((2, 2, 4, 4), (2, 2), False)
6724 helper((2, 2, 9, 9), (3, 3), False)
6725 helper((2, 2, 9, 9), (9, 9), False)
6726 helper((2, 2, 16, 16), (2, 2), False)
6727 helper((2, 2, 16, 16), (2, 16), False)
6728
6729 helper((2, 16, 16), (4, 4), False)
6730
Abhishek Pathake746fff2022-09-27 19:08:22 +00006731 # Output shape larger than input shape
6732
6733 helper((2, 2, 4, 4), (8, 8), False)
6734 helper((2, 2, 2, 2), (4, 4), False)
6735 helper((2, 2, 3, 3), (9, 9), False)
6736 helper((2, 2, 2, 2), (16, 16), False)
6737 helper((2, 2, 2, 16), (16, 16), False)
6738
6739 helper((2, 4, 4), (16, 16), False)
6740
6741 try:
6742 helper((2, 2, 3, 3), (7, 7), False)
6743 except Exception as e:
6744 pass
6745
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006746 # Test max avg pool2d - when the input size is a multiple of output size
6747 # Not testing for channels last right now
6748 def test_adaptive_max_pool2d_simple(self):
6749 def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
6750 cpu_x = None
Thomas4935b592022-11-23 02:18:03 +00006751 if (dtype in [torch.float16, torch.float32]):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006752 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
6753 else:
6754 cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00006755 if (channels_last):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006756 cpu_x = cpu_x.to(memory_format=torch.channels_last)
6757 cpu_x.retain_grad()
6758 x = cpu_x.detach().clone().to('mps').requires_grad_()
6759
6760 max_result, max_indices = None, None
6761 max_result_cpu, max_indices_cpu = None, None
6762
Thomas4935b592022-11-23 02:18:03 +00006763 if (return_indices):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006764 max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6765 max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6766 else:
6767 max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6768 max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6769
6770 cpu_grad = torch.randn(max_result_cpu.shape)
6771 grad = cpu_grad.to('mps')
6772
6773 max_result.backward(gradient=grad)
6774 max_result_cpu.backward(gradient=cpu_grad)
6775
6776 self.assertEqual(max_result, max_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00006777 if (return_indices):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006778 self.assertEqual(max_indices, max_indices_cpu)
6779 self.assertEqual(x.grad, cpu_x.grad)
6780
6781 for dtype in [torch.float32]:
6782 for return_indices in [False, True]:
6783 helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
6784 helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
6785 helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
6786 helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
6787 helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
6788 helper((2, 16, 16), (4, 4), return_indices, dtype)
6789
Kulin Sethe011a8e2022-05-13 18:28:53 +00006790 def test_gelu_simple(self):
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006791 def helper(shape, dtype=torch.float, contiguous=True):
6792 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6793 x = cpu_x.detach().clone().to('mps')
6794
6795 if not contiguous and (0 not in shape and len(shape) >= 2):
6796 # Tranposing will make the tensor non-contiguous
6797 cpu_x = cpu_x.transpose(0, 1)
6798 x = x.transpose(0, 1)
6799 assert not x.is_contiguous()
6800
6801 cpu_x.requires_grad_()
6802 x.requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006803
6804 gelu_result = torch.nn.GELU()(x)
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006805 # GELU is not supported on CPU, so cast it to float
6806 gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float))
Kulin Sethe011a8e2022-05-13 18:28:53 +00006807
6808 cpu_grad = torch.ones_like(gelu_result_cpu)
6809 grad = cpu_grad.to('mps')
6810
6811 gelu_result.backward(gradient=grad)
6812 gelu_result_cpu.backward(gradient=cpu_grad)
6813
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006814 atol = 1e-5 if dtype == torch.float else 1e-2
6815 rtol = 1e-3 if dtype == torch.float else 1e-2
6816 self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006817
6818 assert x.grad is not None # Check that the grad is well-populated
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006819 self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006820
6821 # Test empty shape too
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006822 for dtype in [torch.float, torch.half]:
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006823 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6824 for contiguous in [True, False]:
6825 helper(shape, dtype, contiguous)
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006826 # Test that gelu would raise an assert for integral types
6827 for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
6828 self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
Kulin Sethe011a8e2022-05-13 18:28:53 +00006829
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006830 def test_mish_simple(self):
6831 def helper(shape, dtype=torch.float, contiguous=True):
6832 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6833 x = cpu_x.detach().clone().to('mps')
6834
6835 if not contiguous and (0 not in shape and len(shape) >= 2):
6836 # Tranposing will make the tensor non-contiguous
6837 cpu_x = cpu_x.transpose(0, 1)
6838 x = x.transpose(0, 1)
6839 assert not x.is_contiguous()
6840
6841 cpu_x.requires_grad_()
6842 x.requires_grad_()
6843
6844 mish_result = torch.nn.Mish()(x)
6845 mish_result_cpu = torch.nn.Mish()(cpu_x)
6846
6847 cpu_grad = torch.ones_like(mish_result_cpu)
6848 grad = cpu_grad.to('mps')
6849
6850 mish_result.backward(gradient=grad)
6851 mish_result_cpu.backward(gradient=cpu_grad)
6852
6853 atol = 1e-5 if dtype == torch.float else 1e-2
6854 rtol = 1e-3 if dtype == torch.float else 1e-2
6855 self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
6856
6857 assert x.grad is not None # Check that the grad is well-populated
6858 self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
6859
6860 # Test empty shape too
6861 for dtype in [torch.float, torch.half]:
6862 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6863 for contiguous in [True, False]:
6864 helper(shape, dtype, contiguous)
6865
Kulin Seth3d833212022-05-20 03:18:09 +00006866 def test_gelu(self):
6867 def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
6868 numpy_dtype = {
6869 torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
6870 }[dtype]
6871 devices = ['cpu']
6872 devices += ['mps']
6873
6874 def _gelu_ref(X):
Aaron Gokaslanbd10fea2024-01-01 08:40:46 +00006875 return X * stats.norm.cdf(X) # noqa: F821
Kulin Seth3d833212022-05-20 03:18:09 +00006876
6877 for d in devices:
6878 X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
6879 res = X
6880 ref = (X.to(numpy_dtype).cpu().detach().numpy())
6881 self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
6882
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006883 for n in [1, 5, 10]:
6884 for m in [1, 5, 10]:
Kulin Seth3d833212022-05-20 03:18:09 +00006885 _test_gelu(n, m, torch.float32, True)
6886 _test_gelu(n, m, torch.float32, False)
6887
6888 # Test multi threaded
6889 num_threads = torch.get_num_threads()
6890 torch.set_num_threads(4)
6891 try:
6892 _test_gelu(32, 32, torch.float32, False)
6893 finally:
6894 torch.set_num_threads(num_threads)
6895
Denis Vieriu7ce785b2023-02-11 00:24:30 +00006896 def test_gelu_tanh(self):
6897 def helper(shape):
6898 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
6899 x = cpu_x.detach().clone().to('mps')
6900
6901 gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
6902 gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
6903 self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
6904
6905 helper((2, 8, 4, 5))
6906
Kulin Sethe011a8e2022-05-13 18:28:53 +00006907 # Test hardtanh
6908 def test_hardtanh(self):
6909 def helper(shape, min_val, max_val, inplace=False):
6910 cpu_x = None
6911 x = None
6912
Thomas4935b592022-11-23 02:18:03 +00006913 if (not inplace):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006914 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6915 x = cpu_x.detach().clone().to('mps').requires_grad_()
6916 else:
6917 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6918 x = cpu_x.detach().clone().to('mps')
6919
6920 hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
6921 hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
6922
6923 self.assertEqual(hardtanh_result, hardtanh_result_cpu)
6924
Thomas4935b592022-11-23 02:18:03 +00006925 if (not inplace):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006926 cpu_grad = torch.randn(hardtanh_result_cpu.shape)
6927 grad = cpu_grad.to('mps')
6928 hardtanh_result.backward(gradient=grad)
6929 hardtanh_result_cpu.backward(gradient=cpu_grad)
6930 self.assertEqual(x.grad, cpu_x.grad)
6931
6932 # Test empty shape too
6933 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
6934 for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
6935 helper(shape, min_val, max_val)
6936 helper(shape, min_val, max_val, inplace=True)
6937
Thomas4935b592022-11-23 02:18:03 +00006938 def test_hardswish(self):
6939 def helper(shape, inplace=False, requires_grad=True):
6940 m = nn.Hardswish(inplace=inplace)
6941
6942 input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad)
6943 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
6944
6945 if inplace and requires_grad: # check that both raise runtime error
6946 self.assertRaises(RuntimeError, lambda: m(input_cpu))
6947 self.assertRaises(RuntimeError, lambda: m(input_mps))
6948 return
6949
6950 output_cpu = m(input_cpu)
6951 output_mps = m(input_mps)
6952
6953 cpu_grad = torch.ones_like(output_cpu)
6954 mps_grad = cpu_grad.to('mps')
6955
6956 self.assertEqual(output_cpu, output_mps)
6957
6958 if requires_grad:
6959 output_cpu.backward(gradient=cpu_grad)
6960 output_mps.backward(gradient=mps_grad)
6961
6962 self.assertEqual(input_cpu.grad, input_mps.grad)
6963
6964 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
6965 helper(shape, inplace=False, requires_grad=False)
6966 helper(shape, inplace=True, requires_grad=False)
6967 helper(shape, inplace=False, requires_grad=True)
6968 helper(shape, inplace=True, requires_grad=True)
6969
Kulin Seth3d833212022-05-20 03:18:09 +00006970 def test_transpose_2D(self):
6971 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
6972 values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
6973 cpu_x = torch.tensor(values, device='cpu')
6974 mps_x = torch.tensor(values, device='mps')
6975 mps_x1 = torch.tensor(values1, device='mps')
6976
6977 cpu_transpose = torch.transpose(cpu_x, 0, 1)
6978 mps_transpose = torch.transpose(mps_x, 0, 1)
6979 self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
6980
6981 def test_transpose_3D(self):
6982 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
6983 cpu_x = torch.tensor(values, device='cpu')
6984 mps_x = torch.tensor(values, device='mps')
6985
6986 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
6987 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
6988 self.assertEqual(cpu_transpose1, mps_transpose1)
6989
6990 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
6991 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
6992 self.assertEqual(cpu_transpose2, mps_transpose2)
6993
6994 cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
6995 mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
6996 self.assertEqual(cpu_transpose3, mps_transpose3)
6997
6998
6999 def test_transpose_4D(self):
7000 values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
7001 [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
7002 cpu_x = torch.tensor(values, device='cpu')
7003 mps_x = torch.tensor(values, device='mps')
7004
7005 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
7006 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
7007 self.assertEqual(cpu_transpose1, mps_transpose1)
7008
7009 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
7010 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
7011 self.assertEqual(cpu_transpose2, mps_transpose2)
7012
7013 cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
7014 mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
7015 self.assertEqual(cpu_transpose3, mps_transpose3)
7016
7017 cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
7018 mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
7019 self.assertEqual(cpu_transpose4, mps_transpose4)
7020
7021 cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
7022 mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
7023 self.assertEqual(cpu_transpose5, mps_transpose5)
7024
7025 cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
7026 mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
7027 self.assertEqual(cpu_transpose6, mps_transpose6)
7028
Kulin Sethe011a8e2022-05-13 18:28:53 +00007029 # Test sign
7030 def test_sign(self):
7031 def helper(shape):
7032 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7033 x = cpu_x.detach().clone().to('mps').requires_grad_()
7034
7035 sign_result = torch.sign(x)
7036 sign_result_cpu = torch.sign(cpu_x)
7037
7038 cpu_grad = torch.ones_like(sign_result_cpu)
7039 grad = cpu_grad.to('mps')
7040
7041 sign_result.backward(gradient=grad)
7042 sign_result_cpu.backward(gradient=cpu_grad)
7043
7044 self.assertEqual(sign_result, sign_result_cpu)
7045
7046 helper((2, 8, 4, 5))
7047
Daniel Falbele8185742022-10-25 07:12:28 +00007048 def test_signbit(self):
7049 def helper(shape, dtype):
7050 cpu_x = torch.randn(shape, device='cpu').to(dtype)
7051 x = cpu_x.clone().to('mps')
7052
7053 signbit_result = torch.signbit(x)
7054 signbit_result_cpu = torch.signbit(cpu_x)
7055
7056 self.assertEqual(signbit_result, signbit_result_cpu)
7057
7058 helper((2, 8, 4, 5), torch.int)
7059 helper((2, 8, 4, 5), torch.float)
7060 helper((2, 8, 4, 5), torch.int64)
7061
Kulin Sethe011a8e2022-05-13 18:28:53 +00007062 # Test neg
7063 def test_neg(self):
7064 def helper(shape):
7065 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7066 x = cpu_x.detach().clone().to('mps').requires_grad_()
7067
7068 neg_result = torch.neg(x)
7069 neg_result_cpu = torch.neg(cpu_x)
7070
7071 cpu_grad = torch.ones_like(neg_result_cpu)
7072 grad = cpu_grad.to('mps')
7073
7074 neg_result.backward(gradient=grad)
7075 neg_result_cpu.backward(gradient=cpu_grad)
7076
7077 self.assertEqual(neg_result, neg_result_cpu)
7078
7079 helper((2, 8, 4, 5))
7080
Nikita Shulga01e6d642023-07-05 23:17:43 +00007081 def test_neg_strided_input(self):
7082 # See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337
7083 x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
7084 y = x.permute(1, 0, 2)[..., 1]
7085 z = y + y.neg()
7086 self.assertEqual(z.abs().max().item(), 0.0)
7087
qqaatw1caa25e2022-07-14 23:40:00 +00007088 # Test index add
7089 def test_index_add(self):
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00007090 def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):
7091 cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False)
qqaatw1caa25e2022-07-14 23:40:00 +00007092 x = cpu_x.detach().clone().to('mps')
7093
7094 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7095 idx = cpu_idx.detach().clone().to('mps')
7096
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00007097 cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False)
qqaatw1caa25e2022-07-14 23:40:00 +00007098 source = cpu_source.detach().clone().to('mps')
7099
7100 idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha)
7101 idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha)
7102 self.assertEqual(idx_result, idx_result_cpu)
7103
7104 helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5)
7105 helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0)
7106 helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5)
7107 helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0)
7108 helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4)
7109 helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0)
7110 # test result dim=1
7111 helper((2,), 0, [1], (1,), 6.0)
7112 helper(2, 0, 1, 1, 6)
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00007113 # test float16
7114 helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16)
qqaatw1caa25e2022-07-14 23:40:00 +00007115
Nikita Shulga67840302024-01-09 06:49:45 -08007116 def test_index_64bit(self):
7117 """ Test that index operations work for 4Gb+ tensors """
7118 if product_version < 14.0:
7119 raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039")
7120 # Cleanup memory
7121 gc.collect()
7122 torch.mps.empty_cache()
7123 # Check that index operations work for 4+GB tensors
7124 x = torch.rand(16000, 67120, device="mps")
7125 self.assertGreater(x.element_size() * x.numel(), 2**32)
7126 idx = torch.arange(0, 2, device="mps")
7127 x_sampled = x[:, idx]
7128 self.assertEqual(x[:, 0], x_sampled[:, 0])
7129 # Reclaim memory after running the tests
7130 del x
7131 gc.collect()
7132 torch.mps.empty_cache()
7133
Nikita Shulga18728342024-01-17 01:33:08 +00007134 def test_mm_large(self):
7135 """ Test that MM works for matrices with index larger than 32K """
7136 x = torch.rand(10, 1, device="mps")
7137 y = torch.rand(1, 32769, device="mps")
7138 # This used to crash with:
7139 # error: subRange.start (24576) is not less than length of dimension[0] (16384)
7140 # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
7141 self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
Nikita Shulga24dd9f42024-02-01 17:53:38 +00007142
Nikita Shulga54988042024-03-13 14:34:03 +00007143 def compare_mm(m, n, k, dtype=torch.float):
7144 x = torch.rand(m, n, device="mps", dtype=dtype)
7145 y = torch.rand(n, k, device="mps", dtype=dtype)
Nikita Shulga24dd9f42024-02-01 17:53:38 +00007146 z = torch.mm(x, y).cpu()
7147 z_cpu = torch.mm(x.cpu(), y.cpu())
7148 self.assertEqual(z, z_cpu)
7149
7150 # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
7151 compare_mm(1024, 1, 32769)
7152 # one more time, but with dimensions inverted
7153 # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
7154 compare_mm(32769, 1, 1025)
Nikita Shulga18728342024-01-17 01:33:08 +00007155
Nikita Shulga54988042024-03-13 14:34:03 +00007156 if product_version >= 14.0:
7157 # Test bfloat16 mm
7158 compare_mm(1024, 1, 32769, torch.bfloat16)
7159
Nikita Shulgaabf3f902024-04-22 23:43:11 +00007160 @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test")
7161 @unittest.skipIf(product_version < 14.0, "Can't allocate 4Gb tensor on MacOS 13")
7162 def test_copy_large(self):
7163 """ Test that copy of 4Gb+ tensors works """
7164 x = torch.ones((2**30 + 11,), dtype=torch.float32)
7165 y = x.to(device="mps")
7166 self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps")))
7167 del y
7168 del x
7169
qqaatwc4da23e2022-06-28 19:51:43 +00007170 # Test flip
7171 def test_flip(self):
7172 def helper(shape, dims):
7173 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7174 x = cpu_x.detach().clone().to('mps')
7175
7176 flip_result = torch.flip(x, dims=dims)
7177 flip_result_cpu = torch.flip(cpu_x, dims=dims)
7178
7179 self.assertEqual(flip_result, flip_result_cpu)
7180
7181 helper((2, 8, 4, 5), [0])
7182 helper((8, 8, 4, 5), [0, 1])
7183 helper((2, 8, 4, 5), (0, 1, 2, 3))
7184 helper((2, 3, 3), (-1,))
7185 # empty dims
7186 helper((2, 8, 4, 5), [])
7187 # input.numel() == 1
7188 helper((1,), (0,))
7189 # input.numel() == 0
7190 helper((0,), (0,))
Li-Huai (Allan) Linc95bcb62023-03-14 00:34:26 +00007191 # none of dims that needs to be flipped
7192 helper((1, 3), [0])
qqaatwc4da23e2022-06-28 19:51:43 +00007193
Kulin Sethe011a8e2022-05-13 18:28:53 +00007194 # Test index select
7195 def test_index_select(self):
7196 def helper(shape, dim, index, idx_dtype=torch.int32):
7197 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7198 x = cpu_x.detach().clone().to('mps')
7199
7200 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7201 idx = cpu_idx.detach().clone().to('mps')
7202
Kulin Sethe011a8e2022-05-13 18:28:53 +00007203 idx_result = torch.index_select(x, dim=dim, index=idx)
7204 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7205
7206 self.assertEqual(idx_result, idx_result_cpu)
7207
7208 helper((2, 8, 4, 5), 0, [1])
7209 helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
7210 helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
7211 helper((2, 8, 4, 5), 2, [3, 0, 1])
7212 helper((2, 8, 4, 5), 3, [2, 3, 0])
7213 helper((2, 3, 3), -1, [1, 2])
Li-Huai (Allan) Linccbdf492023-01-19 14:08:02 +00007214 helper((), 0, [0])
Nikita Shulga8a888522023-02-05 05:45:57 +00007215 helper((5), 0, [])
Li-Huai (Allan) Linccbdf492023-01-19 14:08:02 +00007216
7217 def test_index_select_scalar(self):
7218 def helper(value, dim, index, idx_dtype=torch.int32):
7219 cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
7220 x = cpu_x.detach().clone().to('mps')
7221
7222 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7223 idx = cpu_idx.detach().clone().to('mps')
7224
7225 idx_result = torch.index_select(x, dim=dim, index=idx)
7226 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7227
7228 self.assertEqual(idx_result, idx_result_cpu)
7229
Li-Huai (Allan) Lin4afef852023-03-28 19:23:55 +00007230 helper(22, 0, [0])
7231 with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
7232 helper(22, 0, [])
Kulin Sethe011a8e2022-05-13 18:28:53 +00007233
7234 def test_embedding_dense_backward(self):
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007235 def helper(n, d, m, idx):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007236 embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
Nikita Shulga62ef15e2022-11-10 23:52:27 +00007237 emedding_weight = embeddingMPS.weight.detach().cpu()
Kulin Sethe011a8e2022-05-13 18:28:53 +00007238 W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
Nikita Shulga62ef15e2022-11-10 23:52:27 +00007239 idx_MPS = torch.tensor(idx, device='mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007240 a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable
7241 a_MPS.retain_grad()
7242 b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place
7243 b_MPS.retain_grad()
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007244 out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007245 loss_MPS = out_MPS.sigmoid().prod()
7246 loss_MPS.backward()
7247
Nikita Shulga62ef15e2022-11-10 23:52:27 +00007248 embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007249 W_CPU = W_MPS.to('cpu')
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007250 idx_CPU = torch.tensor(idx)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007251 a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable
7252 a_CPU.retain_grad()
7253 b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place
7254 b_CPU.retain_grad()
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007255 out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007256 loss_CPU = out_CPU.sigmoid().prod()
7257 loss_CPU.backward()
7258
7259 self.assertEqual(b_CPU.grad, b_MPS.grad)
7260 self.assertEqual(a_CPU.grad, a_MPS.grad)
7261
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007262 helper(3, 5, 7, [0, 1, 2])
Li-Huai (Allan) Lin330c9072023-05-21 13:47:46 +08007263 helper(3, 6, 7, [0, 1, 2]) # verify if changes in shape would cause cached graph lookup problems
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007264 helper(3, 5, 7, 2) # test scalar index
Kulin Sethe011a8e2022-05-13 18:28:53 +00007265
7266 # Test pytorch gather
7267 def test_gather(self):
7268 def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
7269 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7270 x = cpu_x.detach().clone().to('mps').requires_grad_()
7271
7272 # Indices should be taken from range of axis along which gathering is done
7273 idx_np = np.random.randint(0, shape[dim], idx_shape)
7274
7275 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7276 idx = cpu_idx.detach().clone().to('mps')
7277
7278 gather_result = torch.gather(x, dim=dim, index=idx)
7279 gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
7280
7281 cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
7282 grad = cpu_grad.to('mps')
7283 gather_result.backward(gradient=grad)
7284 gather_result_cpu.backward(gradient=cpu_grad)
7285
7286 self.assertEqual(gather_result, gather_result_cpu)
7287 self.assertEqual(cpu_x.grad, x.grad)
7288
7289 helper((6, 3, 3), 0, (3, 3, 3))
7290 helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
7291 helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
7292 helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
7293 helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
7294 helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
7295 helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
7296 helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
7297 helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
7298
Abhishek Pathak81b366a2022-09-30 00:24:16 +00007299 # Test pytorch gather
7300 def test_gather_scalar(self):
7301 idx_dtype = torch.int64
7302 cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7303 x = cpu_x.detach().clone().to('mps').requires_grad_()
7304
7305 idx_np = [0]
7306
7307 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7308 idx = cpu_idx.detach().clone().to('mps')
7309
7310 gather_result = torch.gather(x, dim=0, index=idx)
7311 gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
7312
7313 cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
7314 grad = cpu_grad.to('mps')
7315 gather_result.backward(gradient=grad)
7316 gather_result_cpu.backward(gradient=cpu_grad)
7317
7318 self.assertEqual(gather_result, gather_result_cpu)
7319 self.assertEqual(cpu_x.grad, x.grad)
7320
Kulin Sethe011a8e2022-05-13 18:28:53 +00007321 # Test pytorch scatter_add and scatter
7322 def test_scatter_add(self):
7323 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
7324 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7325 x = cpu_x.detach().clone().to('mps').requires_grad_()
7326
7327 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7328 src = cpu_src.detach().clone().to('mps').requires_grad_()
7329
7330 # Indices should be taken from range of axis along which gathering is done
7331 idx_np = None
Thomas4935b592022-11-23 02:18:03 +00007332 if (do_add):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007333 idx_np = np.random.randint(0, shape[dim], idx_shape)
7334 else:
7335 idx_np = np.array([[0, 1, 2],
7336 [1, 2, 3],
7337 [2, 3, 4],
7338 [3, 4, 5],
7339 [4, 5, 6]])
7340
7341 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7342 idx = cpu_idx.detach().clone().to('mps')
7343
7344 scatter_result = None
7345 scatter_result_cpu = None
7346
Thomas4935b592022-11-23 02:18:03 +00007347 if (do_add):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007348 scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
7349 scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7350 else:
7351 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
7352 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7353
7354 cpu_grad = None
7355 grad = None
7356
Thomas4935b592022-11-23 02:18:03 +00007357 if (idx_shape == src_shape):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007358 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7359 grad = cpu_grad.to('mps')
7360 scatter_result.backward(gradient=grad)
7361 scatter_result_cpu.backward(gradient=cpu_grad)
7362
7363 self.assertEqual(scatter_result, scatter_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00007364 if (idx_shape == src_shape):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007365 self.assertEqual(cpu_x.grad, x.grad)
7366 self.assertEqual(cpu_src.grad, src.grad)
7367
7368 helper((2, 3), 0, (5, 3), (5, 3))
7369 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7370 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7371 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
7372 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
7373 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
7374
7375 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
7376 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
7377 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
7378 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
7379
7380 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
7381 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
7382 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
7383
7384 # Test scatter src
7385 helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
7386 helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
7387
Abhishek Pathak81b366a2022-09-30 00:24:16 +00007388 # Test pytorch scatter_add and scatter for scalar input
7389 def test_scatter_add_scalar(self):
7390 def helper(idx_dtype=torch.int64, do_add=True):
7391 cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
7392 x = cpu_x.detach().clone().to('mps').requires_grad_()
7393
7394 cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7395 src = cpu_src.detach().clone().to('mps').requires_grad_()
7396
7397 # Indices should be taken from range of axis along which gathering is done
7398 idx_np = [0]
7399
7400 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7401 idx = cpu_idx.detach().clone().to('mps')
7402
7403 scatter_result = None
7404 scatter_result_cpu = None
7405
Thomas4935b592022-11-23 02:18:03 +00007406 if (do_add):
Abhishek Pathak81b366a2022-09-30 00:24:16 +00007407 scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
7408 scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7409 else:
7410 scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
7411 scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7412
7413 cpu_grad = None
7414 grad = None
7415
7416 cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
7417 grad = cpu_grad.to('mps')
7418 scatter_result.backward(gradient=grad)
7419 scatter_result_cpu.backward(gradient=cpu_grad)
7420
7421 self.assertEqual(scatter_result, scatter_result_cpu)
7422 self.assertEqual(cpu_x.grad, x.grad)
7423 self.assertEqual(cpu_src.grad, src.grad)
7424
7425 helper()
7426 helper(do_add=False)
7427
Kulin Sethe011a8e2022-05-13 18:28:53 +00007428 # Test pytorch scatter_reduce
7429 def test_scatter_reduce(self):
7430 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
7431 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7432 x = cpu_x.detach().clone().to('mps').requires_grad_()
7433
7434 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7435 src = cpu_src.detach().clone().to('mps').requires_grad_()
7436
7437 # Indices should be taken from range of axis along which gathering is done
7438 idx_np = np.random.randint(0, shape[dim], idx_shape)
7439
7440 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7441 idx = cpu_idx.detach().clone().to('mps')
7442
7443 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
7444 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
7445
7446 self.assertEqual(scatter_result, scatter_result_cpu)
7447
7448 # for reduce in ["sum", "prod", "amax", "amin"]:
Denis Vieriu4acdc442023-02-13 23:31:06 +00007449 for reduce_type in ["add", "multiply"]:
7450 helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
7451 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7452 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7453 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7454 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7455 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007456
Denis Vieriu4acdc442023-02-13 23:31:06 +00007457 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
7458 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
7459 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
7460 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007461
Denis Vieriu4acdc442023-02-13 23:31:06 +00007462 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
7463 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
7464 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007465
7466 def test_is_nonzero(self):
7467 self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
7468 self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
7469 self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
7470 self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
7471
7472 # Test triu
7473 def test_triu(self):
7474 def helper(shape, diag=0):
7475 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7476 x = cpu_x.detach().clone().to('mps').requires_grad_()
7477
7478 triu_result = torch.triu(x, diag)
7479 triu_result_cpu = torch.triu(cpu_x, diag)
7480
7481 cpu_grad = torch.randn(triu_result_cpu.shape)
7482 grad = cpu_grad.to('mps')
7483
7484 triu_result.backward(gradient=grad)
7485 triu_result_cpu.backward(gradient=cpu_grad)
7486
7487 self.assertEqual(triu_result, triu_result_cpu)
7488 self.assertEqual(x.grad, cpu_x.grad)
7489
7490 helper((2, 8, 4, 5))
7491 helper((2, 8, 4, 5), diag=1)
7492 helper((2, 8, 4, 5), diag=2)
7493 helper((2, 8, 4, 5), diag=3)
7494 helper((2, 8, 4, 5), diag=-1)
7495 helper((2, 8, 4, 5), diag=-2)
7496 helper((2, 8, 4, 5), diag=-3)
7497
Kulin Seth8ecb49b2022-12-19 22:00:07 +00007498 # Test inverse
7499 def test_inverse(self):
7500 def helper(n):
7501 cpu_input = torch.randn(n, n, device='cpu')
7502 mps_input = cpu_input.to('mps')
7503
7504 cpu_result = torch.linalg.inv(cpu_input)
7505 mps_result = torch.linalg.inv(mps_input)
7506 self.assertEqual(cpu_result, mps_result)
7507
7508 helper(2)
7509 helper(6)
7510 helper(3)
7511 helper(8)
7512
Kulin Sethe011a8e2022-05-13 18:28:53 +00007513 # Test tril
7514 def test_tril(self):
7515 def helper(shape, diag=0):
7516 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7517 x = cpu_x.detach().clone().to('mps').requires_grad_()
7518
7519 tril_result = torch.tril(x, diag)
7520 tril_result_cpu = torch.tril(cpu_x, diag)
7521
7522 cpu_grad = torch.randn(tril_result_cpu.shape)
7523 grad = cpu_grad.to('mps')
7524
7525 tril_result.backward(gradient=grad)
7526 tril_result_cpu.backward(gradient=cpu_grad)
7527
7528 self.assertEqual(tril_result, tril_result_cpu)
7529 self.assertEqual(x.grad, cpu_x.grad)
7530
7531 helper((2, 8, 4, 5))
7532 helper((2, 8, 4, 5), diag=1)
7533 helper((2, 8, 4, 5), diag=2)
7534 helper((2, 8, 4, 5), diag=3)
7535 helper((2, 8, 4, 5), diag=-1)
7536 helper((2, 8, 4, 5), diag=-2)
7537 helper((2, 8, 4, 5), diag=-3)
7538
Kulin Seth8552acb2022-05-27 17:07:02 +00007539 # test eye
7540 def test_eye(self):
7541 def helper(n, m, dtype):
7542 cpu_result = None
7543 result = None
7544
Thomas4935b592022-11-23 02:18:03 +00007545 if (n == m):
Kulin Seth8552acb2022-05-27 17:07:02 +00007546 cpu_result = torch.eye(n, dtype=dtype, device='cpu')
7547 result = torch.eye(n, dtype=dtype, device='mps')
7548 else:
7549 cpu_result = torch.eye(n, m, device='cpu')
7550 result = torch.eye(n, m, device='mps')
7551
7552 self.assertEqual(result, cpu_result)
7553
Li-Huai (Allan) Lin100641aa2023-03-20 18:08:36 +00007554 for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]:
Kulin Seth8552acb2022-05-27 17:07:02 +00007555 helper(2, 2, dtype)
7556 helper(2, 3, dtype)
7557 helper(0, 2, dtype)
7558 helper(0, 0, dtype)
7559 helper(3, 8, dtype)
7560 helper(8, 3, dtype)
7561
Kulin Sethe011a8e2022-05-13 18:28:53 +00007562 # Test diag
7563 def test_diag(self):
7564 def helper(shape, diag=0):
7565 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7566 x = cpu_x.detach().clone().to('mps').requires_grad_()
7567
7568 diag_result = torch.diag(x, diag)
7569 diag_result_cpu = torch.diag(cpu_x, diag)
7570
7571 # cpu_grad = torch.randn(diag_result_cpu.shape)
7572 # grad = cpu_grad.to('mps')
7573
7574 # diag_result.backward(gradient=grad)
7575 # diag_result_cpu.backward(gradient=cpu_grad)
7576
7577 self.assertEqual(diag_result, diag_result_cpu)
7578 # self.assertEqual(x.grad, cpu_x.grad)
7579
7580 for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
7581 for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
7582 helper(shape, diag=diag)
7583
Kulin Setha3bdafe2022-06-01 13:47:14 +00007584 # Test linspace
7585 def test_linspace(self):
7586 def helper(start, end, steps, dtype=torch.float32):
7587 cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
7588 result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
7589 self.assertEqual(cpu_result, result)
7590
7591 for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
7592 helper(2, 5, 10, dtype)
7593 helper(2, 2, 10, dtype)
7594 helper(5, 2, 10, dtype)
7595 helper(2, 2, 0, dtype)
7596
Nikita Shulga55cac222022-06-03 21:54:41 +00007597 # Test argange
7598 def test_arange(self):
7599 self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
7600 self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
7601 self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
7602 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
7603
Kulin Sethf35f1232023-02-09 19:30:14 +00007604 def test_arange_empty(self):
7605 out_mps = torch.tensor([], device="mps")
7606 out_cpu = torch.tensor([], device="cpu")
7607
7608 y_mps = torch.arange(0, 0, 1, out=out_mps)
7609 y_cpu = torch.arange(0, 0, 1, out=out_cpu)
7610 self.assertEqual(y_mps, y_cpu)
7611
OwenPendrighElliott840fb742023-02-13 23:19:06 +00007612 # Test rgange
7613 def test_range(self):
7614 self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
7615 self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
7616 self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
7617 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
7618
Kulin Sethe011a8e2022-05-13 18:28:53 +00007619 # Test softmax
7620 def test_softmax(self):
7621 def helper(shape, dim, channels_last=False):
7622 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00007623 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007624 cpu_x = cpu_x.to(memory_format=torch.channels_last)
7625 cpu_x.retain_grad()
7626 x = cpu_x.detach().clone().to('mps').requires_grad_()
7627
7628 softmax_result = torch.nn.functional.softmax(x, dim=dim)
7629 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7630
7631 # Currently NOT testing backward for channels last backward
7632 cpu_grad = None
7633 grad = None
7634
Thomas4935b592022-11-23 02:18:03 +00007635 if (not channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007636 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7637 grad = cpu_grad.to('mps')
7638
7639 softmax_result.backward(gradient=grad)
7640 softmax_result_cpu.backward(gradient=cpu_grad)
7641
7642 self.assertEqual(softmax_result, softmax_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00007643 if (not channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007644 self.assertEqual(x.grad, cpu_x.grad)
7645
7646 def helper2(dim):
7647 cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
7648 x = cpu_x.detach().clone().to('mps').requires_grad_()
7649
7650 softmax_result = torch.nn.functional.softmax(x, dim=dim)
7651 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7652
7653 cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
7654 grad = cpu_grad.to('mps')
7655
7656 softmax_result.backward(gradient=grad)
7657 softmax_result_cpu.backward(gradient=cpu_grad)
7658
7659 self.assertEqual(softmax_result, softmax_result_cpu)
7660 self.assertEqual(x.grad, cpu_x.grad)
7661
7662 helper2(0)
7663
Kulin Seth3d833212022-05-20 03:18:09 +00007664 for channels_last in [False]:
Kulin Sethe011a8e2022-05-13 18:28:53 +00007665 for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
Thomas4935b592022-11-23 02:18:03 +00007666 if (len(shape) != 4 and channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007667 continue
7668 for dim in [0, 1, 2, 3, -1, -2, -3]:
7669 helper(shape, dim, channels_last)
7670
Ramin Azarmehr229f12b2023-01-05 02:17:48 +00007671 def test_nan_to_num(self):
7672 inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
7673 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
7674 outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0)
7675 outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0)
7676 self.assertEqual(outputMPS, outputCPU)
7677
Kulin Sethe011a8e2022-05-13 18:28:53 +00007678 # Test where
7679 def test_where(self):
7680 def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
7681
7682 cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
7683 cond = cpu_cond.detach().clone().to('mps')
7684
7685 cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7686 x = cpu_x.detach().clone().to('mps').requires_grad_()
7687
7688 cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7689 y = cpu_y.detach().clone().to('mps').requires_grad_()
7690
7691 cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
7692 out = torch.where(cond, x, y)
7693
7694 cpu_grad = torch.randn(cpu_out.shape)
7695 grad = cpu_grad.to('mps')
7696
7697 cpu_out.backward(gradient=cpu_grad)
7698 out.backward(gradient=grad)
7699
7700 self.assertEqual(out, cpu_out)
7701 self.assertEqual(x.grad, cpu_x.grad)
7702 self.assertEqual(y.grad, cpu_y.grad)
7703
7704 for shape in ([(0, 3), [], (2, 3), (9,)]):
7705 helper(shape, shape, shape)
7706
7707 helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
7708 helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
7709 helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
7710 helper([], (1, 1, 4), (2, 3, 1))
7711 helper([], (2, 3, 4), [])
Alexca69ddb2022-10-07 01:38:57 +00007712 helper((5, 2, 3), (2, 3), (2, 3))
7713 helper((2, 3), (5, 2, 3), (2, 3))
7714 helper((2, 3), (2, 3), (5, 2, 3))
7715 helper((2, 3), (5, 2, 3), (6, 5, 2, 3))
Nikita Shulga9b03a062024-03-08 07:25:49 -08007716 # Test that output is correctly resizes
7717 # TODO: Remove me when out OpInfo testing is enabled on MPS
7718 output = torch.tensor(0.0, device="mps")
7719 cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps")
7720 inp = torch.rand(3, 3, device="mps")
7721 other = torch.rand(3, 3, device="mps")
7722 out = torch.where(cond, inp, other, out=output)
7723 self.assertEqual(id(out), id(output))
7724 self.assertEqual(out.shape, (3, 3))
Kulin Sethe011a8e2022-05-13 18:28:53 +00007725
7726 # Test normal
7727 def test_normal(self):
7728 def helper(shape, mean=0.0, std=1.0):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007729 mps_out = torch.normal(mean, std, shape, device='mps')
7730
Kulin Sethe011a8e2022-05-13 18:28:53 +00007731 mean_array = np.ones(shape)
7732 mean_array *= mean
7733 cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
7734 mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
7735
7736 std_array = np.ones(shape)
7737 std_array *= std
7738 cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
7739 std_tensor = cpu_std_tensor.detach().clone().to('mps')
7740
qqaatwe1b15b72022-06-28 15:19:39 +00007741 # test out
Kulin Sethe011a8e2022-05-13 18:28:53 +00007742 mps_out = torch.zeros(shape, device='mps')
7743 torch.normal(mean_tensor, std, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007744
7745 mps_out = torch.zeros(shape, device='mps')
7746 torch.normal(mean, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007747
7748 mps_out = torch.zeros(shape, device='mps')
7749 torch.normal(mean_tensor, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007750
qqaatwe1b15b72022-06-28 15:19:39 +00007751 # test without out
7752 mps_out = torch.normal(mean_tensor, std)
7753 self.assertEqual(mps_out.size(), mean_tensor.size())
7754
7755 mps_out = torch.normal(mean, std_tensor)
7756 self.assertEqual(mps_out.size(), std_tensor.size())
7757
7758 inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
7759 mps_out = torch.normal(mean_tensor, std_tensor)
7760 self.assertEqual(mps_out.size(), inferred_shape)
7761
Kulin Sethe011a8e2022-05-13 18:28:53 +00007762 helper((2, 3, 4, 5, 6))
7763 helper((100, 100), 2.5, 1.2)
7764
7765 def test_bernoulli(self):
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007766 shape = (10, 10)
7767 all_ones = torch.ones(shape, device='mps')
7768 all_zeros = torch.zeros(shape, device='mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007769
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007770 prob_tensor = all_ones * 0.5
7771 # probability of drawing "1" is 0.5
7772 mps_out = torch.bernoulli(prob_tensor)
7773 # We can't check reliably the mean and std.
7774 # Just make sure we don't return constant values
7775 self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
7776 self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007777
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007778 # probability of drawing "1" is 0
7779 mps_out = torch.bernoulli(all_zeros)
7780 self.assertEqual(mps_out, all_zeros)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007781
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007782 # probability of drawing "1" is 1
7783 mps_out = torch.bernoulli(all_ones)
7784 self.assertEqual(mps_out, all_ones)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007785
Nikita Shulgab7bf9532023-05-11 23:52:38 +00007786 # Check it works for different dtypes
7787 for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]:
7788 mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5)
7789 # Check that output is not all zeros or ones
7790 if product_version > 13.0:
7791 uniq = mps_out.unique()
Nikita Shulga9e089db2023-05-13 01:19:08 +00007792 self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
Nikita Shulgab7bf9532023-05-11 23:52:38 +00007793 else:
7794 self.assertEqual(mps_out.min().item(), 0.)
7795 self.assertEqual(mps_out.max().item(), 1.)
7796
Ramin Azarmehr688e3512023-01-03 16:01:19 +00007797 def test_mps_generator(self):
7798 # explicit manual seeding by creating an MPS Generator
7799 g_mps = torch.Generator(device='mps')
7800 g_mps.manual_seed(999)
7801 mps_x = torch.randn(5, device='mps', generator=g_mps)
7802 g_mps.manual_seed(999)
7803 mps_y = torch.randn(5, device='mps', generator=g_mps)
7804 # seed values were the same, so the random tensor contents should match
7805 self.assertEqual(mps_x, mps_y)
7806 # save generator's state to restore it later
7807 g_state = g_mps.get_state()
7808
7809 # generate random numbers without seeding
7810 mps_x = torch.randn(5, device='mps', generator=g_mps)
7811 # in this case, the random results must differ from the last generated random results
7812 self.assertNotEqual(mps_x, mps_y)
7813
7814 # restore the previously saved state, and the results should match again
7815 g_mps.set_state(g_state)
7816 mps_x = torch.randn(5, device='mps', generator=g_mps)
7817 self.assertEqual(mps_x, mps_y)
7818
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007819 def test_default_mps_generator(self):
7820 # manual seeding on the "default" MPS generator using
7821 # the global torch.manual_seed()
7822 torch.manual_seed(230)
7823 mps_x = torch.randn(5, device='mps')
7824 # manual seeding using torch.mps.manual_seed()
7825 # which should set the "default" MPS generator
7826 # like the global torch.manual_seed()
7827 torch.mps.manual_seed(230)
7828 mps_y = torch.randn(5, device='mps')
7829 # seed values were the same, so the random tensor contents should match
7830 self.assertEqual(mps_x, mps_y)
7831
7832 # save the default generator's state to restore it later
7833 g_state = torch.mps.get_rng_state()
7834
7835 # generate random numbers without seeding
7836 mps_x = torch.randn(5, device='mps')
7837 # in this case, the random results must differ from the last generated random results
7838 self.assertNotEqual(mps_x, mps_y)
7839
7840 # restore the previously saved state, and the results should match again
7841 torch.mps.set_rng_state(g_state)
7842 mps_x = torch.randn(5, device='mps')
7843 self.assertEqual(mps_x, mps_y)
7844
7845 def test_device_synchronize(self):
7846 # just running some ops each followed by a synchronize to wait for
7847 # MPS stream to finish running each of them
7848 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
7849 .to(device='mps', dtype=torch.float)
7850
7851 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7852 torch.mps.synchronize()
7853 x = net1(x)
7854 torch.mps.synchronize()
7855 x.backward(torch.randn_like(x))
7856 torch.mps.synchronize()
7857
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00007858 @unittest.expectedFailure
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00007859 def test_mps_allocator_module(self):
7860 # first garbage collect and empty the cached blocks
7861 gc.collect()
7862 torch.mps.empty_cache()
7863 # measure memory allocations from MPSAllocator
7864 current_alloc_before = torch.mps.current_allocated_memory()
7865 # after garbage collection and emptying the cache the
7866 # current_allocated_memory must be zero
7867 self.assertTrue(current_alloc_before == 0)
7868 # measure total memory allocations from Metal driver
7869 driver_alloc_before = torch.mps.driver_allocated_memory()
7870 # allocate a new 8 MB tensor to force allocation of a new Metal Heap
7871 x = torch.ones(1024 * 1024 * 8, device="mps")
7872 # get memory allocations after allocating tensor x
7873 current_alloc_after = torch.mps.current_allocated_memory()
7874 driver_alloc_after = torch.mps.driver_allocated_memory()
7875 # current and driver memory allocations must have
7876 # grown at this point
7877 self.assertTrue(current_alloc_after > current_alloc_before)
7878 self.assertTrue(driver_alloc_after > driver_alloc_before)
7879
Kulin Seth8df56af2024-06-12 16:03:57 +00007880 def test_mps_allocator_stats(self):
7881 max_memory = torch.mps.recommended_max_memory()
7882 print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB")
7883 self.assertTrue(max_memory > 0)
7884
Ramin Azarmehr0be53d82023-05-12 21:55:34 +00007885 # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
7886 # press record, then run this python test, and press stop. Next expand
7887 # the os_signposts->PyTorchMPS and check if events or intervals are logged
7888 # like this example:
7889 # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
7890 def test_mps_profiler_module(self):
7891 with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
7892 # just running some ops to capture the OS Signposts traces for profiling
7893 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
7894 .to(device='mps', dtype=torch.float)
7895 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7896 x = net1(x)
7897
7898 torch.mps.profiler.start(mode="interval", wait_until_completed=True)
7899 # just running some ops to capture the OS Signposts traces for profiling
7900 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7901 x = net1(x)
7902 torch.mps.profiler.stop()
7903
Ramin Azarmehrcdfd0ea2023-08-08 03:45:45 +00007904 def test_mps_event_module(self):
7905 startEvent = torch.mps.Event(enable_timing=True)
7906 startEvent.record()
7907 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
7908 .to(device='mps', dtype=torch.float)
7909 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7910 x = net1(x)
7911 endEvent = torch.mps.Event(enable_timing=True)
7912 endEvent.record()
7913 elapsedTime = startEvent.elapsed_time(endEvent)
7914 self.assertTrue(elapsedTime > 0.0)
7915
Denis Vieriude7ec2d2023-05-25 23:32:29 +00007916 def test_jit_save_load(self):
7917 m = torch.nn.Module()
7918 m.x = torch.rand(3, 3, device='mps')
7919 buffer = io.BytesIO()
7920 torch.jit.save(torch.jit.script(m), buffer)
7921 buffer.seek(0)
7922 n = torch.jit.load(buffer)
7923 self.assertEqual(n.x, m.x)
7924
Nikita Shulga29cde002023-04-05 21:24:45 +00007925 # Test random_, random_.to and random_.from
Kulin Sethe011a8e2022-05-13 18:28:53 +00007926 def test_random(self):
7927 def helper(shape, low, high, dtype=torch.int32):
7928
Kulin Sethe011a8e2022-05-13 18:28:53 +00007929 mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
7930
Alban Desmaison02551a02022-05-28 12:39:10 -04007931 # We can't check reliably the mean and std.
7932 # Just make sure we don't return constant values
Nikita Shulga29cde002023-04-05 21:24:45 +00007933 self.assertNotEqual(mps_out.float().mean().item(), 0.)
7934 self.assertNotEqual(mps_out.float().std().item(), 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007935
7936 helper([100, 100], 0, 10)
7937 helper([100, 100], 23, 89)
7938 helper([100, 100], 23, 89, dtype=torch.float32)
7939 helper([100, 100], 23, 89, dtype=torch.int64)
7940 helper([100, 100], 0, 2, dtype=torch.bool)
7941
Nikita Shulga29cde002023-04-05 21:24:45 +00007942 # Test random_
7943 for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]:
7944 x = torch.empty(10, 10, dtype=dtype, device='mps')
7945 x.random_()
7946 self.assertNotEqual(x.max().item(), 0)
7947
Kulin Seth83239352022-06-10 13:16:21 +00007948 # Test exponential
7949 def test_exponential(self):
7950 def helper(shape, lamda, dtype=torch.float32):
7951
7952 mps_out = torch.zeros(shape, device='mps', dtype=dtype)
7953 mps_out.exponential_(lamda)
7954
7955 print(mps_out.to('cpu').float().mean(), 1 / lamda)
7956 print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2))
7957
7958 for dtype in [torch.float32, torch.float16]:
7959 helper([100, 100], 2, dtype)
7960 helper([100, 100], 1, dtype)
7961 helper([100, 100], 3, dtype)
7962 helper([100, 100], 0.5, dtype)
7963
7964 def test_exponential_1(self):
7965 rate = torch.randn(5, 5).abs().requires_grad_()
7966 rate_1d = torch.randn(1).abs().requires_grad_()
7967 self.assertEqual(Exponential(rate).sample().size(), (5, 5))
7968 self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
7969 self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
7970 self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
7971 self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
7972 self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
7973
Kulin Sethe011a8e2022-05-13 18:28:53 +00007974 # Test add
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007975 def test_add_sub(self):
7976 def helper(shape, alpha, op_name, inplace):
7977 if op_name == "add":
7978 op = torch.Tensor.add_ if inplace else torch.add
7979 elif op_name == "sub":
7980 op = torch.Tensor.sub_ if inplace else torch.sub
7981
Kulin Setha6347f52022-06-07 18:22:10 +00007982 for dtype in [torch.float16, torch.float32]:
7983 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
7984 mps_x = cpu_x.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007985
Kulin Setha6347f52022-06-07 18:22:10 +00007986 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
7987 mps_y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007988
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007989 cpu_out = op(cpu_x, cpu_y, alpha=alpha)
7990 mps_out = op(mps_x, mps_y, alpha=alpha)
Kulin Setha6347f52022-06-07 18:22:10 +00007991 # fp16 isn't accurate when alpha is passed
7992 # TODO: remove or fix 'tol' when we fix problems with fp16
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007993 tol = 2e-3 if dtype is torch.float16 else None
Kulin Setha6347f52022-06-07 18:22:10 +00007994 self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol)
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007995 if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted.
7996 # create a scalar tensor
7997 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
7998 mps_s = cpu_s.detach().clone().to('mps')
7999 # primary tensor is scalar
8000 self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y))
Kulin Setha6347f52022-06-07 18:22:10 +00008001 # create a scalar tensor
8002 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8003 mps_s = cpu_s.detach().clone().to('mps')
Kulin Setha6347f52022-06-07 18:22:10 +00008004 # secondary tensor is scalar
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008005 self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008006
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008007
8008 for op_name, inplace in product(["add", "sub"], [True, False]):
8009 helper((), 0.0, op_name, inplace)
8010 helper((2, 8, 4, 5), 0.0, op_name, inplace)
8011 helper((2, 8, 4, 5), 0.1, op_name, inplace)
8012 helper((2, 8, 4, 5), 1.0, op_name, inplace)
8013 helper((2, 8, 3, 5), 0.1, op_name, inplace)
8014 helper((2, 8, 3, 5), 0.2, op_name, inplace)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008015
8016 # Test add
8017 def test_add_scalars(self):
Kulin Setha6347f52022-06-07 18:22:10 +00008018 def helper(alpha):
8019 for dtype in [torch.float16, torch.float32]:
8020 cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8021 x = cpu_x.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00008022
Kulin Setha6347f52022-06-07 18:22:10 +00008023 cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False)
8024 y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00008025
Kulin Setha6347f52022-06-07 18:22:10 +00008026 cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
8027 out = torch.add(x, y, alpha=alpha)
8028 # fp16 isn't accurate when alpha is passed
8029 tol = 1e-3 if dtype is torch.float16 else None
8030 self.assertEqual(out, cpu_out, rtol=tol, atol=tol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008031
Kulin Setha6347f52022-06-07 18:22:10 +00008032 helper(1.0)
8033 helper(0.0)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008034 helper(0.1)
8035 helper(0.2)
8036
Nikita Shulga06f874e2022-06-25 02:21:34 +00008037 # Test int32 tensor + int64 scalar add
8038 # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
8039 x = torch.ones(4, dtype=torch.int32, device='mps')
8040 self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
PyTorch MergeBotcba96362022-12-02 21:36:13 +00008041 self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
Nikita Shulga06f874e2022-06-25 02:21:34 +00008042
Kulin Seth50f7b402022-06-09 17:33:06 +00008043 def test_types_binary_op(self):
8044 # Float * Bool
8045 cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
8046 mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
8047 self.assertEqual(cpu_x, mps_x)
8048 # Float * Int64
8049 cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
8050 mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
8051 self.assertEqual(cpu_y, mps_y)
8052
Kulin Setha6347f52022-06-07 18:22:10 +00008053 def test_unary_ops(self):
8054 def helper(shape, op):
8055 for dtypef in [torch.float32]:
8056 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8057 mps_x = cpu_x.detach().clone().to('mps')
8058 self.assertEqual(op(cpu_x), op(mps_x))
8059
8060 for dtypei in [torch.int32, torch.int16]:
8061 cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False)
8062 mps_x = cpu_x.to('mps')
8063 self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4)
Peter Pham74dfdc52023-12-15 23:14:03 +00008064 # test slice
8065 for dtypef in [torch.float32]:
8066 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8067 mps_x = cpu_x.detach().clone().to('mps')
8068 cpu_slice = cpu_x[:, ::2, :, :]
8069 mps_slice = mps_x[:, ::2, :, :]
8070 self.assertEqual(op(cpu_slice), op(mps_slice))
8071 # test view
8072 for dtypef in [torch.float32]:
8073 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8074 mps_x = cpu_x.detach().clone().to('mps')
8075 # create view of tensor by reducing the 3rd and 4th dimension
8076 combined_dim = shape[-1] * shape[-2]
8077 reshaped_dims = list(shape[:-2]) + [combined_dim]
8078 cpu_view = cpu_x.view(*reshaped_dims)
8079 mps_view = mps_x.view(*reshaped_dims)
8080 self.assertEqual(op(cpu_view), op(mps_view))
Kulin Setha6347f52022-06-07 18:22:10 +00008081
8082 helper((2, 8, 4, 5), torch.exp)
8083 helper((2, 8, 3, 5), torch.exp2)
arnaudstiegler16e35bd2022-10-26 17:45:46 +00008084 helper((2, 8, 3, 5), torch.expm1)
Kulin Setha6347f52022-06-07 18:22:10 +00008085 helper((2, 8, 3, 5), torch.log)
8086 helper((2, 8, 3, 5), torch.cos)
Peter Phambba06ad2023-07-23 01:36:43 +00008087 helper((2, 8, 3, 5), torch.erfinv)
8088
Kulin Setha6347f52022-06-07 18:22:10 +00008089
Peter Stefekd2c24ec2023-07-19 03:56:35 +00008090 def test_non_dense_in_storage_unary_ops(self):
8091 def helper(op):
8092 for dtypef in [torch.float32]:
8093 cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False)
8094 mps_x = cpu_x.detach().clone().to('mps')
8095 self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]))
8096
8097 for dtypei in [torch.int32, torch.int16, torch.int8]:
8098 cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False)
8099 mps_x = cpu_x.to('mps')
8100 self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4)
8101
8102 helper(torch.exp)
8103 helper(torch.exp2)
8104 helper(torch.expm1)
8105 helper(torch.log)
8106 helper(torch.cos)
8107
Li-Huai (Allan) Lin538114d2023-11-14 22:03:21 +00008108 def test_unary_ops_storage_offset_strided(self):
8109 def helper(shape, op, inplace, dtype=torch.float32):
8110 # test in-place with storage_offset
8111 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8112 mps_x = cpu_x.detach().clone().to('mps')
8113 y = op(mps_x[1])
8114 cpu_y = op(cpu_x[1])
8115 self.assertEqual(y, cpu_y)
8116
8117
8118 # See https://github.com/pytorch/pytorch/issues/100764
8119 if not inplace:
8120 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8121 mps_x = cpu_x.detach().clone().to('mps')
8122 cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t()
8123 mps_y = cpu_y.detach().clone().to('mps')
8124 op(cpu_x, out=cpu_y)
8125 op(mps_x, out=mps_y)
8126 self.assertEqual(mps_y, cpu_y)
8127
8128
8129 helper((5, 5), torch.exp, False)
8130 helper((5, 5), torch.cos, False)
8131 helper((5, 5), torch.neg, False)
8132 helper((5, 5), torch.tanh, False)
8133 helper((5, 5), torch.tanh_, True)
8134
Kulin Sethe011a8e2022-05-13 18:28:53 +00008135 def test_atan2(self):
8136 def helper(shape):
8137 input_cpu = torch.randn(shape)
8138 input_mps = input_cpu.detach().clone().to("mps")
8139
8140 other_cpu = torch.randn(shape)
8141 other_mps = other_cpu.detach().clone().to("mps")
8142
8143 atan2_cpu = torch.atan2(input_cpu, other_cpu)
8144 atan2_mps = torch.atan2(input_mps, other_mps)
8145
8146 self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
8147
8148 helper(4)
8149 helper(10000)
8150 helper((10000, 40))
8151
Kulin Seth6a842e32022-10-03 21:05:30 +00008152 def test_multinomial(self):
8153 # Test with num_dist = 1
8154 def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
8155 cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
8156 prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
8157
8158 mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
Thomas4935b592022-11-23 02:18:03 +00008159 if (not replacement):
Kulin Seth6a842e32022-10-03 21:05:30 +00008160 print(mps_out.to('cpu'))
8161 else:
8162 # Compare "real" with theoretical values
8163 print(mps_out.to('cpu').float().mean(), compare_mean)
8164 print(mps_out.to('cpu').float().std() ** 2, compare_var)
8165
8166 # TODO: Add tests for data types
8167 helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
8168 helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8169 helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8170 helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8171 helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008172
Nikita Shulga10a1efb2023-02-05 18:21:29 +00008173 def test_cumsum_dim_check(self):
8174 x = torch.rand((3, 3), device="mps")
8175 self.assertEqual(x.cumsum(1), x.cumsum(-1))
8176 self.assertEqual(x.cumsum(0), x.cumsum(-2))
8177 self.assertRaises(IndexError, lambda: x.cumsum(2))
8178 self.assertRaises(IndexError, lambda: x.cumsum(-3))
8179
Peter Stefek97e50552023-08-01 21:51:16 +00008180 def test_cumprod_dim_check(self):
8181 x = torch.rand((3, 3), device="mps")
8182 self.assertEqual(x.cumprod(1), x.cumprod(-1))
8183 self.assertEqual(x.cumprod(0), x.cumprod(-2))
8184 self.assertRaises(IndexError, lambda: x.cumprod(2))
8185 self.assertRaises(IndexError, lambda: x.cumprod(-3))
8186
Li-Huai (Allan) Lin88a659e2023-11-08 16:19:38 -08008187class TestLogical(TestCaseMPS):
8188 def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
8189 return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
8190
8191 def test_logical_not(self):
8192 def helper(x):
8193 cpu_x = x
8194 x = cpu_x.detach().clone().to('mps')
8195
8196 result = torch.logical_not(x)
8197 result_cpu = torch.logical_not(cpu_x)
8198
8199 self.assertEqual(result, result_cpu)
8200
8201 helper(self._wrap_tensor([1, 1, 0, 0]))
8202 helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True))
8203 helper(self._wrap_tensor([True, True, False, False]))
8204 helper(self._wrap_tensor(1))
8205 helper(self._wrap_tensor(0))
8206 helper(self._wrap_tensor(True))
8207 helper(self._wrap_tensor(False))
8208
8209 def test_logical_and(self):
8210 def helper(x, other):
8211 cpu_x = x
8212 x = cpu_x.detach().clone().to('mps')
8213
8214 cpu_other = other
8215 other = cpu_other.detach().clone().to('mps')
8216
8217 result = torch.logical_and(x, other)
8218 result_cpu = torch.logical_and(cpu_x, cpu_other)
8219 self.assertEqual(result, result_cpu)
8220
8221 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8222 helper(
8223 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8224 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8225 )
8226 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8227 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8228 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8229 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8230 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8231
8232 def test_logical_or(self):
8233 def helper(x, other):
8234 cpu_x = x
8235 x = cpu_x.detach().clone().to('mps')
8236
8237 cpu_other = other
8238 other = cpu_other.detach().clone().to('mps')
8239
8240 result = torch.logical_or(x, other)
8241 result_cpu = torch.logical_or(cpu_x, cpu_other)
8242
8243 self.assertEqual(result, result_cpu)
8244
8245 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8246 helper(
8247 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8248 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8249 )
8250 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8251 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8252 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8253 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8254 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8255
8256 def test_logical_xor(self):
8257 def helper(x, other):
8258 cpu_x = x
8259 x = cpu_x.detach().clone().to('mps')
8260
8261 cpu_other = other
8262 other = cpu_other.detach().clone().to('mps')
8263
8264 result = torch.logical_xor(x, other)
8265 result_cpu = torch.logical_xor(cpu_x, cpu_other)
8266
8267 self.assertEqual(result, result_cpu)
8268
8269 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8270 helper(
8271 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8272 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8273 )
8274 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8275 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8276 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8277 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8278 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8279
8280 def test_min_max(self):
8281 def helper(dtype):
8282 for _ in range(10):
8283 if dtype == torch.float32 or dtype == torch.float16:
8284 x = torch.randn((30, 15), device='mps', dtype=dtype)
8285 else:
8286 x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
8287 x_cpu = x.to("cpu")
8288
8289 y = x.max()
8290 y_cpu = x_cpu.max()
8291 self.assertEqual(y, y_cpu)
8292
8293 z = x.min()
8294 z_cpu = x_cpu.min()
8295 self.assertEqual(z, z_cpu)
8296
8297 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
8298
Joona Havukainenc451d102024-05-01 23:14:05 +00008299 def test_isin(self):
8300 def helper(dtype):
8301 shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),
8302 ([5], [10]), ([0], [5]), ([5], [0])]
8303 for shape_tuple in shapes:
8304 for inverted in [True, False]:
8305 if dtype.is_floating_point:
8306 # Half is not supported for CPU isin. Compute reference in FP32
8307 A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32)
8308 B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32)
8309 else:
8310 A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype)
8311 B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype)
8312
8313 A_mps = A.clone().detach().to('mps')
8314 B_mps = B.clone().detach().to('mps')
8315
8316 cpu_ref = torch.isin(A, B, invert=inverted)
jhavukainend28868c2024-05-20 20:23:53 +00008317 if dtype in [torch.float16, torch.bfloat16]:
Joona Havukainenc451d102024-05-01 23:14:05 +00008318 cpu_ref.type(dtype)
8319
8320 mps_out = torch.isin(A_mps, B_mps, invert=inverted)
8321 self.assertEqual(mps_out, cpu_ref)
8322
jhavukainend28868c2024-05-20 20:23:53 +00008323 dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8]
8324 if product_version < 14.0:
8325 # Int types expected to fail on MacOS < 14.0
8326 dtypes = [torch.float32, torch.float16, torch.bfloat16]
Joona Havukainenc451d102024-05-01 23:14:05 +00008327
jhavukainend28868c2024-05-20 20:23:53 +00008328 [helper(dtype) for dtype in dtypes]
8329
Joona Havukainenc451d102024-05-01 23:14:05 +00008330 def test_isin_asserts(self):
8331 A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8332 B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16)
8333 with self.assertRaisesRegex(RuntimeError, 'Expected elements.dtype()*'):
8334 out = torch.isin(A, B)
8335
8336
8337 C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8338 D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32)
8339 with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'):
8340 out = torch.isin(C, D)
8341
Li-Huai (Allan) Lin88a659e2023-11-08 16:19:38 -08008342class TestSmoothL1Loss(TestCaseMPS):
8343
8344 def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
8345 # CPU
8346 input_cpu = torch.randn(4, 7, requires_grad=requires_grad)
8347 target_cpu = torch.randn(4, 7)
8348
8349 # MPS
8350 input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
8351 target_mps = target_cpu.detach().clone().to('mps')
8352
8353 smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
8354 smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
8355
8356 self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
8357
8358 if requires_grad:
8359 smooth_l1_loss_cpu.backward()
8360 smooth_l1_loss_mps.backward()
8361 self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
8362
8363 return smooth_l1_loss_cpu, smooth_l1_loss_mps
8364
8365 def test_smooth_l1_loss_reduction_none(self):
8366 self._smooth_l1_loss_helper(reduction="none")
8367
8368 def test_smooth_l1_loss_reduction_mean(self):
8369 self._smooth_l1_loss_helper(reduction="mean")
8370
8371 def test_smooth_l1_loss_reduction_sum(self):
8372 self._smooth_l1_loss_helper(reduction="sum")
8373
8374 def test_smooth_l1_loss_reduction_mean_backward(self):
8375 self._smooth_l1_loss_helper(reduction="mean", requires_grad=True)
8376
8377 def test_smooth_l1_loss_reduction_mean_sum_backward(self):
8378 self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
8379
8380class TestNLLLoss(TestCaseMPS):
8381 def test_nll_loss_mismatched_batch(self, device='mps'):
8382 x = torch.randn((10, 3), requires_grad=True, device=device)
8383 # t should have size (10,)
8384 t = torch.zeros((3,), dtype=torch.int64, device=device)
8385 with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
8386 F.nll_loss(x, t)
8387
8388 def test_nll_loss_out_of_bounds_ignore_index(self):
8389
8390 def test_nll_loss_out_of_bounds_ignore_index_helper(device):
8391 output = []
8392 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8393 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8394 t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
8395 t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device)
8396 for reduction in ['mean', 'none']:
8397 # out of bound ignore_index
8398 output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction))
8399 # default ignore_index
8400 output.append(F.nll_loss(x, t2, reduction=reduction))
8401 return output
8402
8403 output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu')
8404 output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps')
8405
8406 for cpu, mps in zip(output_cpu, output_mps):
8407 self.assertEqual(cpu, mps)
8408
8409 def test_nll_loss_invalid_target_dim(self):
8410
8411 def _test_nll_loss_invalid_target_dim(device):
8412 output = []
8413 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8414 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8415 t = torch.zeros((6, 2), dtype=torch.int64, device=device)
8416 with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
8417 F.nll_loss(x, t)
8418
8419 _test_nll_loss_invalid_target_dim(device='cpu')
8420 _test_nll_loss_invalid_target_dim(device='mps')
8421
8422 def test_nll_loss_invalid_weights(self):
8423
8424 def _test_nll_loss_invalid_weights(device):
8425 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8426 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8427 t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
8428 invalid_weights = [
8429 torch.zeros(4, device=device),
8430 torch.zeros((1, 3), device=device),
8431 ]
8432 msg = "weight tensor should be defined either for all 3 classes or no classes"
8433 for weight in invalid_weights:
8434 with self.assertRaisesRegex(RuntimeError, msg):
8435 F.nll_loss(x, t, weight=weight)
8436
8437 _test_nll_loss_invalid_weights(device='cpu')
8438 _test_nll_loss_invalid_weights(device='mps')
8439
8440 def _nll_loss_helper(self, input_size, reduction, expected):
8441
8442 # CPU
8443 input = torch.rand(input_size, requires_grad=True, device='cpu')
8444 num_channels = input_size[1]
8445 target_size = (input_size[0], ) + tuple(input_size[2:])
8446 target = torch.randint(num_channels, target_size, device='cpu')
8447 weights = torch.randn(num_channels)
8448
8449 # MPS
8450 input_mps = input.detach().clone().to('mps').requires_grad_()
8451 target_mps = target.detach().clone().to('mps')
8452 weights_mps = weights.to("mps")
8453
8454 output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
8455 output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
8456 self.assertEqual(output_cpu, output_mps.to('cpu'))
8457
8458 output_cpu.sum().backward()
8459 output_mps.sum().backward()
8460 self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8461
8462 def _nll_loss_1d_helper(self, input_size, reduction):
8463
8464 # CPU
8465 input = torch.rand(input_size, requires_grad=True, device='cpu')
8466 num_channels = input_size[0]
8467 target = torch.randint(num_channels, [], device='cpu')
8468
8469 # MPS
8470 input_mps = input.detach().clone().to('mps').requires_grad_()
8471 target_mps = target.detach().clone().to('mps')
8472
8473 output_cpu = F.nll_loss(input, target, reduction=reduction)
8474 output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
8475 self.assertEqual(output_cpu, output_mps.to('cpu'))
8476
8477 output_cpu.sum().backward()
8478 output_mps.sum().backward()
8479 self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8480
8481 def test_nll_loss_1d(self, device='cpu'):
8482 self._nll_loss_1d_helper([10], "none")
8483 self._nll_loss_1d_helper([10], "mean")
8484 self._nll_loss_1d_helper([10], "sum")
8485
8486 def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
8487 self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
8488 self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
8489 self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
8490 self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
8491 self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
8492
8493 def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
8494 nan = torch.tensor(float('nan'), device=device)
8495 self._nll_loss_helper([1, 3], "mean", nan)
8496 self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
8497 self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
8498 self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
8499 self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
8500
8501 def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
8502 zero = torch.tensor(0, device=device)
8503 self._nll_loss_helper([1, 3], "sum", zero)
8504 self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
8505 self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
8506 self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
8507 self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
8508
8509 def test_nll_loss_byte_target_matches_long(self, device='cpu'):
8510 N, C = 10, 4
8511 input = torch.randn(N, C, device=device, requires_grad=True)
8512 target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
8513
8514 def compute_result_and_gradient(reduction, target_dtype):
8515 result, grad = {}, {}
8516 for dev in ['cpu', 'mps']:
8517 input_dev = input.to(dev)
8518 input_ = input_dev.detach()
8519 input_.requires_grad_()
8520
8521 target_dev = target.to(dev)
8522
8523 prob = F.log_softmax(input_, dim=-1)
8524 loss = nn.NLLLoss(reduction=reduction)
8525 result[dev] = loss(prob, target_dev.to(target_dtype))
8526 result[dev].sum().backward()
8527 grad[dev] = input_.grad
8528
8529 return result, grad
8530
8531 for reduction in ["none", "mean", "sum"]:
8532 result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
8533 result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
8534
8535 self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
8536 self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
Soof Golane4fe11e2023-02-09 10:42:48 +00008537
8538class TestTopK(TestCase):
8539 def _test_topk(self, shape, largest):
8540 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
8541 x = cpu_x.detach().clone().to('mps')
8542 if isinstance(shape, tuple):
8543 for curr_dim, dim_size in enumerate(shape):
8544 for k in range(1, dim_size + 1):
8545 topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
8546 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
8547 self.assertEqual(topk_values, topk_values_cpu)
8548 self.assertEqual(topk_indices, topk_indices_cpu)
8549 else:
8550 for k in range(1, shape):
8551 topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
8552 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
8553 self.assertEqual(topk_values, topk_values_cpu)
8554 self.assertEqual(topk_indices, topk_indices_cpu)
8555
8556 def test_topk(self):
8557 largest_vals = [True, False]
8558 shapes = [
8559 # Zero Element Tensors
8560 0,
8561 (1, 0),
8562 (0, 1),
8563 (1, 0, 1),
8564 # Multiple Element Tensors
8565 1,
8566 2,
8567 (5, 1),
8568 (1, 5),
8569 (5, 9, 7, 4),
8570 ]
8571
8572 for shape in shapes:
8573 for largest_val in largest_vals:
8574 with self.subTest(shape=shape, largest_val=largest_val):
8575 self._test_topk(shape, largest_val)
8576
Kulin Sethe011a8e2022-05-13 18:28:53 +00008577class TestNNMPS(NNTestCase):
8578
8579 def _create_basic_net(self):
8580 class Layer(nn.Module):
8581 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008582 super().__init__()
Kulin Sethe011a8e2022-05-13 18:28:53 +00008583 self.layer_dummy_param = Parameter(torch.empty(3, 5))
Jason Lubc880282023-08-08 15:27:34 +00008584 self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
Kulin Sethe011a8e2022-05-13 18:28:53 +00008585
8586 class Net(nn.Module):
8587 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008588 super().__init__()
Kulin Sethe011a8e2022-05-13 18:28:53 +00008589 self.l1 = Layer()
8590 self.dummy_param = Parameter(torch.empty(3, 5))
Jason Lubc880282023-08-08 15:27:34 +00008591 self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
Kulin Sethe011a8e2022-05-13 18:28:53 +00008592
8593 l = Layer()
8594 n = Net()
8595 s = nn.Sequential(n, n)
8596
8597 return l, n, s
8598
8599 def test_requires_grad_(self):
8600 m = self._create_basic_net()[-1]
8601 assert len(list(m.buffers())) > 0, 'invalid test'
8602 assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
8603 assert len(list(m.parameters())) > 0, 'invalid test'
8604 assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
8605 for requires_grad in (False, True):
8606 self.assertIs(m.requires_grad_(requires_grad), m)
8607 for p in m.parameters():
8608 self.assertEqual(p.requires_grad, requires_grad)
8609 for b in m.buffers():
8610 self.assertFalse(b.requires_grad)
8611
8612 def test_module_backcompat(self):
8613 from torch.serialization import SourceChangeWarning
8614 path = download_file('https://download.pytorch.org/test_data/linear.pt')
8615 with warnings.catch_warnings():
8616 warnings.simplefilter('ignore', SourceChangeWarning)
8617 m = torch.load(path)
8618 input = torch.randn(2, 3, dtype=torch.float)
8619 self.assertEqual(m(input).size(), (2, 5))
8620
8621 def test_conv_backcompat(self):
8622 from torch.serialization import SourceChangeWarning
8623 # This file was generated by running on PyTorch 1.0.1 on Python 2:
8624 #
8625 # import torch
8626 # from torch import nn
8627 # m = nn.Conv2d(1, 1, 1)
8628 # torch.save(m, 'legacy_conv2d.pt')
8629 #
8630 # NB: This Pickle also contains some Unicode data!
8631 path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
8632 with warnings.catch_warnings():
8633 warnings.simplefilter('ignore', SourceChangeWarning)
8634 m = torch.load(path, encoding='utf-8')
8635 input = torch.randn((1, 1, 1, 1), dtype=torch.float)
8636 self.assertEqual(m(input).size(), (1, 1, 1, 1))
8637
Kulin Seth017b0ae2022-05-31 02:09:03 +00008638 def test_conv_expand(self):
8639 device = 'mps'
8640 input_ = torch.rand(2, 3, 16, 16, device=device)
8641 kernel = torch.rand(1, 1, 3, 11, device=device)
8642 tmp_kernel = kernel.expand(-1, 3, -1, -1)
8643 output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
8644
8645 # The test should not crash
8646 def test_permute(self):
PumeTufc1c0cd2022-11-18 07:24:33 +00008647 M_cpu = torch.randn(5, 5)
8648 M_mps = M_cpu.to('mps')
8649
8650 output_cpu = M_cpu.permute(1, 0)
8651 output_mps = M_mps.permute(1, 0)
8652
8653 self.assertEqual(output_cpu, output_mps)
8654 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Seth017b0ae2022-05-31 02:09:03 +00008655
8656 # Printing of non_contiguous should not crash
8657 def test_print_non_contiguous(self):
8658 print(torch.ones(100, 100, device='mps').nonzero())
8659 print(torch.ones(100, 100, device='mps').nonzero().contiguous())
8660
Kulin Sethe011a8e2022-05-13 18:28:53 +00008661 def test_zero_grad(self):
8662 i = torch.randn(2, 5, requires_grad=True)
8663 module = nn.Linear(5, 5)
8664 for p in module.parameters():
8665 p.requires_grad = False
8666 module.zero_grad()
8667
8668 module.weight.requires_grad = True
8669 module.zero_grad()
8670 self.assertIsNone(module.weight.grad) # uninitialized grad
8671
8672 module(i).sum().backward()
8673 self.assertIsNotNone(module.weight.grad)
8674 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8675 module.zero_grad()
Jane Xub90496e2023-01-25 19:47:57 +00008676 self.assertIsNone(module.weight.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008677
8678 module.bias.requires_grad = True
8679 module.zero_grad()
Jane Xub90496e2023-01-25 19:47:57 +00008680 self.assertIsNone(module.weight.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008681 self.assertIsNone(module.bias.grad)
8682 module(i).sum().backward()
8683 self.assertIsNotNone(module.weight.grad)
8684 self.assertIsNotNone(module.bias.grad)
8685 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8686 self.assertGreater(module.bias.grad.data.abs().sum(), 0)
Jane Xub90496e2023-01-25 19:47:57 +00008687
8688 # Force set to zeros.
8689 module.zero_grad(set_to_none=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008690 self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
8691 self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
8692
Jane Xub90496e2023-01-25 19:47:57 +00008693 module.zero_grad()
Kulin Sethe011a8e2022-05-13 18:28:53 +00008694 self.assertIsNone(module.weight.grad)
Jane Xub90496e2023-01-25 19:47:57 +00008695 self.assertIsNone(module.bias.grad)
8696
Kulin Sethe011a8e2022-05-13 18:28:53 +00008697
8698 def test_no_grad(self):
8699 for dtype in [torch.bfloat16, torch.float, torch.double]:
8700 module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
8701 input = torch.randn(1, 2, 10, 10).to(dtype)
8702 x = input
8703 y = input.clone()
8704
8705 output = module(x)
8706 self.assertTrue(output.requires_grad)
8707 output.backward(torch.ones(1, 5, 10, 10))
8708
8709 with torch.no_grad():
8710 output2 = module(y)
8711 self.assertFalse(output2.requires_grad)
8712 self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
8713
8714 def test_invalid_conv1d(self):
8715 for dtype in [torch.bfloat16, torch.float, torch.double]:
8716 module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
8717 input = torch.randn(1, 3, 4).to(dtype)
8718 with self.assertRaisesRegex(RuntimeError,
8719 r'Calculated padded input size per channel: \(4\). ' +
8720 r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
8721 module(input)
8722
8723 # Negative stride check
8724 module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
8725 input = torch.randn(1, 3, 4).to(dtype)
8726 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8727 module(input)
8728
8729 def test_conv2d_discontiguous_weight(self):
8730 # Test for https://github.com/pytorch/pytorch/issues/55781
8731 x = torch.ones(64, 16, 16, 16)
8732 weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
8733 self.assertFalse(weight.is_contiguous())
8734 y = torch.nn.functional.conv2d(x, weight, None)
8735 if torch.backends.mkldnn.is_available():
8736 # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
8737 with torch.backends.mkldnn.flags(enabled=False):
8738 y_ = torch.nn.functional.conv2d(x, weight, None)
8739 self.assertEqual(y, y_)
8740 self.assertEqual(y.sum(), 4186112.)
8741
8742 def test_invalid_conv2d(self):
8743 for dtype in [torch.bfloat16, torch.float, torch.double]:
8744 module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
8745 input = torch.empty(1, 1, 4, 4).to(dtype)
8746 self.assertRaises(RuntimeError, lambda: module(input))
8747
8748 module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
8749 input = torch.randn(1, 3, 1, 1)
8750 with self.assertRaisesRegex(RuntimeError,
8751 r'Calculated padded input size per channel: \(1 x 1\). ' +
8752 r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
8753 module(input)
8754
8755 # Negative stride check
8756 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
8757 input = torch.randn(1, 3, 4, 4).to(dtype)
8758 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8759 module(input)
8760
8761 # Zero stride check
8762 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
8763 input = torch.randn(1, 3, 4, 4).to(dtype)
8764 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8765 module(input)
8766
Nikita Shulgafa799132022-10-06 15:38:57 +00008767 # Input and weights on different devices
8768 self.assertRaisesRegex(RuntimeError,
8769 'must be on the same device',
8770 lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
8771 self.assertRaisesRegex(RuntimeError,
8772 'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same',
8773 lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
8774
8775
Kulin Sethe011a8e2022-05-13 18:28:53 +00008776 def test_conv2d_valid_padding(self, device='mps'):
8777 # Test F.conv2d padding='valid' is the same as no padding
8778 x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
8779 y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
8780
8781 expect = F.conv2d(x, y)
8782 actual = F.conv2d(x, y, padding='valid')
8783 self.assertEqual(expect.to('cpu'), actual.to('cpu'))
8784
Nikita Shulga265d6aa2023-11-10 04:29:33 +00008785 def test_conv2d_backward_collision(self):
8786 # Test for https://github.com/pytorch/pytorch/issues/112998
8787 x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True)
8788 m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps")
8789 m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps")
8790 y1, y2 = m1(x), m2(x)
8791 self.assertEqual(y1.shape, y2.shape)
8792 y1.sum().backward()
8793 # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8794 y2.sum().backward()
8795
Lucas Steuernagel2e517b22023-12-15 23:05:01 +00008796 @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
8797 def test_conv3d_backward_collision(self):
8798 # Conv3D is only available from MacOS 13.2 onwards
8799 x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
8800 m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
8801 m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
8802 y1, y2 = m1(x), m2(x)
8803 self.assertEqual(y1.shape, y2.shape)
8804 y1.sum().backward()
8805 # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8806 y2.sum().backward()
Nikita Shulga265d6aa2023-11-10 04:29:33 +00008807
Kulin Seth4858c562022-06-02 06:17:19 +00008808 def test_gemm_permute_transpose(self):
8809 batch_size = 32
8810 n = 20
8811 hidden = 768
8812 num_attention_heads = 12
8813 attention_head_size = hidden // num_attention_heads
8814
8815 def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
8816 new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
8817 x = x.view(new_x_shape)
8818 return x.permute(0, 2, 1, 3)
8819
8820 def attention2(key, *, workaround=False, device):
8821 key = transpose_for_scores(key)
8822 res = key.transpose(-1, -2)
8823 return res
8824
8825 A = torch.randn(batch_size, n, hidden)
8826 A_mps = A.detach().clone().to("mps")
8827
8828 r1 = attention2(A, device="cpu")
8829 r2 = attention2(A_mps, device="mps")
8830
8831 r2_cpu = r2.to("cpu")
8832 self.assertEqual(r1, r2_cpu)
8833
Nikita Shulgafd3a7262022-12-21 21:35:54 -08008834 def test_group_norm_backward(self, device='mps'):
8835 # See https://github.com/pytorch/pytorch/issues/88331 for more detail
8836 shape = [1, 4, 16, 16]
8837 x = torch.full(shape, 7.0, device=device)
8838
8839 target = torch.ones((1, 3, 128, 128), device=device)
8840
8841 conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
8842 conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
8843 norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device)
8844
8845 with torch.enable_grad():
8846 x = x.detach().requires_grad_()
8847 out = 5.5 * x
8848 out = conv_in(out)
8849 out = out + norm(out)
8850 out = out + norm(out)
8851 out = out + norm(out)
8852 out = F.interpolate(out, scale_factor=8.0, mode="nearest")
8853 out = norm(out)
8854 out = conv_out(out)
8855
8856 loss = (out - target).norm(dim=-1).sum()
8857 grad = -torch.autograd.grad(loss, x)[0]
8858 self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd')
8859
8860
Kulin Sethe011a8e2022-05-13 18:28:53 +00008861 # def test_conv2d_same_padding(self, device='mps'):
8862 # x = torch.rand(1, 1, 10, 11, device=device)
8863 # y = torch.rand(1, 1, 4, 5, device=device)
8864 # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
8865 # actual = F.conv2d(x, y, padding='same')
8866 # self.assertEqual(expect.to('cpu'), actual.to('cpu'))
8867
8868 # # With dilation
8869 # y = torch.rand(1, 1, 3, 4, device=device)
8870 # expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
8871 # actual = F.conv2d(x, y, padding='same', dilation=2)
8872 # self.assertEqual(expect, actual)
8873
8874 # # Dilation with asymmetric padding
8875 # y = torch.rand(1, 1, 4, 4, device=device)
8876 # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
8877 # actual = F.conv2d(x, y, padding='same', dilation=3)
8878 # self.assertEqual(expect, actual)
8879
8880
Li-Huai (Allan) Lin38e14402023-11-08 16:19:38 -08008881class TestPad(TestCaseMPS):
8882 def test_constant_pad(self):
8883 m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
8884 input_cpu = torch.randn(1, 16, 16, 16)
8885 input_mps = input_cpu.detach().clone().to("mps")
8886 r_cpu = m(input_cpu)
8887 r_mps = m(input_mps)
8888 self.assertEqual(r_cpu, r_mps.to("cpu"))
8889
8890 # Arbitrary input dimensions
8891 pad = (1, 1, 0, 0, 0, 0)
8892 value = 3.5
8893 input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
8894 input_mps = input_cpu.detach().clone().to("mps")
8895 r_cpu = F.pad(input_cpu, pad=pad, value=value)
8896 r_mps = F.pad(input_mps, pad=pad, value=value)
8897 self.assertEqual(r_cpu, r_mps.to("cpu"))
8898
8899 def test_circular_pad(self):
8900 # https://github.com/pytorch/pytorch/issues/80856
8901 k_cpu = torch.ones(3, 3, 9, 9)
8902 k_mps = k_cpu.detach().clone().to("mps")
8903
8904 x_cpu = torch.rand(1, 3, 32, 32)
8905 x_mps = x_cpu.detach().clone().to("mps")
8906
8907 x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular')
8908 x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular')
8909
8910 y_cpu = F.conv2d(x_pad_cpu, k_cpu)
8911 y_mps = F.conv2d(x_pad_mps, k_mps)
8912
8913 self.assertEqual(y_cpu, y_mps.cpu())
8914
8915 def test_constant_pad_4d_warning(self):
8916 inputCPU = torch.rand((1, 2, 2, 2, 1, 1))
8917 inputMPS = inputCPU.detach().clone().to('mps')
8918 outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0])
8919 outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0])
8920 self.assertEqual(outputCPU, outputMPS)
8921
8922 def test_pad(self):
8923 def helper(shape, padding, op, value=0):
8924 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
8925 inputCPU.retain_grad()
8926 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
8927
8928 if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
8929 padCriteria = op(padding, value)
8930 else:
8931 padCriteria = op(padding)
8932 outputCPU = padCriteria(inputCPU)
8933 outputMPS = padCriteria(inputMPS)
8934 self.assertEqual(outputCPU, outputMPS)
8935
8936 # backward pass (chose 0.6 just to have the grad_output != 1)
8937 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
8938 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
8939 self.assertEqual(inputCPU.grad, inputMPS.grad)
8940
8941 # 1D Padding
8942 helper((2, 4, 3), 2, nn.ReflectionPad1d)
8943 # verify if a change in shape of input would cause problems with graph caching
8944 helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
8945 # Replication 1D
8946 helper((2, 1, 6), 3, nn.ReplicationPad1d)
8947 # Constant Pad 1D
8948 helper((2, 3, 4), 2, nn.ConstantPad1d)
8949 # Constant Pad 1D with single dimension input
8950 helper((16), (1, 2), nn.ConstantPad1d)
8951
8952 # 2D Padding
8953 helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
8954 # verify if a change in shape of input would cause problems with graph caching
8955 helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
8956 # this should make the padding (2, 2, 2, 2)
8957 helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
8958 # verify if a change in shape of padding would cause problems with graph caching
8959 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
8960 # Constant Pad 2D
8961 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
8962 # input size < pad size
8963 helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
8964 # pad dims < input dims
8965 helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
8966 # pad dims == input dims
8967 helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
8968 # input.numel() == 0 but output.numel() > 0
8969 helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
8970 # pad dims < input dims - 2
8971 helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
8972
8973 # 3D Padding
8974 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
8975 # verify if a change in shape of padding would cause problems with graph caching
8976 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
8977 # case where input_d == pad_front/back for ReplicationPad3d
8978 helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
8979 # Constant Pad 3D
8980 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
8981 # input size < pad size
8982 helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
8983 # check the workaround for the right padding bug in Monterey
8984 helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
8985
8986 def test_constant_pad_nd_preserves_memory_format(self):
Kulin Sethe011a8e2022-05-13 18:28:53 +00008987 nchw_tensor = torch.rand((1, 2, 5, 3))
8988 nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
8989 self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
8990
8991 nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
8992 nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
8993 self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
8994
8995
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00008996class TestLinalgMPS(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00008997 def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
8998 dtype = t.dtype
8999 numpy_dtype = dtype
9000 alpha = 1.2 if alpha is None else alpha
9001 beta = 0.8 if beta is None else beta
9002 res1 = f(t, m, v, alpha=alpha, beta=beta)
9003 res2 = torch.full_like(res1, math.nan)
9004 if transpose_out:
9005 res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
9006 f(t, m, v, alpha=alpha, beta=beta, out=res2)
9007 res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
9008 if beta != 0:
9009 res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9010 res3 = torch.from_numpy(res3).to(dtype)
Kulin Seth978304f2022-05-14 13:33:16 +00009011 self.assertEqual(res1, res2)
9012 self.assertEqual(res1, res3)
Kulin Sethe011a8e2022-05-13 18:28:53 +00009013
9014 def test_addmm(self, device="mps", dtype=torch.float32):
9015 M = torch.randn(10, 25, device=device).to(dtype)
9016 m1 = torch.randn(10, 50, device=device).to(dtype)
9017 m2 = torch.randn(50, 25, device=device).to(dtype)
9018 self._test_addmm_addmv(torch.addmm, M, m1, m2)
9019
Kulin Sethe011a8e2022-05-13 18:28:53 +00009020 # Test beta=0, M=nan
9021 M = torch.full((10, 25), math.nan, device=device).to(dtype)
9022 m1 = torch.randn(10, 50, device=device).to(dtype)
9023 m2 = torch.randn(50, 25, device=device).to(dtype)
9024 self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
9025
Kulin Seth978304f2022-05-14 13:33:16 +00009026 # Test transpose
9027 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
9028 def maybe_transpose(cond, m):
9029 if not cond:
9030 return m
9031 return m.t().clone(memory_format=torch.contiguous_format).t()
Kulin Sethe011a8e2022-05-13 18:28:53 +00009032
Kulin Seth978304f2022-05-14 13:33:16 +00009033 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
9034 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
9035 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
9036 self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
Kulin Sethe011a8e2022-05-13 18:28:53 +00009037
Denis Vieriu507b8c32023-02-11 00:16:46 +00009038 def _test_addr(self, f, t, m, v, alpha=None, beta=None):
9039 dtype = t.dtype
9040 numpy_dtype = dtype
9041 alpha = 1.2 if alpha is None else alpha
9042 beta = 0.8 if beta is None else beta
9043 res1 = f(t, m, v, alpha=alpha, beta=beta)
9044 res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
9045 if beta != 0:
9046 res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9047 res2 = torch.from_numpy(res2).to(dtype)
9048 self.assertEqual(res1, res2)
9049
9050 def test_addr(self, device="mps", dtype=torch.float32):
9051 M = torch.randn(10, 25, device=device).to(dtype)
9052 m1 = torch.randn(10, device=device).to(dtype)
9053 m2 = torch.randn(25, device=device).to(dtype)
9054 self._test_addr(torch.addr, M, m1, m2)
9055
9056 # Test beta=0, M=nan
9057 M = torch.full((10, 25), math.nan, device=device).to(dtype)
9058 m1 = torch.randn(10, device=device).to(dtype)
9059 m2 = torch.randn(25, device=device).to(dtype)
9060 self._test_addr(torch.addr, M, m1, m2, beta=0)
9061
watarungurunnnd444a3b2024-02-05 15:36:55 +00009062 def test_matrix_rank(self, device="mps", dtype=torch.float32):
9063 matrix_rank = torch.linalg.matrix_rank
9064
9065 def run_test(shape0, shape1, batch):
9066 a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
9067 rank_a = matrix_rank(a)
9068
9069 self.assertEqual(rank_a, matrix_rank(a.mH))
9070 aaH = torch.matmul(a, a.mH)
9071 rank_aaH = matrix_rank(aaH)
9072 rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
9073 self.assertEqual(rank_aaH, rank_aaH_hermitian)
9074 aHa = torch.matmul(a.mH, a)
9075 self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
9076
9077 # check against NumPy
9078 self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
9079 self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
9080
9081 self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
9082 self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
9083
9084 # hermitian flag for NumPy was added in 1.14.0
9085 if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
9086 self.assertEqual(rank_aaH_hermitian,
9087 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
9088 self.assertEqual(matrix_rank(aaH, 0.01, True),
9089 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
9090
9091 # check out= variant
9092 out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
9093 ans = matrix_rank(a, out=out)
9094 self.assertEqual(ans, out)
9095 self.assertEqual(ans, rank_a)
9096
9097 shapes = (3, 13)
9098 batches = ((), (0, ), (4, ), (3, 5, ))
9099 for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
9100 # escape only when NotImplementedError of downstream function is raised
9101 # TODO: remove this once the required function is implemented
9102 try:
9103 run_test(shape0, shape1, batch)
9104 except NotImplementedError as e:
9105 with self.assertRaisesRegex(
9106 NotImplementedError,
9107 "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
9108 raise e
9109
9110 def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
9111 from torch.testing._internal.common_utils import random_hermitian_pd_matrix
9112
9113 def run_test_main(A, hermitian):
9114 # Testing against definition for pseudo-inverses
9115 A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
9116 np_A = A.cpu().numpy()
9117 np_A_pinv = A_pinv.cpu().numpy()
9118 if A.numel() > 0:
9119 self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
9120 self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
9121 self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9122 self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9123 else:
9124 self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
9125
9126 # Check out= variant
9127 out = torch.empty_like(A_pinv)
9128 ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
9129 self.assertEqual(ans, out)
9130 self.assertEqual(ans, A_pinv)
9131
9132 def run_test_numpy(A, hermitian):
9133 # Check against NumPy output
9134 # Test float rcond, and specific value for each matrix
9135 rconds = [float(torch.rand(1)), ]
9136 # Test different types of rcond tensor
9137 for rcond_type in MPS_DTYPES:
9138 rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
9139 # Test broadcasting of rcond
9140 if A.ndim > 2:
9141 rconds.append(torch.rand(A.shape[-3], device=device))
9142 for rcond in rconds:
9143 actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
9144 torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
9145 self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
9146 numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
9147 expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
9148 self.assertEqual(actual, expected, atol=precision, rtol=precision)
9149
9150 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
9151 (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices
9152 (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices
9153 (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices
9154 A = torch.randn(*sizes, dtype=dtype, device=device)
9155 hermitian = False
9156 run_test_main(A, hermitian)
9157 run_test_numpy(A, hermitian)
9158
9159 # Check hermitian = True
9160 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
9161 (0, 0), (3, 0, 0), ]: # zero numel square matrices
9162 A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
9163 hermitian = True
9164 # escape only when NotImplementedError of downstream function is raised
9165 # TODO: remove this once the required function is implemented
9166 try:
9167 run_test_main(A, hermitian)
9168 except NotImplementedError as e:
9169 with self.assertRaisesRegex(
9170 NotImplementedError,
9171 "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9172 raise e
9173 try:
9174 run_test_numpy(A, hermitian)
9175 except NotImplementedError as e:
9176 with self.assertRaisesRegex(
9177 NotImplementedError,
9178 "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9179 raise e
9180
Nikita Shulga4ff91132024-05-24 16:08:04 +00009181 @parametrize("m", [1, 32, 64])
Nikita Shulga30610252024-05-03 15:20:39 +00009182 @parametrize("k", [32, 64])
9183 @parametrize("n", [48, 64])
9184 def test__int4_mm(self, m, k, n):
9185 q_group = 32
9186 inner_k_tiles = 2
9187
9188 torch.manual_seed(1)
9189 a_f32 = torch.rand((m, k), device="mps")
9190 b_f32 = torch.rand((k, n), device="mps")
9191
9192 def convert_weight_to_int4pack(b):
9193 b_int32, b_scales_and_zeros = _group_quantize_tensor(
9194 b, n_bit=4, q_group_size=q_group
9195 )
9196 b_int4pack = torch._convert_weight_to_int4pack(
9197 b_int32.cpu(), inner_k_tiles
9198 ).to(device="mps")
9199
9200 return b_int4pack, b_scales_and_zeros
9201
9202 def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
9203 return torch._weight_int4pack_mm(
9204 a, b_int4pack, q_group, b_scales_and_zeros
Nikita Shulga4ff91132024-05-24 16:08:04 +00009205 )
Nikita Shulga30610252024-05-03 15:20:39 +00009206
9207 b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
9208
9209 for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9210 a = a_f32.to(dtype=dtype)
9211 b = b_f32.to(dtype=dtype)
9212 b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
9213 ref = torch.mm(a, b)
9214 res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
9215
9216 mean_err = ((res - ref).abs() / ref).mean()
9217 self.assertTrue(mean_err < 0.05)
9218
Nikita Shulga4ff91132024-05-24 16:08:04 +00009219 @parametrize("m", [1, 32, 64])
9220 @parametrize("k", [32, 64])
9221 @parametrize("n", [32, 64])
9222 def test__int8_mm(self, m, k, n):
9223 torch.manual_seed(1)
9224 a_f32 = torch.rand((m, k), device="mps")
9225 b_f32 = torch.rand((n, k), device="mps")
9226
9227 def convert_weight_to_int8pack(b):
9228 b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
9229 b, -128, 127, torch.int8
9230 )
9231 return b_int8pack, b_scales
9232
9233 def weight_int8pack_mm(a, b_int8pack, b_scales):
9234 return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
9235
9236 b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32)
9237 for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9238 a = a_f32.to(dtype=dtype)
9239 b = b_f32.to(dtype=dtype)
9240 b_scales = b_scales_f32.to(dtype=dtype)
9241 res = weight_int8pack_mm(a, b_int8pack, b_scales)
9242 ref = torch.mm(a, b.transpose(0, 1))
9243
9244 mean_err = ((res - ref).abs() / ref).mean()
9245 self.assertTrue(mean_err < 0.05)
9246
Nikita Shulga30610252024-05-03 15:20:39 +00009247
watarungurunnnd444a3b2024-02-05 15:36:55 +00009248
9249
9250
9251
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00009252class TestGatherScatter(TestCaseMPS):
Kulin Sethb744e1c2022-07-01 15:10:56 +00009253 def test_slicing_with_step(self):
9254 # Slicing with step
9255 # https://github.com/pytorch/pytorch/issues/78886
9256 x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
9257 x_mps[::2] = 1.0
9258
Kulin Seth54361342022-07-06 03:39:20 +00009259 x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
Kulin Sethb744e1c2022-07-01 15:10:56 +00009260 x_cpu[::2] = 1.0
9261
9262 self.assertEqual(x_cpu, x_mps)
9263
Denis Vieriu4247cc92022-09-14 17:24:24 +00009264 def test_cast_gather_scatter(self):
9265 for _ in range(0, 50):
9266 input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8)
9267 with torch.no_grad():
9268 s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
9269 s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0)
9270 s = s.long()
9271 s_cpu = s_cpu.long()
9272 self.assertEqual(s.cpu(), s_cpu)
9273
9274 s = s.float()
9275 s_cpu = s_cpu.float()
9276 self.assertEqual(s.cpu(), s_cpu)
9277
9278 s /= 255
9279 s_cpu /= 255
9280 self.assertEqual(s.cpu(), s_cpu)
9281
Kulin Sethb744e1c2022-07-01 15:10:56 +00009282 def test_slicing_replace_column(self):
9283 # https://github.com/pytorch/pytorch/issues/78074
9284 def _helper(tensor_data):
9285 x_cpu = torch.tensor(tensor_data)
9286 x_mps = x_cpu.to('mps')
9287
9288 x_cpu[:, 0] = 7
9289 x_mps[:, 0] = 7
9290
9291 self.assertEqual(x_cpu, x_mps)
9292
9293 _helper([[1, 2, 3], [4, 5, 6]])
9294 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
9295 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
9296
9297 def test_inplace_scatter(self):
9298 # https://github.com/pytorch/pytorch/issues/79672
9299 a_mps = torch.ones((2, 2),).to(torch.device("mps"))
9300 b_mps = torch.ones((2, 2),).to(torch.device("mps"))
9301
9302 a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9303 b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9304
9305 a_mps[:, 0] += b_mps[:, 0]
9306 a_cpu[:, 0] += b_cpu[:, 0]
9307 self.assertEqual(a_cpu, a_mps)
9308
9309 a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
9310 a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
9311 self.assertEqual(a_cpu, a_mps)
9312
Kulin Seth76cff182022-07-04 06:41:39 +00009313# These tests were taken from test/test_view_ops.py
9314# They are subset of those tests as currently only this subset is working.
9315# This whole `class` will be removed when we add generic device testing. There
9316# are no additional tests added apart from what is part of test_view_ops.py
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00009317class TestViewOpsMPS(TestCaseMPS):
Kulin Sethb744e1c2022-07-01 15:10:56 +00009318 exact_dtype = True
9319
Ramin Azarmehr36062dd2023-02-07 15:51:26 +00009320 def test_permute_slicing(self):
9321 # test the fix for crash reported in
9322 # https://github.com/pytorch/pytorch/issues/94190
9323 cpu_x = (torch.randn([3, 2, 2]).float())
9324 mps_x = cpu_x.detach().clone().to('mps')
9325 cpu_out = cpu_x.permute((2, 0, 1)) * 2.0
9326 mps_out = mps_x.permute((2, 0, 1)) * 2.0
9327 # this print caused a crash prior to fix PR#94259
9328 print(torch.zeros_like(mps_out))
Ramin Azarmehr4f691d22023-02-09 19:07:13 +00009329 # test the fix for fill_scalar_mps() mentioned in issue #94190
9330 self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out))
9331 self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1))
Ramin Azarmehr36062dd2023-02-07 15:51:26 +00009332
Kulin Sethb744e1c2022-07-01 15:10:56 +00009333 def is_view_of(self, base, other):
9334 if (not other._is_view() or
9335 other is base or
9336 other._base is not base or
9337 base.device != other.device):
9338 return False
9339 # Note: only validates storage on native device types
9340 # because some accelerators, like XLA, do not expose storage
Kulin Seth76cff182022-07-04 06:41:39 +00009341 if base.device.type == 'mps':
Nikita Shulgab8a706a2024-05-09 14:04:21 +00009342 if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
Kulin Sethb744e1c2022-07-01 15:10:56 +00009343 return False
9344
9345 return True
9346
9347 # Returns true if v1 and v2 are views of the same base
9348 def is_view_of_same_base(self, v1, v2):
9349 if (not v1._is_view() or v1 is v2):
9350 return False
9351 return self.is_view_of(v1._base, v2)
9352
9353 # Performs transpose if contiguous=True, else returns the input tensor as is
9354 def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
9355 if contiguous:
9356 return x
9357 else:
9358 return x.transpose(dim0, dim1)
9359
9360 def test_diagonal_view(self, device="mps"):
9361 t = torch.ones((5, 5), device=device)
9362 v = torch.diagonal(t)
9363 self.assertTrue(self.is_view_of(t, v))
9364
9365 v[0] = 0
9366 self.assertEqual(t[0, 0], v[0])
9367
9368 t = torch.ones((3, 3, 3), device="mps")
9369 v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
9370 self.assertTrue(self.is_view_of(t, v))
9371
9372 v[0, 0] = 0
9373 self.assertEqual(t[0, 0, 1], v[0, 0])
9374
9375 def test_select_view(self, device="mps") -> None:
9376 t = torch.ones((5, 5), device=device)
9377 v = t.select(0, 2)
9378 self.assertTrue(self.is_view_of(t, v))
9379
9380 v[0] = 0
9381 self.assertEqual(t[2, 0], v[0])
9382
9383 def test_unbind_view(self, device="mps") -> None:
9384 t = torch.zeros((5, 5), device=device)
9385 tup = torch.unbind(t)
9386
9387 for idx, v in enumerate(tup):
9388 self.assertTrue(self.is_view_of(t, v))
9389
9390 v[0] = idx + 1
9391 self.assertEqual(t[idx, 0], v[0])
9392
9393 def test_expand_view(self, device="mps") -> None:
9394 t = torch.ones((5, 1), device=device)
9395 v = t.expand(5, 5)
9396 self.assertTrue(self.is_view_of(t, v))
9397
9398 v[2, 2] = 0
9399 self.assertEqual(t[2, 0], v[2, 2])
9400
9401 def test_expand_as_view(self, device="mps"):
9402 t = torch.ones((5, 1), device=device)
9403 e = torch.empty((5, 5), device=device)
9404 v = t.expand_as(e)
9405 self.assertTrue(self.is_view_of(t, v))
9406
9407 v[2, 2] = 0
9408 self.assertEqual(t[2, 0], v[2, 2])
9409
9410 def test_narrow_view(self, device="mps"):
9411 t = torch.ones((5, 5), device=device)
9412 v = torch.narrow(t, 1, 2, 2)
9413 self.assertTrue(self.is_view_of(t, v))
9414
9415 v[0, 0] = 0
9416 self.assertEqual(t[0, 2], v[0, 0])
9417
9418 def test_permute_view(self, device="mps") -> None:
9419 t = torch.ones((5, 5), device=device)
9420 v = t.permute(1, 0)
9421 self.assertTrue(self.is_view_of(t, v))
9422
9423 v[0, 1] = 0
9424 self.assertEqual(t[1, 0], v[0, 1])
9425
9426 def test_transpose_view(self, device="mps"):
9427 for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
9428 t = torch.ones((5, 5), device=device)
9429 v = fn(t, 0, 1)
9430 self.assertTrue(self.is_view_of(t, v))
9431
9432 v[0, 1] = 0
9433 self.assertEqual(t[1, 0], v[0, 1])
9434
9435 def test_transpose_inplace_view(self, device="mps"):
9436 t = torch.ones(5, 5, device=device)
9437 v = t.view_as(t)
9438 v = v.swapdims_(0, 1)
9439 self.assertTrue(self.is_view_of(t, v))
9440 v[0, 1] = 0
9441 self.assertEqual(t[1, 0], v[0, 1])
9442
9443 t = torch.ones(5, 5, device=device)
9444 v = t.view_as(t)
9445 v = v.swapaxes_(0, 1)
9446 self.assertTrue(self.is_view_of(t, v))
9447 v[0, 1] = 0
9448 self.assertEqual(t[1, 0], v[0, 1])
9449
9450 t = torch.ones(5, 5, device=device)
9451 v = t.view_as(t)
9452 v = v.transpose_(0, 1)
9453 self.assertTrue(self.is_view_of(t, v))
9454 v[0, 1] = 0
9455 self.assertEqual(t[1, 0], v[0, 1])
9456
9457 def test_t_view(self, device="mps"):
9458 t = torch.ones((5, 5), device=device)
9459 v = t.t()
9460 self.assertTrue(self.is_view_of(t, v))
9461
9462 v[0, 1] = 0
9463 self.assertEqual(t[1, 0], v[0, 1])
9464
Denis Vieriuba275482024-05-08 01:00:37 +00009465 def test_inplace_view_add(self):
9466 # https://github.com/pytorch/pytorch/issues/96153
9467 t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3)
9468 t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3)
9469 t_mps = t_mps + 1
9470 t_cpu = t_cpu + 1
9471 self.assertEqual(t_mps, t_cpu)
9472
Kulin Sethb744e1c2022-07-01 15:10:56 +00009473 def test_t_inplace_view(self, device="mps"):
9474 t = torch.ones(5, 5, device=device)
9475 v = t.view_as(t)
9476 v = v.t_()
9477 self.assertTrue(self.is_view_of(t, v))
9478 v[0, 1] = 0
9479 self.assertEqual(t[1, 0], v[0, 1])
9480
9481 def test_T_view(self, device="mps"):
9482 for op in ("T", "H", "mT", "mH"):
9483 t = torch.ones((5, 5), device=device)
9484 v = getattr(t, op)
9485 self.assertTrue(self.is_view_of(t, v))
9486
9487 v[0, 1] = 0
9488 self.assertEqual(t[1, 0], v[0, 1])
9489
Denis Vieriu4477a5b2022-12-22 21:21:00 +00009490 def test_unfold_view(self, device="mps"):
9491 t = torch.ones(10, device=device)
9492 v = t.unfold(0, 3, 2)
9493 self.assertTrue(self.is_view_of(t, v))
Kulin Sethb744e1c2022-07-01 15:10:56 +00009494
Denis Vieriu4477a5b2022-12-22 21:21:00 +00009495 v[1, 0] = 0
9496 self.assertEqual(t[2], v[1, 0])
Kulin Sethb744e1c2022-07-01 15:10:56 +00009497
9498 def test_squeeze_view(self, device="mps"):
9499 t = torch.ones(5, 1, 5, device=device)
9500 v = torch.squeeze(t)
9501 self.assertTrue(self.is_view_of(t, v))
9502 v[0, 1] = 0
Kulin Seth76cff182022-07-04 06:41:39 +00009503 self.assertTrue(t is v._base)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009504
9505 def test_squeeze_inplace_view(self, device="mps"):
9506 t = torch.ones(5, 5, device=device)
9507 v = t.view_as(t)
9508 v = v.squeeze_()
9509 self.assertTrue(self.is_view_of(t, v))
9510 v[0, 1] = 0
Kulin Seth76cff182022-07-04 06:41:39 +00009511 self.assertTrue(t is v._base)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009512
9513 def test_unsqueeze_view(self, device="mps"):
9514 t = torch.ones(5, 5, device=device)
9515 v = torch.unsqueeze(t, 1)
9516 self.assertTrue(self.is_view_of(t, v))
9517
9518 v[0, 0, 1] = 0
9519 self.assertEqual(t[0, 1], v[0, 0, 1])
9520
9521 def test_unsqueeze_inplace_view(self, device="mps"):
9522 t = torch.ones(5, 5, device=device)
9523 v = t.view_as(t)
9524 v = v.unsqueeze_(1)
9525 self.assertTrue(self.is_view_of(t, v))
9526 v[0, 0, 1] = 0
9527 self.assertEqual(t[0, 1], v[0, 0, 1])
9528
9529 def test_as_strided_view(self, device="mps"):
9530 t = torch.ones(5, 5, device=device)
9531 v = torch.as_strided(t, (25,), (1,))
9532 self.assertTrue(self.is_view_of(t, v))
9533
9534 v[6] = 0
9535 self.assertEqual(t[1, 1], v[6])
9536
9537 def test_as_strided_inplace_view(self, device="mps"):
9538 t = torch.ones(5, 5, device=device)
9539 v = t.view_as(t)
9540 v = v.as_strided_((25,), (1,))
9541 self.assertTrue(self.is_view_of(t, v))
9542 v[6] = 0
9543 self.assertEqual(t[1, 1], v[6])
9544
9545 def test_view_view(self, device="mps"):
9546 t = torch.ones(5, 5, device=device)
9547 v = t.view(25)
9548 self.assertTrue(self.is_view_of(t, v))
9549
9550 v[6] = 0
9551 self.assertEqual(t[1, 1], v[6])
9552
9553 def test_view_as_view(self, device="mps"):
9554 t = torch.ones(5, 5, device=device)
9555 e = torch.empty((25,))
9556 v = t.view_as(e)
9557 self.assertTrue(self.is_view_of(t, v))
9558
9559 v[6] = 0
9560 self.assertEqual(t[1, 1], v[6])
9561
9562 def test_contiguous_self(self, device="mps"):
9563 t = torch.ones(5, 5, device=device)
9564 s = t.contiguous()
9565 self.assertTrue(s is t)
9566
9567 def test_contiguous_nonview(self, device="mps"):
9568 t = torch.ones(5, 5, device=device)
9569 nv = t.t().contiguous()
9570 self.assertTrue(not self.is_view_of(t, nv))
9571
9572 nv[0, 0] = 0
9573 self.assertNotEqual(t[0, 0], nv[0, 0])
9574
9575 def test_reshape_view(self, device="mps"):
9576 t = torch.ones(5, 5, device=device)
9577 v = torch.reshape(t, (25,))
9578 self.assertTrue(self.is_view_of(t, v))
9579
9580 v[6] = 0
9581 self.assertEqual(t[1, 1], v[6])
9582
9583 def test_reshape_as_view(self, device="mps"):
9584 t = torch.ones(5, 5, device=device)
9585 e = torch.empty((25,), device=device)
9586 v = t.reshape_as(e)
9587 self.assertTrue(self.is_view_of(t, v))
9588
9589 v[6] = 0
9590 self.assertEqual(t[1, 1], v[6])
9591
9592 def test_reshape_nonview(self, device="mps"):
9593 t = torch.ones(5, 5, device=device)
9594 nv = torch.reshape(t.t(), (25,))
9595 self.assertTrue(not self.is_view_of(t, nv))
9596
9597 nv[6] = 0
9598 self.assertNotEqual(t[1, 1], nv[6])
9599
9600 def test_flatten_view(self, device="mps"):
9601 def test_writes_propagate(t, v):
9602 idx_t = (0,) * t.ndim
9603 idx_v = (0,) * v.ndim
9604 v[idx_v] = 0
9605 self.assertEqual(t[idx_t], v[idx_v])
9606
9607 t = torch.ones(1, 2, 3, 4, device=device)
9608 v = t.flatten()
9609 self.assertTrue(self.is_view_of(t, v))
9610 test_writes_propagate(t, v)
9611
9612 # zero-dimensional tensor
9613 t = torch.tensor(1, device=device)
9614 v = t.flatten()
9615 test_writes_propagate(t, v)
9616 self.assertTrue(self.is_view_of(t, v))
9617
9618 t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
9619 v = t.flatten(0, 1)
9620 test_writes_propagate(t, v)
9621 self.assertTrue(self.is_view_of_same_base(t, v))
9622
9623 # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
9624 t = torch.ones(720, device=device) \
9625 .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
9626 # [--1--|---2---|-3-] [--1--|----2---|-3-]
9627 v1 = t.flatten(0, 1)
9628 v2 = v1.flatten(1, 3)
9629 v3 = v2.flatten(2, 2)
9630 test_writes_propagate(t, v1)
9631 self.assertTrue(self.is_view_of_same_base(t, v1))
9632 test_writes_propagate(t, v2)
9633 self.assertTrue(self.is_view_of_same_base(t, v2))
9634 test_writes_propagate(t, v3)
9635 self.assertTrue(self.is_view_of_same_base(t, v3))
9636
9637 def test_flatten_nonview(self, device="mps"):
9638 def assert_is_nonview(t, nv):
9639 idx_t = (0,) * t.ndim
9640 idx_nv = (0,) * nv.ndim
9641 self.assertTrue(not nv._is_view())
9642 nv[idx_nv] = 0
9643 self.assertNotEqual(t[idx_t], nv[idx_nv])
9644 t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
9645 nv = t.flatten(1, 3)
9646 assert_is_nonview(t, nv)
9647
9648 t = torch.ones(2, 2, device=device).T
9649 nv = t.flatten()
9650 assert_is_nonview(t, nv)
9651
9652 # flatten returns the original object if start_dim=end_dim
9653 t = t = torch.ones(2, 2, device=device)
9654 nv = t.flatten(1, 1)
9655 self.assertTrue(t is nv)
9656
9657 def test_basic_indexing_slice_view(self, device="mps"):
9658 t = torch.ones(5, 5, device=device)
9659 v = t[:2, :3]
9660 self.assertTrue(self.is_view_of(t, v))
9661
9662 v[0, 0] = 0
9663 self.assertEqual(t[0, 0], v[0, 0])
9664
9665 def test_basic_indexing_ellipses_view(self, device="mps"):
9666 t = torch.ones(5, 5, device=device)
9667 v = t[..., :2]
9668 self.assertTrue(self.is_view_of(t, v))
9669
9670 v[0, 0] = 0
9671 self.assertEqual(t[0, 0], v[0, 0])
9672
9673 def test_basic_indexing_newaxis_view(self, device="mps"):
9674 t = torch.ones(5, 5, device=device)
9675 v = t[None, :2, 3]
9676 self.assertTrue(self.is_view_of(t, v))
9677
9678 v[0, 0] = 0
9679 self.assertEqual(t[0, 3], v[0, 0])
9680
9681 def test_chunk_view(self, device="mps"):
9682 t = torch.zeros(3, 3, device=device)
9683 l = torch.chunk(t, 3)
9684
9685 for idx, v in enumerate(l):
9686 self.assertTrue(self.is_view_of(t, v))
9687
9688 v[0, 0] = idx + 1
9689 self.assertEqual(t[idx, 0], v[0, 0])
9690
9691 def test_split_view(self, device="mps"):
9692 t = torch.zeros(3, 3, device=device)
9693 l = torch.split(t, [1, 1, 1])
9694
9695 for idx, v in enumerate(l):
9696 self.assertTrue(self.is_view_of(t, v))
9697
9698 v[0, 0] = idx + 1
9699 self.assertEqual(t[idx, 0], v[0, 0])
9700
9701 def test_movedim_view(self, device="mps"):
9702 def run_test(device, op):
9703 t = torch.zeros(3, 3, device=device)
9704 out = op(t)
9705
9706 self.assertTrue(self.is_view_of(t, out))
9707
9708 # Randomly change values in output
9709 # and verify that original is changed
9710 # as well.
9711 for _ in range(3):
9712 idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
9713 out[idx_1, idx_2] = random.random()
9714 self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
9715
9716 for fn in [torch.movedim, torch.moveaxis]:
9717 op = partial(fn, source=(0, 1), destination=(1, 0))
9718 run_test(device, op)
9719
9720 op = partial(fn, source=0, destination=1)
9721 run_test(device, op)
9722
9723 # Testing that the generated view_copy kernel and its derivative are implemented correctly
9724 def test_view_copy(self, device="mps"):
9725 a = torch.randn(4, device=device, requires_grad=True)
9726 a_ref = a.clone().detach().requires_grad_()
9727 a_view = a_ref.view(2, 2)
9728 a_view_copy = torch.view_copy(a, (2, 2))
9729
9730 # view_copy ops don't preserve view relationship
9731 self.assertTrue(self.is_view_of(a_ref, a_view))
9732 self.assertFalse(self.is_view_of(a, a_view_copy))
9733
9734 a_view_copy.sum().backward()
9735 a_view.sum().backward()
9736
9737 # forward and backward give the same shape + result
9738 self.assertEqual(a_view_copy, a_view)
9739 self.assertEqual(a.grad, a_ref.grad)
9740
9741 def test_view_copy_out(self, device="mps"):
9742 a = torch.randn(2, 2, device=device)
9743 out = torch.empty(2, device=device)
9744
9745 torch.diagonal_copy(a, out=out)
9746 expected = torch.diagonal_copy(a)
9747
9748 self.assertEqual(expected, out)
9749
9750 a = torch.randn(4, device=device)
9751 out1 = torch.empty(2, device=device)
9752 out2 = torch.empty(2, device=device)
9753
9754 torch.split_copy(a, 2, out=(out1, out2))
9755 expected1, expected2 = torch.split_copy(a, 2)
9756
9757 self.assertEqual(expected1, out1)
9758 self.assertEqual(expected2, out2)
9759
Nikita Shulga13cff2e2022-10-14 17:35:18 +00009760 def test_detached_view_copy(self, device="mps"):
9761 # https://github.com/pytorch/pytorch/issues/86052
9762 x = torch.arange(2)
9763 # .detach() makes y not a view, but contig tensor
9764 # with non-zero offset
9765 y = x[1].detach()
9766 z = y.to(device)
9767 self.assertEqual(y, z.cpu())
9768
Kulin Sethb744e1c2022-07-01 15:10:56 +00009769 def test_empty_reshape(self, device="mps"):
9770 x = torch.randn(0, 6, device=device)
9771 self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
9772 # should be viewable -- i.e. data_ptr is the same.
9773 self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
9774
9775 # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
9776 self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
9777
9778 def test_expand(self, device="mps"):
9779 tensor = torch.rand(1, 8, 1, device=device)
9780 tensor2 = torch.rand(5, device=device)
9781 template = torch.rand(4, 8, 5, device=device)
9782 target = template.size()
9783 self.assertEqual(tensor.expand_as(template).size(), target)
9784 self.assertEqual(tensor.expand(4, 8, 5).size(), target)
9785 self.assertEqual(tensor.expand(target).size(), target)
9786 self.assertEqual(tensor2.expand_as(template).size(), target)
9787 self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
9788 self.assertEqual(tensor2.expand(target).size(), target)
9789
9790 # test double expand
9791 self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
9792
9793 # test non-contiguous
9794 noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
9795 self.assertFalse(noncontig.is_contiguous())
9796 self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
9797
9798 # make sure it's compatible with unsqueeze
9799 expanded = tensor2.expand(1, 1, 5)
9800 unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
9801 self.assertEqual(expanded, unsqueezed)
9802 self.assertEqual(expanded.stride(), unsqueezed.stride())
9803
9804 # test -1 as target size
9805 self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
9806 self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
9807
9808 # test expanding empty to empty
9809 self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
9810
9811 def test_view_empty(self, device="mps"):
9812 x = torch.randn(0, 6, device=device)
9813 self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
9814
9815 def test_reshape(self, device="mps"):
9816 x = torch.randn(3, 3, device=device)
9817 self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
9818 self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
9819 self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
9820 self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
9821
9822 y = torch.randn(4, 4, 4, device=device)[:, 0, :]
9823 # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
9824 if device != "meta":
9825 self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
9826 self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
9827 self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
9828
9829 s = torch.randn((), device=device)
9830 self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
9831 self.assertEqual(s.reshape(-1).shape, (1,))
9832 self.assertRaises(RuntimeError, lambda: s.reshape(2))
9833
9834 empty = torch.tensor([], device=device)
9835 self.assertEqual(empty, empty.reshape(-1))
9836 self.assertEqual(empty, empty.reshape([0]))
9837 # TODO: fix these once we have multi-dimensional empty tensors
9838 self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
9839 self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
9840 self.assertRaises(RuntimeError, lambda: empty.reshape(1))
9841
9842 x = torch.randn(3, 3, device=device)
9843 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
9844 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
9845 self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
9846
9847 def test_narrow(self, device="mps"):
9848 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
9849 self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
9850 self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
9851 self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
9852 self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
9853 self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
9854 self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
9855 self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
9856 self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
9857
9858 def test_narrow_tensor(self, device="mps"):
9859 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
9860 self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
9861 with self.assertRaises(Exception):
9862 x.narrow(0, torch.tensor(0.), 1)
9863 with self.assertRaises(Exception):
9864 x.narrow(0, torch.tensor([0]), 1)
9865 with self.assertRaises(Exception):
9866 x.narrow(0, torch.tensor([0, 1]), 1)
9867
9868 def test_t(self, device="mps"):
9869 # Test 0D tensors
9870 x = torch.randn(())
9871 self.assertEqual(x, x.t())
9872 x = x.to_sparse()
9873 self.assertEqual(x, x.t())
9874
9875 # Test 1D tensors
9876 x = torch.arange(4)
9877 self.assertEqual(x, x.t())
9878 x = x.to_sparse()
9879 self.assertEqual(x, x.t())
9880
9881 # Test 2D tensors
9882 x = torch.rand((2, 2))
9883 self.assertEqual(x.t(), x.transpose(0, 1))
9884 x = x.to_sparse()
9885 self.assertEqual(x.t(), x.transpose(0, 1))
9886
9887 # Test 3D tensor
9888 x = torch.rand((2, 2, 2))
9889 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
9890 x.t()
9891 x = x.to_sparse()
9892 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
9893 x.t()
9894
9895 def test_split(self, device="mps"):
9896 tensor = torch.rand(7, 4)
9897 split_size = 3
9898 dim = 0
9899 target_sizes = ([3, 4], [3, 4], [1, 4])
9900 splits = tensor.split(split_size, dim)
9901 start = 0
9902 for target_size, split in zip(target_sizes, splits):
9903 self.assertEqual(split.size(), target_size)
9904 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
9905 start = start + target_size[dim]
9906
9907 # Variable sections split
9908 tensor = torch.randn(20, 10)
9909 dim = 0
9910 split_sizes = [5, 5, 10]
9911 target_sizes = ([[5, 10], [5, 10], [10, 10]])
9912 splits = tensor.split(split_sizes, dim)
9913 start = 0
9914 for target_size, split in zip(target_sizes, splits):
9915 self.assertEqual(split.size(), target_size)
9916 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
9917 start = start + target_size[dim]
9918
9919 split_sizes = [2, 2, 6]
9920 target_sizes = ([20, 2], [20, 2], [20, 6])
9921 dim = 1
9922 splits = tensor.split(split_sizes, dim)
9923 start = 0
9924 for target_size, split in zip(target_sizes, splits):
9925 self.assertEqual(split.size(), target_size)
9926 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
9927 start = start + target_size[dim]
9928
9929 def test_chunk(self, device="mps"):
9930 tensor = torch.rand(4, 7)
9931 num_chunks = 3
9932 dim = 1
9933 target_sizes = ([4, 3], [4, 3], [4, 1])
9934 splits = tensor.chunk(num_chunks, dim)
9935 start = 0
9936 for target_size, split in zip(target_sizes, splits):
9937 self.assertEqual(split.size(), target_size)
9938 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
9939 atol=0, rtol=0)
9940 start = start + target_size[dim]
9941
9942 # Invalid chunk sizes
9943 error_regex = 'chunk expects.*greater than 0'
9944 with self.assertRaisesRegex(RuntimeError, error_regex):
9945 tensor.chunk(0)
9946 with self.assertRaisesRegex(RuntimeError, error_regex):
9947 tensor.chunk(-2)
9948
9949 def test_unsqueeze(self, device="mps") -> None:
9950 x = torch.randn(2, 3, 4)
9951 y = x.unsqueeze(1)
9952 self.assertEqual(y, x.view(2, 1, 3, 4))
9953 y = x.clone().unsqueeze_(2)
9954 self.assertEqual(y, x.view(2, 3, 1, 4))
9955
9956 x = x[:, 1]
9957 self.assertFalse(x.is_contiguous())
9958 y = x.unsqueeze(1)
9959 self.assertEqual(y, x.contiguous().view(2, 1, 4))
9960 y = x.clone().unsqueeze_(2)
9961 self.assertEqual(y, x.contiguous().view(2, 4, 1))
9962
9963 # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
9964 def test_big_transpose(self, device="mps"):
9965 t = torch.rand(456, 789, device=device)
9966 t1 = t.t().contiguous()
9967 t2 = torch.from_numpy(t.cpu().numpy().transpose())
9968 self.assertEqual(t1, t2)
9969
9970 def test_T(self, device="mps"):
9971 a = torch.randn(2, 3, 4, device=device)
9972 t1 = a.T
9973 t2 = a.permute(2, 1, 0)
9974 self.assertEqual(t2, t1)
9975 b = torch.randn(10, device=device)
9976 self.assertEqual(b, b.T)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009977
9978 def test_transposes(self, device="mps", dtype=torch.float32):
9979 for op in ("T", "H", "mT", "mH", "adjoint"):
lezcano46a81c82023-01-15 19:35:15 +00009980 shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009981 for shape in shapes:
9982 a = make_tensor(shape, device=device, dtype=dtype)
9983 t1 = getattr(a, op)
9984 if op == "adjoint":
9985 t1 = t1()
9986 t2 = a
9987 if a.ndim != 0:
9988 t2 = t2.transpose(-2, -1)
9989 if op[-1] == "H" or op == "adjoint":
9990 t2 = t2.conj()
9991 self.assertEqual(t2, t1)
9992
9993 def test_transposes_errors(self, device="mps", dtype=torch.float32):
9994 for op in ("H", "mT", "mH", "adjoint"):
9995 shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
9996 for shape in shapes:
9997 a = make_tensor(shape, device=device, dtype=dtype)
9998 with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
9999 t1 = getattr(a, op)
10000 if op == "adjoint":
10001 t1 = t1()
10002
10003 def test_python_types(self, device="mps"):
10004 a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
10005 a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
10006 self.assertEqual(a1.dtype, a2.dtype)
10007
10008 b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
10009 b2 = torch.arange(10, 20, dtype=int, device=device)
10010 self.assertEqual(b1.dtype, b2.dtype)
10011
10012 c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
10013 c2 = torch.tensor([True, False], dtype=bool, device=device)
10014 self.assertEqual(c1.dtype, c2.dtype)
10015
10016 # TODO: is resize best put in test_view_ops?
10017 def test_resize_as_preserves_strides(self, device="mps"):
10018 x = torch.empty(2, 3).t()
10019 old_strides = x.stride()
10020 x.resize_as_(x)
10021 self.assertEqual(x.stride(), old_strides)
10022
10023 def test_memory_format_resize_as(self, device="mps"):
10024 def test_helper(shape, memory_format, device="mps"):
10025 xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
10026 flat = torch.randn(xc.numel(), device=device)
10027 flat.resize_as_(xc, memory_format=torch.preserve_format)
10028 self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10029
10030 test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
10031 test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
10032
10033 def test_memory_format_resize_(self, device="mps"):
10034 def test_helper(shape, numel, memory_format, device="mps"):
10035 flat = torch.randn(numel, device=device)
10036 flat.resize_(shape, memory_format=memory_format)
10037 self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10038
10039 test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
10040 test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
10041
10042 # TODO: OpInfo this
10043 def _test_atleast(self, device, torch_fn):
10044 # 0-dim
10045 s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
10046
10047 gradcheck(lambda x: torch_fn(x), s)
10048 gradgradcheck(lambda x: torch_fn(x), s)
10049
10050 # 1-dim
10051 a = torch.rand(4, dtype=torch.double, requires_grad=True)
10052
10053 gradcheck(lambda x: torch_fn(x), a)
10054 gradgradcheck(lambda x: torch_fn(x), a)
10055
10056 # 2,3,4-dim
10057 b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
10058 c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
10059 d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
10060
10061 input_tuple = (s, a, b, c, d)
10062 gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10063 gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10064
10065 def test_atleast_gradient(self, device="mps"):
10066 self._test_atleast(device, torch.atleast_1d)
10067 self._test_atleast(device, torch.atleast_2d)
10068 self._test_atleast(device, torch.atleast_3d)
10069
10070 def test_view(self, device="mps"):
10071 tensor = torch.rand(15, device=device)
10072 template = torch.rand(3, 5, device=device)
10073 empty = torch.empty(0, device=device)
10074 target = template.size()
10075 self.assertEqual(tensor.view_as(template).size(), target)
10076 self.assertEqual(tensor.view(3, 5).size(), target)
10077 self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
10078 self.assertEqual(tensor.view(-1, 5).size(), target)
10079 self.assertEqual(tensor.view(3, -1).size(), target)
10080 tensor_view = tensor.view(5, 3)
10081 tensor_view.fill_(random.uniform(0, 1))
10082 self.assertEqual(empty.view_as(empty), empty)
10083 self.assertEqual(empty.view(0), empty)
10084 self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
10085 self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
10086
10087 # test size inference with empty tensors
10088 self.assertEqual(empty.view(-1).size(), torch.Size([0]))
10089 self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
10090
10091 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10092 empty.view(-1, 0)
10093
10094 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10095 empty.view(3, 0, -1, 0)
10096
10097 self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
10098 self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
10099 self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
10100
Kulin Seth76cff182022-07-04 06:41:39 +000010101 def test_contiguous(self, device="mps"):
10102 x = torch.randn(1, 16, 5, 5, device=device)
10103 self.assertTrue(x.is_contiguous())
10104 stride = list(x.stride())
10105 stride[0] = 20
10106 # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
10107 x.set_(x.storage(), 0, x.size(), stride)
10108 self.assertTrue(x.is_contiguous())
Kulin Sethb744e1c2022-07-01 15:10:56 +000010109
Nikita Shulga436993d2023-03-04 01:29:07 +000010110 def test_resize_mps_dtypes(self, device="mps"):
Kulin Sethb744e1c2022-07-01 15:10:56 +000010111 shape = (2, 2)
Nikita Shulga436993d2023-03-04 01:29:07 +000010112 for dt in MPS_DTYPES:
Kulin Sethb744e1c2022-07-01 15:10:56 +000010113 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10114 x.resize_(shape)
10115 self.assertEqual(shape, x.shape)
10116
Nikita Shulga436993d2023-03-04 01:29:07 +000010117 def test_resize_as_mps_dtypes(self, device="mps"):
10118 for dt in MPS_DTYPES:
Kulin Sethb744e1c2022-07-01 15:10:56 +000010119 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10120 y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
10121 x.resize_as_(y)
10122 self.assertEqual(y.shape, x.shape)
10123
10124 def test_resize_overflow(self, device="mps"):
10125 x = torch.empty((), dtype=torch.float64)
10126 with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
10127 x.resize_([2, 4, 2**29, 2**29])
10128 with self.assertRaisesRegex(RuntimeError, 'overflow'):
10129 x.resize_([8, 8, 2**29, 2**29])
10130
10131 def test_view_all_dtypes_and_devices(self, device="mps"):
10132 for dt in (torch.float, torch.bool):
10133 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10134 self.assertEqual(x.view(6).shape, [6])
Kulin Sethe011a8e2022-05-13 18:28:53 +000010135
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000010136class TestConvolutionMPS(TestCaseMPS):
Kulin Seth31d4b6f2022-08-17 00:26:41 +000010137 def test_conv1d_all_strides_paddings(self):
10138 # https://github.com/pytorch/pytorch/issues/82921
10139 def helper(stride, padding):
10140 y_cpu = torch.randn(1, 57, 40)
10141 conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False)
10142 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10143 x_cpu = conv_cpu(y_cpu)
10144
10145 y_gpu = y_cpu.to(device='mps')
10146 x_gpu = conv_gpu(y_gpu)
10147 self.assertEqual(x_cpu, x_gpu.cpu())
10148 for stride in range(1, 4):
10149 for padding in range(1, 4):
10150 helper(stride, padding)
10151
10152
10153 def test_conv1d_channels_last(self):
10154 # https://github.com/pytorch/pytorch/issues/81557
10155 model_cpu = torch.nn.Conv1d(1, 128, 3)
10156 a_cpu = torch.arange((128 * 176), dtype=torch.float32)
10157 a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
10158 out_cpu = model_cpu(a_cpu)
10159
10160 a_mps = a_cpu.detach().clone().to("mps")
10161 model_mps = model_cpu.to("mps")
10162 out_mps = model_mps(a_mps)
10163
10164 self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
10165
10166 def test_conv_transpose_1d_all_strides(self):
10167 # https://github.com/pytorch/pytorch/issues/82711
10168 def helper(stride):
10169 y_cpu = torch.ones(1, 1, 2)
10170 deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1)
10171 deconv_cpu.weight.data = torch.ones(1, 1, 2)
10172 deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
10173 x_cpu = deconv_cpu(y_cpu)
10174
10175 y_gpu = y_cpu.to(device='mps')
10176 x_gpu = deconv_gpu(y_gpu)
10177 self.assertEqual(x_cpu, x_gpu.cpu())
10178 [helper(stride) for stride in [1, 2, 3]]
10179
10180 def test_conv_transpose_1d_nn_functional(self):
10181 # https://github.com/pytorch/pytorch/issues/82563
10182 tin = torch.rand((1, 512, 1245), dtype=torch.float32)
10183 tparams = torch.rand((512, 256, 16), dtype=torch.float32)
10184 tbias = torch.rand((256), dtype=torch.float32)
10185
10186 device = 'cpu'
10187 tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10188
10189 device = 'mps'
10190 tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10191
10192 self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
10193
Kulin Seth077db3d2022-09-20 06:19:40 +000010194 def test_conv_backward_1d_channels_last(self):
Denis Vieriue0b82d72023-01-10 18:30:18 +000010195 def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
10196 # https://github.com/pytorch/pytorch/issues/84511
Denis Vieriu5e475712023-02-22 18:04:09 +000010197 conv_cpu = torch.nn.Conv1d(
10198 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
Denis Vieriue0b82d72023-01-10 18:30:18 +000010199 conv_mps = torch.nn.Conv1d(
10200 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
10201 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
10202 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
Kulin Seth077db3d2022-09-20 06:19:40 +000010203
Kulin Seth077db3d2022-09-20 06:19:40 +000010204
Denis Vieriue0b82d72023-01-10 18:30:18 +000010205 data = torch.rand(shape, dtype=torch.float32)
10206 x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
10207 x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
10208 res_cpu = conv_cpu(x_cpu)
10209 res_mps = conv_mps(x_mps)
10210 self.assertEqual(res_cpu, res_mps)
10211 res_cpu = res_cpu.sum().backward()
10212 res_mps = res_mps.sum().backward()
10213
10214 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10215 self.assertEqual(x_cpu.grad, x_mps.grad)
10216
10217 helper(shape=(1, 176, 1))
10218 helper(shape=(2, 12, 1))
10219 helper(shape=(3, 176, 1))
10220 helper(shape=(4, 376, 1))
10221 helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
10222 helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
Kulin Seth077db3d2022-09-20 06:19:40 +000010223
Kulin Seth31d4b6f2022-08-17 00:26:41 +000010224 def test_conv1d_contiguous(self):
10225 model_cpu = torch.nn.Conv1d(1, 128, 3)
10226 a_cpu = torch.ones(128, 1, 176)
10227 out_cpu = model_cpu(a_cpu)
10228
10229 a_mps = a_cpu.detach().clone().to("mps")
10230 model_mps = model_cpu.to("mps")
10231 out_mps = model_mps(a_mps)
10232
10233 self.assertEqual(out_cpu.shape, out_mps.shape)
10234 self.assertEqual(out_cpu, out_mps.cpu())
10235
10236 def test_conv2d_all_strides_paddings(self):
10237 # https://github.com/pytorch/pytorch/issues/83180
Denis Vieriu5e475712023-02-22 18:04:09 +000010238 def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
10239 x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
10240 x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
10241
10242 if permute_data:
10243 x_cpu.permute(0, 2, 3, 1)
10244 x_mps.permute(0, 2, 3, 1)
10245
10246 for strideX in range(1, 4):
10247 for strideY in range(1, 4):
10248 conv_cpu = torch.nn.Conv2d(
10249 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
10250 conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
10251
10252 conv_mps = torch.nn.Conv2d(
10253 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
10254 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10255 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10256
10257 res_cpu = conv_cpu(x_cpu)
10258 res_mps = conv_mps(x_mps)
10259 self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
10260
10261 res_cpu = res_cpu.sum().backward()
10262 res_mps = res_mps.sum().backward()
10263 self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
10264 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10265 self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
10266 self.assertEqual(x_cpu.grad, x_mps.grad)
10267
10268 for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10269 for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
10270 for permute_data in [True, False]:
10271 helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
10272 helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10273 helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10274
10275 def test_conv_transpose_2d_strided(self):
10276 def helper(m_cpu, memory_format):
10277 m_mps = copy.deepcopy(m_cpu).requires_grad_()
10278 m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10279 m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10280
10281 input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
10282 input_mps = input_cpu.detach().clone().to("mps")
10283
10284 output_cpu = m_cpu(input_cpu)
10285 output_mps = m_mps(input_mps)
10286 self.assertEqual(output_cpu, output_mps)
10287
10288 for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10289 # With square kernels and equal stride
10290 helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
10291
10292 # non-square kernels and unequal stride and with padding
10293 helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
10294
10295 def test_conv_transpose_2d_specified_output(self):
10296 input_cpu = torch.randn(1, 16, 12, 12)
10297 input_mps = input_cpu.detach().clone().to("mps")
10298
10299 downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
10300 downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
10301 downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10302 downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10303
10304 upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
10305 upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
10306 upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10307 upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10308
10309 h_cpu = downsample_cpu(input_cpu)
10310 h_mps = downsample_mps(input_mps)
10311 self.assertEqual(h_cpu, h_mps)
10312
10313 size_cpu = h_cpu.size()
10314 size_mps = h_mps.size()
10315 self.assertEqual(size_cpu, size_mps)
10316
10317 output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
10318 output_mps = upsample_mps(h_mps, output_size=input_mps.size())
10319 self.assertEqual(output_cpu, output_mps)
10320 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Seth31d4b6f2022-08-17 00:26:41 +000010321
10322 def test_conv2d_single_stride(self):
10323 y_cpu = torch.randn(2, 2, 3, 6)
10324 y_gpu = y_cpu.to(device='mps')
10325 for stride in range(1, 4):
10326 conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride)
10327 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10328 x_cpu = conv_cpu(y_cpu)
10329 x_gpu = conv_gpu(y_gpu)
10330 self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10331
Lucas Steuernagel2e517b22023-12-15 23:05:01 +000010332 @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
10333 def test_conv3d_single_stride(self):
10334 # Conv3d is only available from MacOS 13.2 onwards
10335 y_cpu = torch.randn(2, 2, 3, 6)
10336 y_gpu = y_cpu.to(device='mps')
10337 for stride in range(1, 4):
10338 conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
10339 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10340 x_cpu = conv_cpu(y_cpu)
10341 x_gpu = conv_gpu(y_gpu)
10342 self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10343
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010344 def test_grid_sample(self):
10345 def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
10346 def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
10347 for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
10348 # grid_dim_contig_order specifies the dimension order that can
10349 # make grid to be contiguous.
10350 # i.e., grid.permute(grid_dim_contig_order) is contiguous.
10351 # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
10352 # initialized with contiguous tensor of shape [N, 2, H, W]
10353 # and permuted to [N, H, W, 2] afterwards.
10354 grid_shape = [N, H, W, 2]
10355 grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
10356 grid_fwd_permute = [None, None, None, None]
10357 for i, d in enumerate(grid_dim_contig_order):
10358 grid_fwd_permute[d] = i
10359
10360 def get_grid(device='cpu', data=None):
10361 if data is not None:
10362 assert list(data.shape) == grid_shape
10363 data = data.permute(grid_dim_contig_order).to(device)
10364 else:
10365 data = torch.randn(grid_init_shape, device=device)
10366 grid = data.permute(grid_fwd_permute)
10367 assert grid.permute(grid_dim_contig_order).is_contiguous()
10368 return grid
10369
10370 input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
10371 grid_cpu = get_grid().requires_grad_()
10372 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10373 align_corners=align_corners)
10374 self.assertTrue(out_cpu.size() == torch.Size([N, C, H, W]))
10375
10376 gradients = torch.randn_like(out_cpu)
10377 out_cpu.backward(gradients)
10378
10379
10380 # Compare against unvectorized CPU fallback
10381
10382 # NOTE [ grid_sample CPU fallback ]
10383 # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
10384 # 32-bit floats. So we also have a fallback that is used only for float tensors
10385 # requiring 64-bit indexing. That requires too much memory to run on CI, so we
10386 # also export the fallback and test it here to ensure feature parity with
10387 # the vectorized version.
10388 input_fallback = input_cpu.float().detach_().requires_grad_()
10389 grid_fallback = grid_cpu.float().detach_().requires_grad_()
10390 out_fallback = torch._grid_sampler_2d_cpu_fallback(
10391 input_fallback, grid_fallback,
10392 F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
10393 F.GRID_SAMPLE_PADDING_MODES[padding_mode],
10394 align_corners)
10395 self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
10396
10397 out_fallback.backward(gradients.float())
10398 if input_requires_grad:
10399 self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10400 self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10401
10402 input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad)
10403 grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
10404 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10405 self.assertEqual(out_cpu, out_mps)
10406 out_mps.backward(gradients.to("mps"))
10407 if input_requires_grad:
10408 self.assertEqual(input_cpu.grad, input_mps.grad)
10409 self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0)
10410
10411 # check that zero-dimensional input strides don't error out
10412 base_input = torch.randn(N, C, 1, IW)
10413 input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad)
10414 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10415 align_corners=align_corners)
10416
10417 input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
10418 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10419 self.assertEqual(out_cpu, out_mps)
10420
10421 # test same size output
10422 test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
10423
10424 # test larger output
10425 N = random.randint(2, 8)
10426 C = random.randint(2, 8)
10427 IH = random.randint(2, 8)
10428 IW = random.randint(2, 8)
10429 H = random.randint(IH + 1, 12)
10430 W = random.randint(IW + 1, 12)
10431 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10432
10433 # test smaller output
10434 N = random.randint(2, 8)
10435 C = random.randint(2, 8)
10436 IH = random.randint(2, 8)
10437 IW = random.randint(2, 8)
10438 H = random.randint(2, IH)
10439 W = random.randint(2, IW)
10440 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10441
10442 # test 1x1 inpput
10443 N = random.randint(2, 8)
10444 C = random.randint(2, 8)
10445 IH = 1
10446 IW = 1
10447 H = random.randint(2, 5)
10448 W = random.randint(2, 5)
10449 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10450
10451 # testing empty grid
10452 N = random.randint(2, 8)
10453 C = random.randint(2, 8)
10454 IH = random.randint(2, 8)
10455 IW = random.randint(2, 8)
10456 W = random.randint(3, IW + 2)
10457 test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
10458
10459 # testing empty channel
10460 N = random.randint(2, 8)
10461 IH = random.randint(2, 8)
10462 IW = random.randint(2, 8)
10463 H = random.randint(3, IH + 2)
10464 W = random.randint(3, IW + 2)
10465 test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
10466
10467 # testing empty batch
10468 C = random.randint(2, 8)
10469 IH = random.randint(2, 8)
10470 IW = random.randint(2, 8)
10471 H = random.randint(3, IH + 2)
10472 W = random.randint(3, IW + 2)
10473 test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
10474
10475 for mode in ('bilinear', 'nearest'):
10476 for padding_mode in ('zeros', 'reflection'):
10477 for align_corners in (True, False):
10478 # test known input
10479 input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
10480 grid = torch.tensor(
10481 [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
10482 [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
10483 if mode == 'bilinear':
10484 if padding_mode == 'zeros':
10485 if align_corners:
10486 groundtruth = torch.tensor(
10487 [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10488 [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
10489 else:
10490 groundtruth = torch.tensor(
10491 [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
10492 [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
10493 elif padding_mode == 'border':
10494 if align_corners:
10495 groundtruth = torch.tensor(
10496 [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10497 [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
10498 else:
10499 groundtruth = torch.tensor(
10500 [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10501 [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
10502 elif padding_mode == 'reflection':
10503 if align_corners:
10504 groundtruth = torch.tensor(
10505 [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
10506 [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
10507 else:
10508 groundtruth = torch.tensor(
10509 [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10510 [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
10511 else:
Justin Chu73e14552023-07-19 07:40:18 -070010512 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010513 elif mode == 'nearest':
10514 if padding_mode == 'zeros':
10515 if align_corners:
10516 groundtruth = torch.tensor(
10517 [[0., 8., 5., 7., 9.],
10518 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10519 else:
10520 groundtruth = torch.tensor(
10521 [[0., 8., 5., 7., 0.],
10522 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10523 elif padding_mode == 'border':
10524 if align_corners:
10525 groundtruth = torch.tensor(
10526 [[1., 8., 5., 7., 9.],
10527 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10528 else:
10529 groundtruth = torch.tensor(
10530 [[1., 8., 5., 7., 9.],
10531 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10532 elif padding_mode == 'reflection':
10533 if align_corners:
10534 groundtruth = torch.tensor(
10535 [[1., 8., 5., 7., 9.],
10536 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10537 else:
10538 groundtruth = torch.tensor(
10539 [[1., 8., 5., 7., 9.],
10540 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10541 else:
Justin Chu73e14552023-07-19 07:40:18 -070010542 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010543 elif mode == 'bicubic':
10544 if padding_mode == 'zeros':
10545 if align_corners:
10546 groundtruth = torch.tensor(
10547 [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
10548 [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
10549 else:
10550 groundtruth = torch.tensor(
10551 [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
10552 [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
10553 elif padding_mode == 'border':
10554 if align_corners:
10555 groundtruth = torch.tensor(
10556 [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
10557 [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
10558 else:
10559 groundtruth = torch.tensor(
10560 [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
10561 [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
10562 elif padding_mode == 'reflection':
10563 if align_corners:
10564 groundtruth = torch.tensor(
10565 [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
10566 [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
10567 else:
10568 groundtruth = torch.tensor(
10569 [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
10570 [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
10571 else:
Justin Chu73e14552023-07-19 07:40:18 -070010572 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010573
10574 else:
Justin Chu73e14552023-07-19 07:40:18 -070010575 raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010576 output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
10577 align_corners=align_corners)
10578 self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
Aaron Gokaslan660e8062023-08-22 23:16:35 +000010579 msg=f"groundtruth comparison failed for mode={mode}, "
10580 f"padding_mode={padding_mode}")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010581
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000010582class TestAdvancedIndexing(TestCaseMPS):
Kulin Sethce7177f2022-08-18 06:03:16 +000010583 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
Denis Vieriuce4f1872022-09-28 00:47:52 +000010584 supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
Kulin Sethce7177f2022-08-18 06:03:16 +000010585
Denis Vieriu38de9812023-01-04 00:02:24 +000010586 def test_nonzero_no_warning(self):
10587 device = "mps"
10588 t = torch.randn((2, 2), device=device)
10589 with warnings.catch_warnings(record=True) as w:
10590 warnings.simplefilter("always")
10591 torch.nonzero(t)
10592 t.nonzero()
10593 self.assertEqual(len(w), 0)
10594
10595 def test_nonzero(self):
10596 def helper(dtype):
10597 device = "mps"
10598 shapes = [
10599 torch.Size((12,)),
10600 torch.Size((12, 1)),
10601 torch.Size((1, 12)),
10602 torch.Size((6, 2)),
10603 torch.Size((3, 2, 2)),
10604 torch.Size((5, 5, 5)),
10605 ]
10606
10607 def gen_nontrivial_input(shape, dtype, device):
10608 if dtype != torch.bfloat16:
10609 return torch.randint(2, shape, device=device, dtype=dtype)
10610 else:
10611 # windows does not work for bfloat16 randing
10612 return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
10613
10614 for shape in shapes:
10615 tensor = gen_nontrivial_input(shape, dtype, device)
10616 dst1 = torch.nonzero(tensor, as_tuple=False)
10617 dst2 = tensor.nonzero(as_tuple=False)
10618 dst3 = torch.empty([], dtype=torch.long, device=device)
10619 dst3 = dst3.resize_(0)
10620 torch.nonzero(tensor, out=dst3)
10621 np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
10622 np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
10623 self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
10624 self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
10625 self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
10626 tup1 = torch.nonzero(tensor, as_tuple=True)
10627 tup2 = tensor.nonzero(as_tuple=True)
10628 tup1 = torch.stack(tup1).t().cpu()
10629 tup2 = torch.stack(tup2).t().cpu()
10630 self.assertEqual(tup1, np_result, atol=0, rtol=0)
10631 self.assertEqual(tup2, np_result, atol=0, rtol=0)
10632 [helper(dtype) for dtype in self.supported_dtypes]
10633
10634 def test_nonzero_astuple_out(self):
10635 device = "mps"
10636 t = torch.randn((3, 3, 3), device=device)
10637 out = torch.empty([], dtype=torch.long, device=device)
10638 out = out.resize_(0)
10639
10640 with self.assertRaises(RuntimeError):
10641 torch.nonzero(t, as_tuple=True, out=out)
10642
10643 self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
10644
10645 # Verifies that JIT script cannot handle the as_tuple kwarg
10646 # See Issue https://github.com/pytorch/pytorch/issues/45499.
10647 def _foo(t):
10648 tuple_result = torch.nonzero(t, as_tuple=True)
10649 nontuple_result = torch.nonzero(t, as_tuple=False)
10650 out = torch.empty_like(nontuple_result)
10651 torch.nonzero(t, as_tuple=False, out=out)
10652 return tuple_result, nontuple_result, out
10653
10654 with self.assertRaises(RuntimeError):
10655 scripted_foo = torch.jit.script(_foo)
10656
10657 # Verifies that JIT tracing works fine
10658 traced_foo = torch.jit.trace(_foo, t)
10659 traced_tuple, traced_nontuple, traced_out = traced_foo(t)
10660 expected_tuple = torch.nonzero(t, as_tuple=True)
10661 expected_nontuple = torch.nonzero(t)
10662
10663 self.assertEqual(traced_tuple, expected_tuple)
10664 self.assertEqual(traced_nontuple, expected_nontuple)
10665 self.assertEqual(traced_out, expected_nontuple)
10666
10667 def test_nonzero_discontiguous(self):
10668 device = "mps"
10669 shape = (4, 4)
10670 tensor = torch.randint(2, shape, device=device)
10671 tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
10672 dst1 = tensor.nonzero(as_tuple=False)
10673 dst2 = tensor_nc.nonzero(as_tuple=False)
10674 self.assertEqual(dst1, dst2, atol=0, rtol=0)
10675 dst3 = torch.empty_like(dst1)
10676 data_ptr = dst3.data_ptr()
10677 # expect dst3 storage to be reused
10678 torch.nonzero(tensor, out=dst3)
10679 self.assertEqual(data_ptr, dst3.data_ptr())
10680 self.assertEqual(dst1, dst3, atol=0, rtol=0)
10681 # discontiguous out
10682 dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
10683 data_ptr = dst4.data_ptr()
10684 strides = dst4.stride()
10685 torch.nonzero(tensor, out=dst4)
10686 self.assertEqual(data_ptr, dst4.data_ptr())
10687 self.assertEqual(dst1, dst4, atol=0, rtol=0)
10688 self.assertEqual(strides, dst4.stride())
10689
10690 def test_nonzero_non_diff(self):
10691 device = "mps"
10692 x = torch.randn(10, requires_grad=True)
10693 nz = x.nonzero()
10694 self.assertFalse(nz.requires_grad)
10695
Nikita Shulga916183a2023-09-13 19:28:47 +000010696 def test_nonzero_multi_threading(self):
10697 # Test that MPS does not crash if nonzero called concurrently
10698 # See https://github.com/pytorch/pytorch/issues/100285
10699 x = torch.rand(3, 3, device="mps")
10700 t1 = threading.Thread(target=torch.nonzero, args=(x,))
10701 t2 = threading.Thread(target=torch.nonzero, args=(x,))
10702 t1.start()
10703 t2.start()
10704
Denis Vieriu6a14fcb2022-09-29 23:23:00 +000010705 def test_masked_select(self):
10706 x = torch.randn(3, 4)
10707 x_mps = x.to("mps")
10708 mask = x.ge(0.5)
10709 mask_mps = x_mps.ge(0.5)
10710
10711 res = torch.masked_select(x, mask)
10712 res_mps = torch.masked_select(x_mps, mask_mps)
10713
10714 self.assertEqual(res, res_mps)
10715
Kulin Sethce7177f2022-08-18 06:03:16 +000010716 # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
Denis Vieriuce4f1872022-09-28 00:47:52 +000010717 def test_indexing_get(self):
Kulin Sethce7177f2022-08-18 06:03:16 +000010718 def helper(dtype):
10719 x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
10720 x_mps = x_cpu.detach().clone().to("mps")
10721
10722 y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]]
10723 y_mps = x_mps[[0, 1, 2], [0, 1, 0]]
10724 self.assertEqual(y_cpu, y_mps, str(dtype))
10725 [helper(dtype) for dtype in self.supported_dtypes]
10726
10727 def test_indexing_select_corners(self):
10728 def helper(dtype):
10729 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
10730 x_mps = x_cpu.detach().clone().to("mps")
10731
10732 rows_cpu = torch.tensor([[0, 0], [3, 3]])
10733 rows_mps = rows_cpu.detach().clone().to("mps")
10734
10735 cols_cpu = torch.tensor([[0, 2], [0, 2]])
10736 cols_mps = cols_cpu.detach().clone().to("mps")
10737
10738 res_cpu = x_cpu[rows_cpu, cols_cpu]
10739 res_mps = x_mps[rows_mps, cols_mps]
10740
10741 self.assertEqual(res_cpu, res_mps, str(dtype))
10742 [helper(dtype) for dtype in self.supported_dtypes]
10743
10744 # FIXME: uint8 fails for this testcase, needs further debugging
10745 def test_slicing_using_advanced_index_for_column(self):
10746 def helper(dtype):
10747 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
10748 x_mps = x_cpu.detach().clone().to("mps")
10749
10750 z_cpu = x_cpu[1:4, 1:3]
10751 z_mps = x_mps[1:4, 1:3]
10752 self.assertEqual(z_cpu, z_mps, str(dtype))
10753
10754 # using advanced index for column
10755 y_cpu = x_cpu[1:4, [1, 2]]
10756 y_mps = x_mps[1:4, [1, 2]]
10757 self.assertEqual(y_cpu, y_mps, str(dtype))
10758 # FIXME: use supported_dtypes once uint8 is fixed
10759 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
10760
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000010761 def test_boolean_array_indexing(self):
10762 def helper(dtype):
10763 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
10764 x_mps = x_cpu.detach().clone().to("mps")
Kulin Sethce7177f2022-08-18 06:03:16 +000010765
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000010766 res_cpu = x_cpu[x_cpu > 5]
10767 res_mps = x_mps[x_mps > 5]
Kulin Sethce7177f2022-08-18 06:03:16 +000010768
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000010769 self.assertEqual(res_cpu, res_mps, str(dtype))
10770 for dtype in self.supported_dtypes:
10771 # MPS support binary op with uint8 natively starting from macOS 13.0
10772 if product_version < 13.0 and dtype == torch.uint8:
10773 continue
10774 helper(dtype)
Denis Vieriuce4f1872022-09-28 00:47:52 +000010775
10776 def test_advanced_indexing_3D_get(self):
10777 def helper(x_cpu):
10778 x_mps = x_cpu.detach().clone().to("mps")
10779 self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
10780 self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
10781 self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
10782
10783 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
10784 [0.5, 0.6, 0.7, 0.8],
10785 [0.9, 1.0, 1.1, 1.2],
10786 [1.3, 1.4, 1.5, 1.6]],
10787
10788 [[2.0, 2.1, 2.2, 2.3],
10789 [2.4, 2.5, 2.6, 2.7],
10790 [2.8, 2.9, 3.0, 3.1],
10791 [3.2, 3.3, 3.4, 3.5]],
10792
10793 [[4.0, 4.1, 4.2, 4.3],
10794 [4.4, 4.5, 4.6, 4.7],
10795 [4.8, 4.9, 5.0, 5.1],
10796 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
10797 helper(x_cpu)
10798 for idx in range(len(self.supported_np_dtypes)):
10799 # torch.randn / torch.rand don't work with all dtypes
10800 # Generate input data for all dtypes on Numpy them move to torch
10801 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
10802 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
10803
10804 helper(inputCPU)
10805
10806 def test_advanced_indexing_3D_put(self):
10807 def helper(x_cpu):
10808 dtype = x_cpu.dtype
10809 x_mps = x_cpu.detach().clone().to("mps")
10810
10811 out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
10812 out_tensor_cpu_view = out_tensor_cpu[1:]
10813
10814 out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
10815 out_tensor_mps_view = out_tensor_mps[1:]
10816
10817 x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
10818 x_mps[[1, 2], 3, :] = out_tensor_mps_view
10819 self.assertEqual(x_cpu, x_mps)
10820
10821 x_cpu[[0, 2], :, :] = out_tensor_cpu_view
10822 x_mps[[0, 2], :, :] = out_tensor_mps_view
10823 self.assertEqual(x_cpu, x_mps)
10824
10825 x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
10826 x_mps[:, [1, 0], [1]] = out_tensor_mps_view
10827 self.assertEqual(x_cpu, x_mps)
10828
10829 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
10830 [0.5, 0.6, 0.7, 0.8],
10831 [0.9, 1.0, 1.1, 1.2],
10832 [1.3, 1.4, 1.5, 1.6]],
10833
10834 [[2.0, 2.1, 2.2, 2.3],
10835 [2.4, 2.5, 2.6, 2.7],
10836 [2.8, 2.9, 3.0, 3.1],
10837 [3.2, 3.3, 3.4, 3.5]],
10838
10839 [[4.0, 4.1, 4.2, 4.3],
10840 [4.4, 4.5, 4.6, 4.7],
10841 [4.8, 4.9, 5.0, 5.1],
10842 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
10843 helper(x_cpu)
10844 for idx in range(len(self.supported_np_dtypes)):
10845 # torch.randn / torch.rand don't work with all dtypes
10846 # Generate input data for all dtypes on Numpy them move to torch
10847 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
10848 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
10849
10850 helper(inputCPU)
10851
10852 def test_index_put_with_view_indices(self):
10853 def helper(dtype):
10854 target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
10855 target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
10856
10857 indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
10858 indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
10859
10860 value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
10861 value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
10862
10863 target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
10864 target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
10865
10866 self.assertEqual(target_cpu, target_mps)
10867
10868 [helper(dtype) for dtype in [torch.int32, torch.float]]
10869
10870 # tests from 'test_indexing.py'
10871 def test_advancedindex_big(self, device="mps"):
10872 reference = torch.arange(0, 123344, dtype=torch.int, device=device)
10873
10874 self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
10875 torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
10876
10877 def test_set_item_to_scalar_tensor(self, device="mps"):
10878 m = random.randint(1, 10)
10879 n = random.randint(1, 10)
10880 z = torch.randn([m, n], device=device)
10881 a = 1.0
10882 w = torch.tensor(a, requires_grad=True, device=device)
10883 z[:, 0] = w
10884 z.sum().backward()
10885 self.assertEqual(w.grad, m * a)
10886
10887 def test_single_int(self, device="mps"):
10888 v = torch.randn(5, 7, 3, device=device)
10889 self.assertEqual(v[4].shape, (7, 3))
10890
10891 def test_multiple_int(self, device="mps"):
10892 v = torch.randn(5, 7, 3, device=device)
10893 self.assertEqual(v[4].shape, (7, 3))
10894 self.assertEqual(v[4, :, 1].shape, (7,))
10895
10896 def test_none(self, device="mps"):
10897 v = torch.randn(5, 7, 3, device=device)
10898 self.assertEqual(v[None].shape, (1, 5, 7, 3))
10899 self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
10900 self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
10901 self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
10902
10903 def test_step(self, device="mps"):
10904 v = torch.arange(10, device=device)
10905 self.assertEqual(v[::1], v)
10906 self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
10907 self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
10908 self.assertEqual(v[::11].tolist(), [0])
10909 self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
10910
10911 def test_step_assignment(self, device="mps"):
10912 v = torch.zeros(4, 4, device=device)
10913 v[0, 1::2] = torch.tensor([3., 4.], device=device)
10914 self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
10915 self.assertEqual(v[1:].sum(), 0)
10916
Kulin Sethce7177f2022-08-18 06:03:16 +000010917 def test_bool_indices(self, device="mps"):
10918 v = torch.randn(5, 7, 3, device=device)
10919 boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
10920 self.assertEqual(v[boolIndices].shape, (3, 7, 3))
10921 self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
10922
10923 v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
10924 boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
10925 uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
10926 with warnings.catch_warnings(record=True) as w:
10927 self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
10928 self.assertEqual(v[boolIndices], v[uint8Indices])
10929 self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
10930 self.assertEqual(len(w), 2)
10931
Denis Vieriu71ec2612023-02-15 06:09:56 +000010932 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Denis Vieriuce4f1872022-09-28 00:47:52 +000010933 def test_bool_indices_accumulate(self, device="mps"):
10934 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
10935 mask = mask > 0
10936 y = torch.ones(size=(10, 10), device=device)
10937 y.index_put_((mask, ), y[mask], accumulate=True)
10938 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
10939
Kulin Sethce7177f2022-08-18 06:03:16 +000010940 def test_multiple_bool_indices(self, device="mps"):
10941 v = torch.randn(5, 7, 3, device=device)
10942 # note: these broadcast together and are transposed to the first dim
10943 mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
10944 mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
10945 self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
10946
Kulin Sethce7177f2022-08-18 06:03:16 +000010947 def test_byte_mask(self, device="mps"):
10948 v = torch.randn(5, 7, 3, device=device)
10949 mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
10950 with warnings.catch_warnings(record=True) as w:
10951 self.assertEqual(v[mask].shape, (3, 7, 3))
10952 self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
10953 self.assertEqual(len(w), 2)
10954
10955 v = torch.tensor([1.], device=device)
10956 self.assertEqual(v[v == 0], torch.tensor([], device=device))
10957
Denis Vieriuce4f1872022-09-28 00:47:52 +000010958 def test_byte_mask_accumulate(self, device="mps"):
10959 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
10960 y = torch.ones(size=(10, 10), device=device)
10961 with warnings.catch_warnings(record=True) as w:
10962 warnings.simplefilter("always")
10963 y.index_put_((mask, ), y[mask], accumulate=True)
10964 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
10965 self.assertEqual(len(w), 2)
10966
10967 def test_index_put_accumulate_expanded_values(self, device="mps"):
10968 t = torch.zeros((5, 2))
10969 t_dev = t.to(device)
10970 indices = [
10971 torch.tensor([0, 1, 2, 3]),
10972 torch.tensor([1, ]),
10973 ]
10974 indices_dev = [i.to(device) for i in indices]
10975 values0d = torch.tensor(1.0)
10976 values1d = torch.tensor([1.0, ])
10977
10978 out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
10979 out_cpu = t.index_put_(indices, values0d, accumulate=True)
10980 self.assertEqual(out_mps.cpu(), out_cpu)
10981
10982 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
10983 out_cpu = t.index_put_(indices, values1d, accumulate=True)
10984 self.assertEqual(out_mps.cpu(), out_cpu)
10985
10986 t = torch.zeros(4, 3, 2)
10987 t_dev = t.to(device)
10988
10989 indices = [
10990 torch.tensor([0, ]),
10991 torch.arange(3)[:, None],
10992 torch.arange(2)[None, :],
10993 ]
10994 indices_dev = [i.to(device) for i in indices]
10995 values1d = torch.tensor([-1.0, -2.0])
10996 values2d = torch.tensor([[-1.0, -2.0], ])
10997
10998 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
10999 out_cpu = t.index_put_(indices, values1d, accumulate=True)
11000 self.assertEqual(out_mps.cpu(), out_cpu)
11001
11002 out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
11003 out_cpu = t.index_put_(indices, values2d, accumulate=True)
11004 self.assertEqual(out_mps.cpu(), out_cpu)
11005
11006 def test_index_put_accumulate_non_contiguous(self, device="mps"):
11007 t = torch.zeros((5, 2, 2))
11008 t_dev = t.to(device)
11009 t1 = t_dev[:, 0, :]
11010 t2 = t[:, 0, :]
11011 self.assertTrue(not t1.is_contiguous())
11012 self.assertTrue(not t2.is_contiguous())
11013
11014 indices = [torch.tensor([0, 1]), ]
11015 indices_dev = [i.to(device) for i in indices]
11016 value = torch.randn(2, 2)
11017 out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
11018 out_cpu = t2.index_put_(indices, value, accumulate=True)
11019 self.assertTrue(not t1.is_contiguous())
11020 self.assertTrue(not t2.is_contiguous())
11021
11022 self.assertEqual(out_mps.cpu(), out_cpu)
11023
11024 def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
11025 # TODO: replace with a better solution.
11026 # Currently, here using torchscript to put None into indices.
11027 # on C++ it gives indices as a list of 2 optional tensors: first is null and
11028 # the second is a valid tensor.
11029 @torch.jit.script
11030 def func(x, i, v):
11031 idx = [None, i]
11032 x.index_put_(idx, v, accumulate=True)
11033 return x
11034
11035 n = 4
11036 t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
11037 t_dev = t.to(device)
11038 indices = torch.tensor([1, 0])
11039 indices_dev = indices.to(device)
11040 value0d = torch.tensor(10.0)
11041 value1d = torch.tensor([1.0, 2.0])
11042
11043 out_mps = func(t_dev, indices_dev, value0d.to("mps"))
11044 out_cpu = func(t, indices, value0d)
11045 self.assertEqual(out_mps.cpu(), out_cpu)
11046
11047 out_mps = func(t_dev, indices_dev, value1d.to("mps"))
11048 out_cpu = func(t, indices, value1d)
11049 self.assertEqual(out_mps.cpu(), out_cpu)
11050
11051 def test_index_put_accumulate_duplicate_indices(self, device="mps"):
11052 for i in range(1, 128):
11053 # generate indices by random walk, this will create indices with
11054 # lots of duplicates interleaved with each other
11055 delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
11056
Nikita Shulga657f2e12022-11-04 01:22:41 +000011057 indices = delta.cumsum(0).long().to("mps")
Denis Vieriuce4f1872022-09-28 00:47:52 +000011058
11059 # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
Denis Vieriu6a14fcb2022-09-29 23:23:00 +000011060 input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
Denis Vieriuce4f1872022-09-28 00:47:52 +000011061 values = torch.randn(indices.size(0), device=device)
11062 output = input.index_put((indices,), values, accumulate=True)
11063
11064 input_list = input.tolist()
11065 indices_list = indices.tolist()
11066 values_list = values.tolist()
11067 for i, v in zip(indices_list, values_list):
11068 input_list[i] += v
11069
11070 self.assertEqual(output, input_list)
11071
Li-Huai (Allan) Lin3b6a7f42023-05-08 00:57:29 +000011072 def test_index_put_deterministic(self, device="mps"):
11073 def helper(dtype, accumulate, deterministic, num_tests=128):
11074 acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype)
11075 non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype)
11076 t_idx = torch.tensor(
11077 [0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2,
11078 0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2]
11079 )
11080 for _ in range(num_tests):
11081 try:
11082 torch.use_deterministic_algorithms(deterministic)
11083 t = torch.zeros(3, dtype=dtype, device=device)
11084 t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate)
11085 if accumulate:
11086 self.assertEqual(t, acc_expected)
11087 else:
11088 self.assertEqual(t, non_acc_expected)
11089 finally:
11090 torch.use_deterministic_algorithms(False)
11091
11092 for accumulate, deterministic in product((False, True), (False, True)):
11093 dtype = torch.float if accumulate else torch.long
11094 if not accumulate and not deterministic:
11095 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"):
11096 helper(dtype, accumulate, deterministic)
11097 else:
11098 helper(dtype, accumulate, deterministic)
11099
Denis Vieriuce4f1872022-09-28 00:47:52 +000011100 def test_multiple_byte_mask(self, device="mps"):
11101 v = torch.randn(5, 7, 3, device=device)
11102 # note: these broadcast together and are transposed to the first dim
11103 mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
11104 mask2 = torch.ByteTensor([1, 1, 1]).to(device)
11105 with warnings.catch_warnings(record=True) as w:
11106 warnings.simplefilter("always")
11107 self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
11108 self.assertEqual(len(w), 2)
11109
11110 def test_byte_mask2d(self, device="mps"):
11111 v = torch.randn(5, 7, 3, device=device)
11112 c = torch.randn(5, 7, device=device)
11113 num_ones = (c > 0).sum()
11114 r = v[c > 0]
11115 self.assertEqual(r.shape, (num_ones, 3))
11116
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011117 def test_jit_indexing(self, device="mps"):
11118 def fn1(x):
11119 x[x < 50] = 1.0
11120 return x
Denis Vieriuce4f1872022-09-28 00:47:52 +000011121
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011122 def fn2(x):
11123 x[0:50] = 1.0
11124 return x
Denis Vieriuce4f1872022-09-28 00:47:52 +000011125
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011126 scripted_fn1 = torch.jit.script(fn1)
11127 scripted_fn2 = torch.jit.script(fn2)
11128 data = torch.arange(100, device=device, dtype=torch.float)
11129 out = scripted_fn1(data.detach().clone())
11130 ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
11131 self.assertEqual(out, ref)
11132 out = scripted_fn2(data.detach().clone())
11133 self.assertEqual(out, ref)
Denis Vieriuce4f1872022-09-28 00:47:52 +000011134
11135 def test_int_indices(self, device="mps"):
11136 v = torch.randn(5, 7, 3, device=device)
11137 self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
11138 self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
11139 self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
11140
11141 def test_index_put_src_datatype(self):
11142 def helper(device, dtype):
11143 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11144 vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
11145 indices = (torch.tensor([0, 2, 1]),)
11146 res = src.index_put_(indices, vals, accumulate=True)
11147 self.assertEqual(res.shape, src.shape)
11148 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
11149
Denis Vieriu71ec2612023-02-15 06:09:56 +000011150 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Denis Vieriuce4f1872022-09-28 00:47:52 +000011151 def test_index_src_datatype(self):
11152 def helper(device, dtype):
11153 orig_dtype = dtype
11154 if dtype is torch.bool:
11155 dtype = torch.uint8
11156
11157 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11158 if orig_dtype is torch.bool:
11159 src = src == 1
11160 # test index
11161 res = src[[0, 2, 1], :, :]
11162 self.assertEqual(res.shape, src.shape)
11163 # test index_put, no accum
11164 src[[0, 2, 1], :, :] = res
11165 self.assertEqual(res.shape, src.shape)
11166 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
11167
Kulin Sethce7177f2022-08-18 06:03:16 +000011168 def test_int_indices2d(self, device="mps"):
11169 # From the NumPy indexing example
11170 x = torch.arange(0, 12, device=device).view(4, 3)
11171 rows = torch.tensor([[0, 0], [3, 3]], device=device)
11172 columns = torch.tensor([[0, 2], [0, 2]], device=device)
11173 self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
11174
11175 def test_int_indices_broadcast(self, device="mps"):
11176 # From the NumPy indexing example
11177 x = torch.arange(0, 12, device=device).view(4, 3)
11178 rows = torch.tensor([0, 3], device=device)
11179 columns = torch.tensor([0, 2], device=device)
11180 result = x[rows[:, None], columns]
11181 self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
11182
Denis Vieriuce4f1872022-09-28 00:47:52 +000011183 def test_empty_index(self, device="mps"):
11184 x = torch.arange(0, 12, device=device).view(4, 3)
11185 idx = torch.tensor([], dtype=torch.long, device=device)
11186 self.assertEqual(x[idx].numel(), 0)
11187
11188 # empty assignment should have no effect but not throw an exception
11189 y = x.clone()
11190 y[idx] = -1
11191 self.assertEqual(x, y)
11192
11193 mask = torch.zeros(4, 3, device=device).bool()
11194 y[mask] = -1
11195 self.assertEqual(x, y)
11196
Kulin Sethce7177f2022-08-18 06:03:16 +000011197 def test_empty_ndim_index(self, device="mps"):
11198 x = torch.randn(5, device=device)
11199 self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
11200
11201 x = torch.randn(2, 3, 4, 5, device=device)
11202 self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
11203 x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
11204
11205 x = torch.empty(10, 0, device=device)
11206 self.assertEqual(x[[1, 2]].shape, (2, 0))
11207 self.assertEqual(x[[], []].shape, (0,))
11208 with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
11209 x[:, [0, 1]]
11210
11211 def test_empty_ndim_index_bool(self, device="mps"):
11212 x = torch.randn(5, device=device)
11213 self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
11214
Denis Vieriuce4f1872022-09-28 00:47:52 +000011215 def test_empty_slice(self, device="mps"):
11216 x = torch.randn(2, 3, 4, 5, device=device)
11217 y = x[:, :, :, 1]
11218 z = y[:, 1:1, :]
11219 self.assertEqual((2, 0, 4), z.shape)
11220 # this isn't technically necessary, but matches NumPy stride calculations.
11221 self.assertEqual((60, 20, 5), z.stride())
11222 self.assertTrue(z.is_contiguous())
11223
Kulin Sethce7177f2022-08-18 06:03:16 +000011224 def test_index_getitem_copy_bools_slices(self, device="mps"):
11225 true = torch.tensor(1, dtype=torch.uint8, device=device)
11226 false = torch.tensor(0, dtype=torch.uint8, device=device)
11227
11228 tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
11229
11230 for a in tensors:
11231 self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
11232 self.assertEqual(torch.empty(0, *a.shape), a[False])
11233 self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
11234 self.assertEqual(torch.empty(0, *a.shape), a[false])
11235 self.assertEqual(a.data_ptr(), a[None].data_ptr())
11236 self.assertEqual(a.data_ptr(), a[...].data_ptr())
11237
Denis Vieriuce4f1872022-09-28 00:47:52 +000011238 def test_index_setitem_bools_slices(self, device="mps"):
11239 true = torch.tensor(1, dtype=torch.uint8, device=device)
11240 false = torch.tensor(0, dtype=torch.uint8, device=device)
11241
11242 tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
11243
11244 for a in tensors:
11245 # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
11246 # (some of these ops already prefix a 1 to the size)
11247 neg_ones = torch.ones_like(a) * -1
11248 neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
11249 a[True] = neg_ones_expanded
11250 self.assertEqual(a, neg_ones)
11251 a[False] = 5
11252 self.assertEqual(a, neg_ones)
11253 a[true] = neg_ones_expanded * 2
11254 self.assertEqual(a, neg_ones * 2)
11255 a[false] = 5
11256 self.assertEqual(a, neg_ones * 2)
11257 a[None] = neg_ones_expanded * 3
11258 self.assertEqual(a, neg_ones * 3)
11259 a[...] = neg_ones_expanded * 4
11260 self.assertEqual(a, neg_ones * 4)
11261 if a.dim() == 0:
11262 with self.assertRaises(IndexError):
11263 a[:] = neg_ones_expanded * 5
11264
Kulin Sethce7177f2022-08-18 06:03:16 +000011265 def test_index_scalar_with_bool_mask(self, device="mps"):
11266 a = torch.tensor(1, device=device)
11267 uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
11268 boolMask = torch.tensor(True, dtype=torch.bool, device=device)
11269 self.assertEqual(a[uintMask], a[boolMask])
11270 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11271
11272 a = torch.tensor(True, dtype=torch.bool, device=device)
11273 self.assertEqual(a[uintMask], a[boolMask])
11274 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11275
Denis Vieriuce4f1872022-09-28 00:47:52 +000011276 def test_setitem_expansion_error(self, device="mps"):
11277 true = torch.tensor(True, device=device)
11278 a = torch.randn(2, 3, device=device)
11279 # check prefix with non-1s doesn't work
11280 a_expanded = a.expand(torch.Size([5, 1]) + a.size())
11281 # NumPy: ValueError
11282 with self.assertRaises(RuntimeError):
11283 a[True] = a_expanded
11284 with self.assertRaises(RuntimeError):
11285 a[true] = a_expanded
11286
Kulin Sethce7177f2022-08-18 06:03:16 +000011287 def test_getitem_scalars(self, device="mps"):
11288 zero = torch.tensor(0, dtype=torch.int64, device=device)
11289 one = torch.tensor(1, dtype=torch.int64, device=device)
11290
11291 # non-scalar indexed with scalars
11292 a = torch.randn(2, 3, device=device)
11293 self.assertEqual(a[0], a[zero])
11294 self.assertEqual(a[0][1], a[zero][one])
11295 self.assertEqual(a[0, 1], a[zero, one])
11296 self.assertEqual(a[0, one], a[zero, 1])
11297
11298 # indexing by a scalar should slice (not copy)
11299 self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
11300 self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
11301 self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
11302
11303 # scalar indexed with scalar
11304 r = torch.randn((), device=device)
11305 with self.assertRaises(IndexError):
11306 r[:]
11307 with self.assertRaises(IndexError):
11308 r[zero]
11309 self.assertEqual(r, r[...])
11310
Denis Vieriuce4f1872022-09-28 00:47:52 +000011311 def test_setitem_scalars(self, device="mps"):
11312 zero = torch.tensor(0, dtype=torch.int64)
11313
11314 # non-scalar indexed with scalars
11315 a = torch.randn(2, 3, device=device)
11316 a_set_with_number = a.clone()
11317 a_set_with_scalar = a.clone()
11318 b = torch.randn(3, device=device)
11319
11320 a_set_with_number[0] = b
11321 a_set_with_scalar[zero] = b
11322 self.assertEqual(a_set_with_number, a_set_with_scalar)
11323 a[1, zero] = 7.7
11324 self.assertEqual(7.7, a[1, 0])
11325
11326 # scalar indexed with scalars
11327 r = torch.randn((), device=device)
11328 with self.assertRaises(IndexError):
11329 r[:] = 8.8
11330 with self.assertRaises(IndexError):
11331 r[zero] = 8.8
11332 r[...] = 9.9
11333 self.assertEqual(9.9, r)
11334
11335 def test_basic_advanced_combined(self, device="mps"):
11336 # From the NumPy indexing example
11337 x = torch.arange(0, 12, device=device).view(4, 3)
11338 self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
11339 self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
11340
11341 # Check that it is a copy
11342 unmodified = x.clone()
11343 x[1:2, [1, 2]].zero_()
11344 self.assertEqual(x, unmodified)
11345
11346 # But assignment should modify the original
11347 unmodified = x.clone()
11348 x[1:2, [1, 2]] = 0
11349 self.assertNotEqual(x, unmodified)
11350
11351 def test_int_assignment(self, device="mps"):
11352 x = torch.arange(0, 4, device=device).view(2, 2)
11353 x[1] = 5
11354 self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
11355
11356 x = torch.arange(0, 4, device=device).view(2, 2)
11357 x[1] = torch.arange(5, 7, device=device)
11358 self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
11359
11360 def test_byte_tensor_assignment(self, device="mps"):
11361 x = torch.arange(0., 16, device=device).view(4, 4)
11362 b = torch.ByteTensor([True, False, True, False]).to(device)
11363 value = torch.tensor([3., 4., 5., 6.], device=device)
11364
11365 with warnings.catch_warnings(record=True) as w:
11366 x[b] = value
11367 self.assertEqual(len(w), 1)
11368
11369 self.assertEqual(x[0], value)
11370 self.assertEqual(x[1], torch.arange(4., 8, device=device))
11371 self.assertEqual(x[2], value)
11372 self.assertEqual(x[3], torch.arange(12., 16, device=device))
11373
Kulin Sethce7177f2022-08-18 06:03:16 +000011374 def test_variable_slicing(self, device="mps"):
11375 x = torch.arange(0, 16, device=device).view(4, 4)
11376 indices = torch.IntTensor([0, 1]).to(device)
11377 i, j = indices
11378 self.assertEqual(x[i:j], x[0:1])
11379
11380 def test_ellipsis_tensor(self, device="mps"):
11381 x = torch.arange(0, 9, device=device).view(3, 3)
11382 idx = torch.tensor([0, 2], device=device)
11383 self.assertEqual(x[..., idx].tolist(), [[0, 2],
11384 [3, 5],
11385 [6, 8]])
11386 self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
11387 [6, 7, 8]])
11388
11389 def test_invalid_index(self, device="mps"):
11390 x = torch.arange(0, 16, device=device).view(4, 4)
11391 self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
11392
Denis Vieriuce4f1872022-09-28 00:47:52 +000011393 def test_out_of_bound_index(self, device="mps"):
11394 x = torch.arange(0, 100, device=device).view(2, 5, 10)
11395 self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
11396 self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
11397 self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
11398 lambda: x[0, 1, 15])
11399 self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
11400 lambda: x[:, :, 12])
11401
11402 def test_zero_dim_index(self, device="mps"):
11403 x = torch.tensor(10, device=device)
11404 self.assertEqual(x, x.item())
11405
11406 def runner():
11407 print(x[0])
11408 return x[0]
11409
11410 self.assertRaisesRegex(IndexError, 'invalid index', runner)
11411
11412 def test_cpu_indices(self, device="mps"):
11413 idx = torch.tensor([0, 1])
11414 b = torch.zeros(2, device=device)
11415 x = torch.ones(10, device=device)
11416 x[idx] = b # index_put_
11417 ref = torch.ones(10, device=device)
11418 ref[:2] = 0
11419 self.assertEqual(x, ref, atol=0, rtol=0)
11420 out = x[idx] # index
11421 self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
11422
Nikita Shulga5944a532024-04-27 02:58:05 +000011423 def test_nextafter(self, device="mps"):
11424 for dtype in [torch.float16, torch.float32]:
11425 x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype)
11426 y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype)
11427 na = torch.nextafter(x, y)
11428 na_cpu = torch.nextafter(x.cpu(), y.cpu())
11429 na_ge_x_mps = na.cpu() > x.cpu()
11430 # greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051
11431 na_ge_x_cpu = na_cpu > x.cpu()
11432 self.assertEqual(na_ge_x_mps, na_ge_x_cpu)
11433
11434
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000011435class TestRNNMPS(TestCaseMPS):
alexdremov78da3152023-03-05 00:19:51 +000011436 def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
11437 seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
11438 rnn = nn.LSTM(
11439 input_size=input_size,
11440 hidden_size=hidden_size,
11441 num_layers=num_layers,
11442 bias=bias,
11443 bidirectional=bidirectional,
11444 batch_first=batch_first,
11445 device="cpu"
11446 )
11447 bidirectional_mul = 2 if bidirectional else 1
Kulin Sethe011a8e2022-05-13 18:28:53 +000011448
alexdremov78da3152023-03-05 00:19:51 +000011449 if batch_first:
11450 input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11451 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11452 requires_grad=backward)
11453 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11454 requires_grad=backward)
11455 else:
11456 input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11457 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11458 requires_grad=backward)
11459 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11460 requires_grad=backward)
Kulin Sethe011a8e2022-05-13 18:28:53 +000011461
alexdremov78da3152023-03-05 00:19:51 +000011462 cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
11463
11464 rnn = rnn.to(device)
11465 input = input.to(device)
11466 hx = hx.to(device)
11467 cx = cx.to(device)
11468 output, (hn, cn) = rnn(input, (hx, cx))
11469
11470 self.assertEqual(cpu_output, output)
11471 self.assertEqual(cpu_hn, hn)
11472 self.assertEqual(cpu_cn, cn)
11473
alexdremov62eb7a22023-03-16 15:53:52 +000011474 def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
alexdremovb9e95152023-02-23 17:32:42 +000011475 rnn = rnn.to(device)
alexdremov78da3152023-03-05 00:19:51 +000011476 inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
Alban Desmaison02551a02022-05-28 12:39:10 -040011477
alexdremov62eb7a22023-03-16 15:53:52 +000011478 output, (hx_out, cx_out) = rnn(inp, (hx, cx))
11479 assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
11480
11481 f = 0
11482 if output_grad_presented:
11483 f = f + 3 * output.sum()
11484 if states_grad_presented:
11485 f = f + (hx_out * cx_out).sum()
qqaatwb0b24b42022-07-07 07:18:00 +000011486
alexdremov78da3152023-03-05 00:19:51 +000011487 param_names, params = zip(*rnn.named_parameters())
11488 param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
qqaatwb0b24b42022-07-07 07:18:00 +000011489
alexdremov78da3152023-03-05 00:19:51 +000011490 input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
11491 return output, param_grads, input_grad, hx_grad, cx_grad
qqaatwb0b24b42022-07-07 07:18:00 +000011492
alexdremov78da3152023-03-05 00:19:51 +000011493 if backward:
alexdremov62eb7a22023-03-16 15:53:52 +000011494 grad_cases = [
11495 dict(output_grad_presented=True, states_grad_presented=True),
11496 dict(output_grad_presented=False, states_grad_presented=True),
11497 dict(output_grad_presented=True, states_grad_presented=False),
11498 ]
alexdremov78da3152023-03-05 00:19:51 +000011499
alexdremov62eb7a22023-03-16 15:53:52 +000011500 for grad_case in grad_cases:
11501 cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
11502 get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
11503 mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
11504 get_backward_results(rnn, device, input, hx, cx, **grad_case)
11505
11506 self.assertEqual(cpu_hx_grad, mps_hx_grad)
11507 self.assertEqual(cpu_cx_grad, mps_cx_grad)
11508 self.assertEqual(cpu_output, mps_output)
11509 self.assertEqual(cpu_input_grad, mps_input_grad)
11510 for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
11511 self.assertEqual(cpu_weight_grad, mps_weight_grad,
11512 f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
alexdremov78da3152023-03-05 00:19:51 +000011513
11514 LSTM_TEST_CASES = [
11515 dict(), # default
11516 dict(batch_first=True),
11517 dict(bias=False),
11518 dict(bidirectional=True),
11519 dict(batch_first=True, bias=False),
11520 dict(bidirectional=True, bias=False),
11521 dict(bidirectional=True, batch_first=True),
11522 dict(bidirectional=True, batch_first=True, bias=False)
11523 ]
11524
11525 def test_lstm_forward(self, device="mps", dtype=torch.float32):
Li-Huai (Allan) Lina87f3f62023-03-10 03:10:49 +000011526 for num_layers in [1, 2, 5]:
alexdremov78da3152023-03-05 00:19:51 +000011527 for test_options in self.LSTM_TEST_CASES:
11528 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
qqaatwb0b24b42022-07-07 07:18:00 +000011529
Nikita Shulga4e29e802024-05-09 13:43:12 +000011530 # Broke on MacOS-14.4 (but works on 14.2), see https://github.com/pytorch/pytorch/issues/125803
11531 @xfailIfMacOS14_4Plus
alexdremovb9e95152023-02-23 17:32:42 +000011532 def test_lstm_backward(self, device="mps", dtype=torch.float32):
Li-Huai (Allan) Lina87f3f62023-03-10 03:10:49 +000011533 for num_layers in [1, 2, 5]:
alexdremov78da3152023-03-05 00:19:51 +000011534 for test_options in self.LSTM_TEST_CASES:
11535 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
alexdremovb9e95152023-02-23 17:32:42 +000011536
Kulin Seth54ebf252023-02-15 16:10:40 +000011537 def test_RNN_cell_no_broadcasting(self):
11538 def test(cell_module, input, hx, input_size, hidden_size):
11539 cell = cell_module(input_size, hidden_size, device='mps')
11540 self.assertRaises(RuntimeError, lambda: cell(input, hx))
11541
11542 def test_all(hidden_size, bad_hx, good_hx, input_size, input):
11543 test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
11544 test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
11545 test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
11546 test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
11547
11548 hidden_size = 20
11549 input_size = 10
11550 input = torch.randn(3, input_size, device='mps')
11551 bad_hx = torch.randn(1, hidden_size, device='mps')
11552 good_hx = torch.randn(3, hidden_size, device='mps')
11553
11554 # Test hidden/input batch size broadcasting
11555 test_all(hidden_size, bad_hx, good_hx, input_size, input)
11556
11557 # Test hx's hidden_size vs module's hidden_size broadcasting
11558 bad_hx = torch.randn(3, 1)
11559 test_all(hidden_size, bad_hx, good_hx, input_size, input)
11560
11561 # Test input's input_size vs module's input_size broadcasting
11562 bad_input = torch.randn(3, 1)
11563 test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
11564
11565 def test_LSTM_cell(self):
11566 # this is just a smoke test; these modules are implemented through
11567 # autograd so no Jacobian test is needed
11568 for bias in (True, False):
11569 input = torch.randn(3, 10, device='mps')
11570 hx = torch.randn(3, 20, device='mps')
11571 cx = torch.randn(3, 20, device='mps')
11572 lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
11573 for _ in range(6):
11574 hx, cx = lstm(input, (hx, cx))
11575
11576 (hx + cx).sum().backward()
11577
11578 def test_LSTM_cell_forward_input_size(self):
11579 input = torch.randn(3, 11, device='mps')
11580 hx = torch.randn(3, 20, device='mps')
11581 cx = torch.randn(3, 20, device='mps')
11582 lstm = nn.LSTMCell(10, 20, device='mps')
11583 self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11584
11585 def test_LSTM_cell_forward_hidden_size(self):
11586 input = torch.randn(3, 10, device='mps')
11587 hx = torch.randn(3, 21, device='mps')
11588 cx = torch.randn(3, 20, device='mps')
11589 lstm = nn.LSTMCell(10, 20, device='mps')
11590 self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11591 self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
11592
11593
Kulin Seth3d833212022-05-20 03:18:09 +000011594class TestFallbackWarning(TestCase):
Nikita Shulga97594a22022-06-09 13:07:03 +000011595 # TODO: Remove once test_testing.py is running on MPS devices
Kulin Seth3d833212022-05-20 03:18:09 +000011596 def test_no_warning_on_import(self):
Nikita Shulga97594a22022-06-09 13:07:03 +000011597 out = subprocess.check_output(
11598 [sys.executable, "-W", "all", "-c", "import torch"],
11599 stderr=subprocess.STDOUT,
11600 # On Windows, opening the subprocess with the default CWD makes `import torch`
11601 # fail, so just set CWD to this script's directory
11602 cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
Nikita Shulga078c25df2022-11-08 21:10:07 +000011603 self.assertEqual(out, "")
Kulin Seth3d833212022-05-20 03:18:09 +000011604
11605 def _get_not_implemented_op(self):
igm5031b9b3a22023-09-12 16:43:37 +000011606 # This can be changed once we actually implement 'lcm'
Kulin Seth3d833212022-05-20 03:18:09 +000011607 # Should return fn, args, kwargs, string_version
igm5031b9b3a22023-09-12 16:43:37 +000011608 return (torch.lcm,
11609 [torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {},
11610 "torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))")
Kulin Seth3d833212022-05-20 03:18:09 +000011611
11612 def test_error_on_not_implemented(self):
11613 fn, args, kwargs, _ = self._get_not_implemented_op()
11614
Nikita Shulga9b16bf02022-09-12 22:25:26 +000011615 with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
Kulin Seth3d833212022-05-20 03:18:09 +000011616 fn(*args, **kwargs)
11617
11618 def test_warn_on_not_implemented_with_fallback(self):
11619 _, _, _, op = self._get_not_implemented_op()
11620 script = f"""
11621import os
11622# MUST happen before pytorch's import
11623os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11624import warnings
11625
11626with warnings.catch_warnings(record=True) as w:
11627 import torch
11628
11629if len(w) > 0:
Nikita Shulga97594a22022-06-09 13:07:03 +000011630 print(w)
Kulin Seth3d833212022-05-20 03:18:09 +000011631 exit(1)
11632
11633# This should run just fine and raise warning about perf
11634with warnings.catch_warnings(record=True) as w:
11635 {op}
11636
11637if len(w) != 1:
Nikita Shulga97594a22022-06-09 13:07:03 +000011638 print(w)
Kulin Seth3d833212022-05-20 03:18:09 +000011639 exit(2)
11640
11641"""
11642 try:
11643 subprocess.check_output(
11644 [sys.executable, '-W', 'all', '-c', script],
11645 stderr=subprocess.STDOUT,
11646 # On Windows, opening the subprocess with the default CWD makes `import torch`
11647 # fail, so just set CWD to this script's directory
11648 cwd=os.path.dirname(os.path.realpath(__file__)),)
11649 except subprocess.CalledProcessError as e:
11650 if e.returncode == 1:
Nikita Shulga97594a22022-06-09 13:07:03 +000011651 self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
11652 e.output.decode("utf-8"))
Kulin Seth3d833212022-05-20 03:18:09 +000011653 elif e.returncode == 2:
11654 self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
Nikita Shulga97594a22022-06-09 13:07:03 +000011655 f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
Kulin Seth3d833212022-05-20 03:18:09 +000011656 else:
Nikita Shulga97594a22022-06-09 13:07:03 +000011657 self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
11658 e.output.decode("utf-8"))
Kulin Sethe011a8e2022-05-13 18:28:53 +000011659
Alban Desmaison04ac80c2022-05-20 20:25:12 +000011660class TestNoRegression(TestCase):
11661 def test_assert_close(self):
11662 a = torch.ones(1, device="mps")
11663 b = torch.zeros(1, device="mps")
11664 inf = a / b
11665 nan = b / b
11666
11667 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11668 torch.testing.assert_close(a, inf)
11669
Kulin Seth76cff182022-07-04 06:41:39 +000011670 # TODO: The NaN test is failing when all the tests in test_mps are run
11671 # together but passes when run separately. There seems to be memory
11672 # corruption which needs to be fixed for this test to be enabled.
11673 # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11674 # torch.testing.assert_close(a, nan)
Alban Desmaison04ac80c2022-05-20 20:25:12 +000011675
11676 def test_double_error(self):
11677 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11678 a = torch.ones(2, dtype=torch.float64, device="mps")
11679
11680 a = torch.ones(2, device="mps")
11681 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11682 a = a.double()
11683
11684 def test_legacy_constructor(self):
11685 a = torch.ones(2, device="mps")
11686
11687 b = a.new(1)
11688
Alban Desmaison0a651a22022-06-14 17:54:30 +000011689 def test_serialization_map_location(self):
11690
11691 # Ensures that cpu Tensor can be loaded on mps
11692 with tempfile.NamedTemporaryFile() as f:
11693 x = torch.rand(2)
11694 torch.save(x, f)
11695
11696 f.seek(0)
11697 x2 = torch.load(f, map_location="mps")
11698
11699 self.assertEqual(x, x2)
11700 self.assertEqual(x2.device.type, "mps")
11701
11702 # Ensures that mps Tensors can be loaded on mps
11703 with tempfile.NamedTemporaryFile() as f:
11704 x = torch.rand(2, device="mps")
11705 torch.save(x, f)
11706
11707 f.seek(0)
11708 x2 = torch.load(f)
11709
11710 self.assertEqual(x, x2)
11711 self.assertEqual(x2.device.type, "mps")
11712
11713 # Ensures that mps Tensors can be loaded on cpu
11714 with tempfile.NamedTemporaryFile() as f:
11715 x = torch.rand(2, device="mps")
11716 torch.save(x, f)
11717
11718 f.seek(0)
11719 x2 = torch.load(f, map_location="cpu")
11720
11721 self.assertEqual(x, x2)
11722 self.assertEqual(x2.device.type, "cpu")
11723
magic-akarie56cdfd2023-06-15 15:51:03 +000011724 # Ensures that `mps:0` Tensors can be loaded on mps
11725 with tempfile.NamedTemporaryFile() as f:
11726 x = torch.rand(2, device="mps:0")
11727 torch.save(x, f)
11728
11729 f.seek(0)
11730 x2 = torch.load(f, map_location="mps:0")
11731
11732 self.assertEqual(x, x2)
11733 self.assertEqual(x2.device.type, "mps")
11734
Alban Desmaison0a651a22022-06-14 17:54:30 +000011735
Kulin Seth76cff182022-07-04 06:41:39 +000011736MPS_DTYPES = get_all_dtypes()
Denis Vieriued1957d2023-03-01 01:36:36 +000011737for t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]:
Kulin Seth76cff182022-07-04 06:41:39 +000011738 del MPS_DTYPES[MPS_DTYPES.index(t)]
Alban Desmaison04ac80c2022-05-20 20:25:12 +000011739
Kulin Seth2bb022e2023-03-08 08:41:21 +000011740MPS_GRAD_DTYPES = [torch.float32, torch.float16]
11741
soulitzerbfdfeec2022-08-31 17:53:32 -040011742
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000011743class TestConsistency(TestCaseMPS):
Kulin Seth76cff182022-07-04 06:41:39 +000011744 # TODO: This is only used while some ops are being added.
11745 # This list should contain all ops and dtypes eventually
11746 # This can be generated automatically in the `new_mps_allowlist.txt` file
11747 # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
11748 # You most likely do NOT want to modify this manually
Kulin Seth76cff182022-07-04 06:41:39 +000011749
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000011750 FP16_LOW_PRECISION_LIST = {
CaoE2a40b7e2023-10-19 17:47:45 +000011751 'add', 'sub', 'div', 'addcdiv',
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000011752 '__rdiv__', '__rmul__',
11753 'nn.functional.huber_loss',
11754 'true_divide', 'kron',
Nikita Shulgaeb9a3382024-02-13 21:51:27 +000011755 'gradient', 'var', 'std', 'std_mean', 'ldexp',
Jane Xu803d42e2023-07-21 10:03:12 -070011756 'linalg.vector_norm', 'lerp',
Kulin Seth2bb022e2023-03-08 08:41:21 +000011757 'addr', 'var_mean',
11758 'var_mean_unbiased',
Sun, Jiayid56e1b22023-05-11 15:30:59 +080011759 'acosh', 'asinh', 'asin',
11760 'masked.std',
11761 'nn.functional.normalize',
11762 'nn.functional.triplet_margin_loss',
11763 'nn.functional.triplet_margin_with_distance_loss',
CaoE54c28c52023-09-18 19:10:53 -070011764 'nn.functional.batch_norm',
11765 'nn.functional.instance_norm',
Pearu Peterson45401ef2023-06-14 14:00:05 +030011766 'round', 'xlogy', 'addcmul',
Nikita Shulga56771282024-04-18 15:21:01 +000011767 'nn.functional.cross_entropy',
11768 'nn.functional.binary_cross_entropy',
11769 'nn.functional.nll_loss',
CaoE42f94d72023-08-31 18:48:38 -070011770 'nn.functional.max_pool2d',
11771 'nn.functional.gelu',
11772 'nn.functional.glu',
CaoE54c28c52023-09-18 19:10:53 -070011773 '_native_batch_norm_legit',
andrewor14773ae812024-03-18 07:27:27 -070011774 '_batch_norm_with_update',
CaoE54c28c52023-09-18 19:10:53 -070011775 'native_batch_norm',
Cao E1c89ea72023-10-26 08:38:54 +000011776 'softmax',
11777 '_softmax_backward_data',
11778 'log_softmax',
11779 'masked.softmax',
11780 'masked.log_softmax',
11781 'masked.softmin',
11782 'nn.functional.kl_div',
11783 'nn.functional.softmin',
CaoEa310cc82023-10-31 09:12:47 +000011784 'cross', 'linalg.cross',
CaoE26b5e272023-11-05 12:31:38 +000011785 'prod', 'masked.prod',
CaoE455241b2023-11-06 06:01:29 +000011786 'nextafter',
Sun, Jiayic173a9d2023-12-19 15:39:04 +080011787 'native_layer_norm',
11788 'nn.functional.layer_norm',
Sun, Jiayi2dd4a252024-01-18 09:07:16 +000011789 'nn.functional.interpolate',
11790 'nn.functional.upsample_bilinear',
11791 'nn.functional.upsample_nearest',
Kulin Seth2bb022e2023-03-08 08:41:21 +000011792
11793 # for macOS 12
11794 'masked.normalize', 'masked.sum', 'masked.var',
11795 'outer',
11796 'sum_to_size', 'sum',
11797 'mul',
11798 'nansum', 'nanmean',
11799 'norm',
11800 }
11801
11802 FP32_LOW_PRECISION_LIST = {
11803 # conv2d and conv_transpose2d results have a very small
11804 # difference compared to CPU/CUDA, so we use lower precision on FP32
11805 'nn.functional.conv2d',
11806 'nn.functional.conv_transpose2d',
11807 'matmul', '__rmatmul__',
11808 'linalg.multi_dot',
11809 'addbmm',
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000011810 }
11811
Nikita Shulga07330ff2024-03-13 04:08:06 +000011812 def _compute_tolerances(self, op, dtype):
Nikita Shulga045309a2024-05-28 17:56:13 +000011813 if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]:
Nikita Shulga07330ff2024-03-13 04:08:06 +000011814 return (1e-4, 3e-5)
11815
11816 if op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16:
11817 return (1e-2, 1e-2)
11818
11819 if op.name in ['nn.functional.conv_transpose1d',
11820 'nn.functional.conv_transpose2d',
11821 'nn.functional.conv_transpose3d',
11822 '__rmatmul__', 'addbmm', 'addmv',
11823 'baddbmm', 'cov', 'matmul', 'mv'] and dtype == torch.float16:
11824 return (5e-2, 5e-2)
11825 if op.name == "masked.mean":
11826 return (7e-4, 2e-3)
11827 if op.name == "native_layer_norm":
11828 return (1e-4, 1.3e-5)
11829 if op.name in ["pow", "__rpow__"] and product_version < 13.3:
11830 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
11831 # fixed in macOS 13.3+
11832 return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6)
11833 if op.name == "nn.functional.interpolate":
11834 return (1e-3, 1e-4)
11835 if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']:
11836 # TODO: Investigate why this is needed
11837 # See https://github.com/pytorch/pytorch/issues/120237
11838 return (3e-5, 3e-5)
11839 return (None, None)
11840
Kulin Seth76cff182022-07-04 06:41:39 +000011841 # Used for accept mode only
11842 NEW_ALLOW_LIST = defaultdict(list)
soulitzerbfdfeec2022-08-31 17:53:32 -040011843 NEW_ALLOW_LIST_GRAD = defaultdict(list)
Kulin Seth76cff182022-07-04 06:41:39 +000011844
Nikita Shulga53a4ca42023-08-31 20:41:39 -070011845 @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES + [torch.complex64])
Kulin Seth76cff182022-07-04 06:41:39 +000011846 def test_output_match(self, device, dtype, op):
11847 self.assertEqual(device, "cpu")
Nikita Shulga3859aac2022-12-14 19:51:00 +000011848
Kulin Seth2bb022e2023-03-08 08:41:21 +000011849 def get_samples():
11850 return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex))
11851 cpu_samples = get_samples()
Kulin Seth76cff182022-07-04 06:41:39 +000011852
Kulin Seth2bb022e2023-03-08 08:41:21 +000011853 for cpu_sample in cpu_samples:
11854 #
11855 # Forward check
11856 #
11857 mps_sample = cpu_sample.transform(
11858 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
11859
11860 cpu_args = [cpu_sample.input] + list(cpu_sample.args)
11861 cpu_kwargs = cpu_sample.kwargs
11862 mps_args = [mps_sample.input] + list(mps_sample.args)
11863 mps_kwargs = mps_sample.kwargs
11864
11865 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
vfdevb7624fc2023-08-29 10:46:02 +000011866 if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
Kulin Seth2bb022e2023-03-08 08:41:21 +000011867 mps_args[1] = cpu_args[1]
11868
11869 cpu_out = op(*cpu_args, **cpu_kwargs)
11870 mps_out = op(*mps_args, **mps_kwargs)
11871
Nikita Shulga07330ff2024-03-13 04:08:06 +000011872 atol, rtol = self._compute_tolerances(op, dtype)
11873 if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8:
vfdev-5d2a2a672023-10-06 10:01:15 +000011874 atol = 1.0
11875 rtol = 0.0
Kulin Seth2bb022e2023-03-08 08:41:21 +000011876
11877 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
11878
11879
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000011880 @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
Kulin Seth2bb022e2023-03-08 08:41:21 +000011881 def test_output_grad_match(self, device, dtype, op):
11882 self.assertEqual(device, "cpu")
Kulin Seth76cff182022-07-04 06:41:39 +000011883
soulitzerbfdfeec2022-08-31 17:53:32 -040011884 def get_samples():
11885 return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex))
11886 cpu_samples = get_samples()
11887
soulitzerbfdfeec2022-08-31 17:53:32 -040011888 for cpu_sample in cpu_samples:
11889 #
11890 # Forward check
11891 #
11892 forward_failed = False
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000011893 mps_sample = cpu_sample.transform(
11894 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
Kulin Seth76cff182022-07-04 06:41:39 +000011895
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000011896 cpu_args = [cpu_sample.input] + list(cpu_sample.args)
11897 cpu_kwargs = cpu_sample.kwargs
11898 mps_args = [mps_sample.input] + list(mps_sample.args)
11899 mps_kwargs = mps_sample.kwargs
Kulin Seth76cff182022-07-04 06:41:39 +000011900
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000011901 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
vfdevb7624fc2023-08-29 10:46:02 +000011902 if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000011903 mps_args[1] = cpu_args[1]
Ramin Azarmehrb654d142023-02-07 15:56:46 +000011904
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000011905 cpu_out = op(*cpu_args, **cpu_kwargs)
11906 mps_out = op(*mps_args, **mps_kwargs)
Kulin Seth76cff182022-07-04 06:41:39 +000011907
Nikita Shulga07330ff2024-03-13 04:08:06 +000011908 if op.name == "unique" and cpu_kwargs["sorted"] is False:
11909 continue
11910
11911 atol, rtol = self._compute_tolerances(op, dtype)
11912 if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16:
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000011913 atol = 7e-4
11914 rtol = 1.5e-3
Kulin Seth76cff182022-07-04 06:41:39 +000011915
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000011916 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
soulitzerbfdfeec2022-08-31 17:53:32 -040011917
soulitzerbfdfeec2022-08-31 17:53:32 -040011918 #
11919 # Backward check
11920 #
Kulin Seth2bb022e2023-03-08 08:41:21 +000011921 if forward_failed:
11922 # We would've failed immediately anyway, but this error is clearer
11923 # We error instead of continuing so that all_backward_pass would not be True
11924 raise RuntimeError("Forward pass already failed")
soulitzerbfdfeec2022-08-31 17:53:32 -040011925
Kulin Seth2bb022e2023-03-08 08:41:21 +000011926 cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
11927 mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
11928
11929 def req_grad(t):
11930 return isinstance(t, torch.Tensor) and t.requires_grad
11931
11932 diff_cpu_out = tuple(t for t in cpu_out if req_grad(t))
11933 diff_mps_out = tuple(t for t in mps_out if req_grad(t))
Peter Bellbbd5b932023-10-30 00:05:29 +000011934 diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t))
11935 diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t))
Kulin Seth2bb022e2023-03-08 08:41:21 +000011936 self.assertEqual(len(diff_cpu_out), len(diff_mps_out))
11937 self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg))
11938
11939 if len(diff_cpu_out) == 0:
soulitzerbfdfeec2022-08-31 17:53:32 -040011940 continue
Kulin Seth2bb022e2023-03-08 08:41:21 +000011941 # rand_like does not work with certain dtypes, so cast to double and cast back
Nikita Shulga6e85a682023-08-25 03:16:18 +000011942 cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out)
Kulin Seth2bb022e2023-03-08 08:41:21 +000011943 mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
soulitzerbfdfeec2022-08-31 17:53:32 -040011944
Kulin Seth2bb022e2023-03-08 08:41:21 +000011945 # Compare computed gradients with cpu given random grad_output vector
11946 # Sometimes when the derivative is 0, we just don't bother creating the graph
11947 # allow_unused is needed in those cases.
11948 cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
11949 mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
soulitzerbfdfeec2022-08-31 17:53:32 -040011950
Kulin Seth2bb022e2023-03-08 08:41:21 +000011951 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
Alex620dbc42022-10-21 19:03:00 +000011952
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000011953
Li-Huai (Allan) Lina50fb502023-05-01 14:54:57 +080011954class TestErrorInputs(TestCase):
11955 _ignore_not_implemented_error = True
11956
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000011957 @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
11958 def test_error_inputs(self, device, op):
Ramin Azarmehrcecfcf12023-05-09 03:55:16 +000011959 self.assertEqual(device, "mps:0")
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000011960
11961 mps_samples = op.error_inputs(device)
11962
11963 for mps_sample in mps_samples:
11964 mps_sample_input = mps_sample.sample_input
11965 error_type = mps_sample.error_type
11966 error_regex = mps_sample.error_regex
11967
11968 mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
11969 mps_kwargs = mps_sample_input.kwargs
11970
11971 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
11972 if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
11973 mps_args[1] = mps_args[1].cpu()
11974
11975 with self.assertRaisesRegex(error_type, error_regex):
11976 op(*mps_args, **mps_kwargs)
11977
Nikita Shulga1d610112024-02-08 18:10:59 +000011978class TestComplex(TestCase):
11979 def test_tensor_scalar_binops(self):
11980 # Regression test for https://github.com/pytorch/pytorch/issues/119088
11981 def to_cpu(x):
11982 return x.cpu() if isinstance(x, torch.Tensor) else x
11983
11984 # Allocate tensors on mps
11985 with torch.device("mps"):
11986 inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]]
11987 self.assertTrue(all(x.device.type == "mps" for x in inputs))
11988 # Add scalars
11989 inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)])
11990
11991 # Iterate over all permutations of types(int, float, complex, half) and ops (excluding div)
11992 for x, y in itertools.product(inputs, inputs):
11993 for op_name in ["__add__", "__sub__", "__mul__"]:
11994 x_cpu, y_cpu = map(to_cpu, (x, y))
11995 res = getattr(x, op_name)(y)
11996 res_cpu = getattr(x_cpu, op_name)(y_cpu)
11997 self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}")
11998
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000011999
Alex620dbc42022-10-21 19:03:00 +000012000# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
12001@skipIfSlowGradcheckEnv
12002class TestCommon(TestCase):
12003 exact_dtype = True
12004
12005 # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
12006 @classmethod
12007 def tearDownClass(cls):
12008 super().tearDownClass()
12009
12010 if IS_CI:
12011 err_msg = (
12012 "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
12013 "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
12014 )
12015 # Assure no opinfo entry has dynamic_dtypes
12016 filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
12017 for op in filtered_ops:
12018 fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
12019 err_msg += "\n" + fmt_str
12020
12021 assert len(filtered_ops) == 0, err_msg
12022
12023 # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
12024 # MPS still requires some fairly heavy special casing in the test framework.
12025 # When MPS becomes more consistent, this can probably be merged with that test using
12026 # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
Alex620dbc42022-10-21 19:03:00 +000012027 @suppress_warnings
12028 # MPS only supports float32
12029 @ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
12030 def test_numpy_ref_mps(self, device, dtype, op):
12031 # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
12032 # does not support float64 Tensors.
12033 # A few ops are currently broken on their reference inputs, but not their sample inputs. These should
12034 # get patched up and this workaround removed.
Ramin Azarmehr87164ac2023-01-06 17:28:49 +000012035 broken_on_ref_inputs = op.name in ['clamp', 'where']
Alex620dbc42022-10-21 19:03:00 +000012036 inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype)
12037 for sample_input in inputs:
12038 self.compare_with_reference(op, op.ref, sample_input)
12039
Nikita Shulga436993d2023-03-04 01:29:07 +000012040 @dtypes(*get_all_dtypes())
12041 def test_tensor_creation(self, device, dtype):
12042 def ones(device):
12043 return torch.ones((2, 2), dtype=dtype, device=device)
Nikita Shulga4ee8aac2024-02-11 16:25:29 +000012044 if dtype not in MPS_DTYPES + ([torch.bfloat16, torch.complex64] if product_version > 14.0 else [torch.complex64]):
Nikita Shulga436993d2023-03-04 01:29:07 +000012045 with self.assertRaises(TypeError):
12046 ones(device)
12047 else:
12048 mps_tensor = ones(device)
12049 cpu_tensor = ones("cpu")
12050 self.assertEqual(mps_tensor.cpu(), cpu_tensor)
12051
Nikita Shulga30610252024-05-03 15:20:39 +000012052
Kulin Seth76cff182022-07-04 06:41:39 +000012053# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
12054# This requires mps to be properly registered in the device generic test framework which is not the
Alex620dbc42022-10-21 19:03:00 +000012055# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
12056# to achieve this.
Kulin Seth76cff182022-07-04 06:41:39 +000012057instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012058instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
Nikita Shulga436993d2023-03-04 01:29:07 +000012059instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
Nikita Shulga30610252024-05-03 15:20:39 +000012060instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
Alban Desmaison04ac80c2022-05-20 20:25:12 +000012061
Kulin Sethe011a8e2022-05-13 18:28:53 +000012062if __name__ == "__main__":
12063 run_tests()