blob: 67f3c784b9e4a4636d12120f9bc0b999bd3a1c26 [file] [log] [blame]
Kulin Sethe011a8e2022-05-13 18:28:53 +00001# -*- coding: utf-8 -*-
2# Owner(s): ["module: mps"]
3
Denis Vieriu71ec2612023-02-15 06:09:56 +00004import platform
Kulin Sethe011a8e2022-05-13 18:28:53 +00005import sys
6import math
7import random
8import unittest
9import warnings
Kulin Seth3d833212022-05-20 03:18:09 +000010import subprocess
Alban Desmaison0a651a22022-06-14 17:54:30 +000011import tempfile
Kulin Seth3d833212022-05-20 03:18:09 +000012import os
Kulin Seth31d4b6f2022-08-17 00:26:41 +000013import copy
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000014import gc
Kulin Sethe011a8e2022-05-13 18:28:53 +000015import torch
16import torch.nn as nn
17import torch.nn.functional as F
Kulin Seth978304f2022-05-14 13:33:16 +000018import itertools
Kulin Seth76cff182022-07-04 06:41:39 +000019from collections import defaultdict
Xuehai Panb005ec62023-02-14 09:14:10 +000020from torch import inf
Kulin Sethe011a8e2022-05-13 18:28:53 +000021from torch.nn import Parameter
Alex620dbc42022-10-21 19:03:00 +000022from torch.testing._internal import opinfo
Kulin Seth76cff182022-07-04 06:41:39 +000023from torch.testing._internal.common_utils import \
Catherine Leeeea07332023-03-07 18:30:27 +000024 (gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, NoTest,
Kulin Seth2bb022e2023-03-08 08:41:21 +000025 TEST_WITH_UBSAN, skipIfSlowGradcheckEnv, TEST_WITH_ASAN, suppress_warnings)
Kulin Sethb744e1c2022-07-01 15:10:56 +000026from torch.testing import make_tensor
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +000027from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
Kulin Sethe011a8e2022-05-13 18:28:53 +000028import torch.backends.mps
Kulin Seth83239352022-06-10 13:16:21 +000029from torch.distributions import Uniform, Exponential
Kulin Sethb744e1c2022-07-01 15:10:56 +000030from functools import partial
PyTorch MergeBotb1943e02022-06-30 16:37:11 +000031
Alex620dbc42022-10-21 19:03:00 +000032from torch.testing._internal.common_methods_invocations import (
33 op_db,
Nikita Shulgafd8367a2023-02-27 15:01:01 +000034 DecorateInfo,
Alex620dbc42022-10-21 19:03:00 +000035 UnaryUfuncInfo,
36 ReductionOpInfo,
37 SpectralFuncInfo,
38 BinaryUfuncInfo,
39)
Nikita Shulga436993d2023-03-04 01:29:07 +000040from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests
Kulin Sethe011a8e2022-05-13 18:28:53 +000041from torch.testing._internal.common_nn import NNTestCase
42import numpy as np
43import torch
soulitzerbfdfeec2022-08-31 17:53:32 -040044import torch.utils._pytree as pytree
Kulin Sethfc596642023-01-04 22:15:13 +000045from itertools import product
Kulin Sethe011a8e2022-05-13 18:28:53 +000046
Alex620dbc42022-10-21 19:03:00 +000047
48# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
49_ref_test_ops = tuple(
50 filter(
51 lambda op: not isinstance(
52 op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
53 )
54 and op.ref is not None,
55 op_db,
56 )
57)
58
Kulin Seth2bb022e2023-03-08 08:41:21 +000059def mps_ops_grad_modifier(ops):
60 XFAILLIST_GRAD = {
61 # Top 60
62 # CPU: empty is returning all 0's and there is a mismatch with MPS
63 # allocation (MacOS 13). According to
64 # https://pytorch.org/docs/2.0/generated/torch.empty.html
65 # PyTorch `empty`, Returns a tensor filled with uninitialized data.
66 'empty': [torch.float16, torch.float32],
67
68 # CPU Error: RuntimeError: "addmv_impl_cpu" not implemented for 'Half'
69 'addr': [torch.float16],
70
71 # Unimplemented ops
72 '__getitem__': [torch.float16],
73 'prod': [torch.float32], # The operator 'aten::cumprod.out'
74 'sgn': [torch.float16, torch.float32],
75 '_segment_reduce': [torch.float16, torch.float32],
76 'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
77 'unfold': [torch.float16, torch.float32],
78 'trace': [torch.float32], # missing in place aten::index_fill_.int_Tensor
79 'sparse.mmreduce': [torch.float32], # csr not supported
80 'unique_consecutive': [torch.float16, torch.float32],
81 'special_modified_bessel_i0': [torch.float16, torch.float32],
82 'scalar_tensor': [torch.float16, torch.float32],
83 'cdist': [torch.float32],
84 'masked.scatter': [torch.float16, torch.float32],
85
86 # Correctness issues
87 'atanh': [torch.float32],
88
89 # Random output
90 'exponential': [torch.float16, torch.float32],
91
92 # CPU errors
93 # derivative for aten::floor_divide is not implemented on CPU
94 'floor_divide': [torch.float16, torch.float32],
95 # derivative for aten::narrow_copy is not implemented on CPU
96 'narrow_copy': [torch.float16, torch.float32],
97 # RuntimeError: "log_vml_cpu" not implemented for 'Half'
98 '__rpow__': [torch.float16],
99 'pow': [torch.float16],
100 # 'bool' object is not iterable
101 'allclose': [torch.float16, torch.float32],
102 'equal': [torch.float16, torch.float32],
103 # "mse_backward_cpu_out" not implemented for 'Half'
104 'nn.functional.mse_loss': [torch.float16],
105 # "smooth_l1_backward_cpu_out" not implemented for 'Half'
106 'nn.functional.smooth_l1_loss': [torch.float16],
107 # cpu error: grad requires non-empty inputs
108 'randn': [torch.float16, torch.float32],
109 'signal.windows.bartlett': [torch.float32],
110 'signal.windows.blackman': [torch.float32],
111 'signal.windows.cosine': [torch.float32],
112 'signal.windows.exponential': [torch.float32],
113 'signal.windows.gaussian': [torch.float32],
114 'signal.windows.general_cosine': [torch.float32],
115 'signal.windows.general_hamming': [torch.float32],
116 'signal.windows.hamming': [torch.float32],
117 'signal.windows.hann': [torch.float32],
118 'signal.windows.kaiser': [torch.float32],
119 'signal.windows.nuttall': [torch.float32],
120 'empty_permuted': [torch.float16, torch.float32],
121 'eye': [torch.float16, torch.float32],
122
123 # trunc_tensor not working properly for float16
124 'divtrunc_rounding': [torch.float16],
125 'fmod': [torch.float16],
126 }
127
128 MACOS_12_3_XFAILLIST_GRAD = {
129 # Unsupported Border padding mode, forward pass success as fallback to cpu
130 'grid_sampler_2d': [torch.float32],
131 # Unimplemented
132 'logaddexp2': [torch.float32],
133
134 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
135 # fixed in macOS 13. We are not raising error.
136 '__rpow__': [torch.float32],
137 'pow': [torch.float32],
138 }
139
140 MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
141 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000142 'masked.softmin': [torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000143 'masked.softmax': [torch.float32],
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000144 'masked.log_softmax': [torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000145
146 # Unsupported Border padding mode, forward pass success as fallback to cpu
147 'grid_sampler_2d': [torch.float32],
148
149 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
150 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
151 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
152 # Running `msort` with stable `sort` passes.
153 'msort': [torch.float16],
154
155 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
156 # fixed in macOS 13. We are not raising error.
157 'pow': [torch.float32],
158 '__rpow__': [torch.float32],
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000159 }
160
Kulin Seth2bb022e2023-03-08 08:41:21 +0000161 XPASSLIST_GRAD = {
162 'nn.functional.pairwise_distance': [torch.float16],
163 }
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000164
Kulin Seth2bb022e2023-03-08 08:41:21 +0000165 MACOS_13_3_XFAILLIST_GRAD = {
166 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
167 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
168 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
169 # Running `msort` with stable `sort` passes.
170 'msort': [torch.float16],
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000171 }
172
173 def addDecorator(op, d) -> None:
174 op.decorators = list(op.decorators) if op.decorators is not None else []
175 op.decorators.append(d)
176
177 for op in ops:
178 key = op.name + op.variant_test_name
Kulin Seth2bb022e2023-03-08 08:41:21 +0000179 if key in XFAILLIST_GRAD:
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000180 addDecorator(op, DecorateInfo(
181 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000182 dtypes=XFAILLIST_GRAD[key]))
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000183
Kulin Seth2bb022e2023-03-08 08:41:21 +0000184 if key in XPASSLIST_GRAD:
185 addDecorator(op, DecorateInfo(
186 unittest.skip,
187 dtypes=XPASSLIST_GRAD[key]))
188
189 if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000190 addDecorator(op, DecorateInfo(
191 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000192 dtypes=MACOS_12_3_XFAILLIST_GRAD[key]))
193
194 if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000195 addDecorator(op, DecorateInfo(
196 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000197 dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key]))
198
199 if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3):
200 addDecorator(op, DecorateInfo(
201 unittest.expectedFailure,
202 dtypes=MACOS_13_3_XFAILLIST_GRAD[key]))
203 yield op
204
205def mps_ops_modifier(ops):
206 # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
207 MACOS_12_3_XFAILLIST = {
208 # Top 60
209 # expected failures
210 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
211 # fixed in macOS 13.3. Currently error is not raised.
212 'pow': [torch.int16, torch.int64, torch.uint8, torch.int8],
213 # expected failures
214 '__rpow__': [torch.uint8, torch.int8],
215
216 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
217 'cdist': [torch.float32],
218 'tan': [torch.uint8, torch.float32],
219
220 # Data type support starts from macOS 13
221 'nn.functional.avg_pool1d': [torch.int64],
222 'nn.functional.avg_pool2d': [torch.int64],
223 'nn.functional.local_response_norm': [torch.int64],
224 '__radd__': [torch.uint8],
225 '__rdiv__': [torch.uint8],
226 '__rmul__': [torch.uint8],
227 'abs': [torch.uint8],
228 'acos': [torch.uint8],
229 'acosh': [torch.uint8],
230 'add': [torch.uint8],
231 'asin': [torch.uint8],
232 'asinh': [torch.uint8],
233 'atan': [torch.uint8],
234 'atanh': [torch.uint8],
235 'ceil': [torch.uint8],
236 'corrcoef': [torch.uint8],
237 'cos': [torch.uint8],
238 'cosh': [torch.uint8],
239 'cov': [torch.uint8],
240 'cumulative_trapezoid': [torch.uint8],
241 'deg2rad': [torch.uint8],
242 'diff': [torch.uint8],
243 'eq': [torch.uint8],
244 'equal': [torch.uint8],
245 'erf': [torch.uint8],
246 'exp2': [torch.uint8],
247 'exp': [torch.uint8],
248 'expm1': [torch.uint8],
249 'floor': [torch.uint8],
250 'fmax': [torch.uint8],
251 'fmin': [torch.uint8],
252 'fmod': [torch.uint8],
253 'ge': [torch.uint8],
254 'gt': [torch.uint8],
255 'isclose': [torch.uint8],
256 'isnan': [torch.uint8],
257 'kron': [torch.uint8],
258 'le': [torch.uint8],
259 'log10': [torch.uint8],
260 'log1p': [torch.uint8],
261 'log2': [torch.uint8],
262 'log': [torch.uint8],
263 'logical_and': [torch.uint8],
264 'logical_or': [torch.uint8],
265 'logical_xor': [torch.uint8],
266 'logit': [torch.uint8],
267 'lt': [torch.uint8],
268 'masked.mean': [torch.uint8],
269 'masked.std': [torch.uint8],
270 'masked.var': [torch.uint8],
271 'maximum': [torch.uint8],
272 'minimum': [torch.uint8],
273 'mul': [torch.uint8],
274 'ne': [torch.uint8],
275 'neg': [torch.uint8],
276 'nn.functional.cosine_embedding_loss': [torch.uint8],
277 'nn.functional.margin_ranking_loss': [torch.uint8],
278 'nn.functional.poisson_nll_loss': [torch.uint8],
279 'nn.functional.softsign': [torch.uint8],
280 'nn.functional.tanhshrink': [torch.uint8],
281 'nn.functional.triplet_margin_loss': [torch.uint8],
282 'nn.functional.triplet_margin_with_distance_loss': [torch.uint8],
283 'nn.functional.pairwise_distance': [torch.uint8, torch.float16],
284 'outer': [torch.uint8],
285 'rad2deg': [torch.uint8],
286 'reciprocal': [torch.uint8],
287 'remainder': [torch.uint8],
288 'round': [torch.uint8],
289 'rsqrt': [torch.uint8],
290 'sigmoid': [torch.uint8],
291 'sign': [torch.uint8],
292 'signbit': [torch.uint8],
293 'sin': [torch.uint8],
294 'sinh': [torch.uint8],
295 'special.ndtr': [torch.uint8],
296 'sqrt': [torch.uint8],
297 'sub': [torch.uint8],
298 'tanh': [torch.uint8],
299 'trapezoid': [torch.uint8],
300 'trapz': [torch.uint8],
301 'true_divide': [torch.uint8],
302 'trunc': [torch.uint8],
303 'xlogy': [torch.uint8],
304 'minbinary': [torch.uint8],
305 'maxbinary': [torch.uint8],
306 'divtrunc_rounding': [torch.uint8],
307 'divfloor_rounding': [torch.uint8],
308 'divno_rounding_mode': [torch.uint8],
309 'floor_divide': [torch.uint8],
310 'ldexp': [torch.uint8],
311 # square internally calls into power, and will type cast to int64, which supports starting from macOS 13
312 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
313
314 # cpu not giving nan for x/0.0
315 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
316 # fill tensors with uninitialized data, causing mismatch with CPU
317 'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16,
318 torch.int32, torch.int64, torch.uint8, torch.int8],
319 'empty': [torch.bool, torch.float16, torch.float32, torch.int16,
320 torch.int32, torch.int64, torch.uint8, torch.int8],
321 'dist': [torch.float16], # cpu result off, showing inf values
322 }
323
324 MACOS_BEFORE_13_3_XFAILLIST = {
325 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
326 'tan': [torch.float32],
327 'cdist': [torch.float32],
328
329 # CPU Error: cpu not giving nan for x/0.0
330 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
331
332 # test blow pass on macOS 12 as it falls back to cpu
333 # Argsort case using duplicate indices (undefined behaviour):
334 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
335 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
336 # Elements from index 30 and 5133 are both equal.
337 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
338 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
339 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
340 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
341 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
342 # Unsupported dtypes
343 'cumsum': [torch.int64],
344 'cumulative_trapezoid': [torch.int64],
345 'masked.cumsum': [torch.int64],
346 }
347
348 MACOS_13_3_XFAILLIST = {
349 # before macOS 13.3 it falls back to cpu and pass the forward pass
350 'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode
351
352 # Failure due to precision issue for fp16
353 # on both cpu and mps there are test cases that might produce inf result
354 # 'nn.functional.pairwise_distance': [torch.float16],
355
356 # test blow pass on macOS 12 as it falls back to cpu
357 # Argsort case using duplicate indices (undefined behaviour):
358 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
359 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
360 # Elements from index 30 and 5133 are both equal.
361 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
362 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
363 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
364 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
365 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
366 }
367
368 # Those ops are not expected to work
369 UNIMPLEMENTED_XFAILLIST = {
370 # Failures due to lack of op implementation on MPS backend
371 'login': None,
372 'log_sigmoid': None,
373 'log_sigmoid_forward': None,
374 'linalg.eig': None,
375 'linalg.eigvals': None,
376 'fft.fft': None,
377 'fft.fft2': None,
378 'fft.fftn': None,
379 'fft.hfft': None,
380 'fft.hfft2': None,
381 'fft.hfftn': None,
382 'fft.ifft': None,
383 'fft.ifft2': None,
384 'fft.ifftn': None,
385 'fft.ihfft': None,
386 'fft.ihfft2': None,
387 'fft.ihfftn': None,
388 'fft.irfft': None,
389 'fft.irfft2': None,
390 'fft.irfftn': None,
391 'fft.rfft': None,
392 'fft.rfft2': None,
393 'fft.rfftn': None,
394 'put': None,
395 'stft': None,
396 'nn.functional.conv_transpose3d': None,
397 'rounddecimals_neg_3': None,
398 'rounddecimals_3': None,
399 'rounddecimals_0': None,
400 '__rsub__': None,
401 'aminmax': None,
402 'angle': None,
403 'bucketize': None,
404 'cauchy_': None,
405 'cauchy': None,
406 'cholesky': None,
407 'cholesky_inverse': None,
408 'cholesky_solve': None,
409 'cummax': None,
410 'cummin': None,
411 'cumprod': None,
412 'digamma': None,
413 'erfc': None,
414 'erfinv': None,
415 'frexp': None,
416 'gcd': None,
417 'geqrf': None,
418 'nn.functional.grid_sample': None, # Unsupported Border padding mode
419 'heaviside': None,
420 'histc': None,
421 'histogram': None,
422 'histogramdd': None,
423 'i0': None,
424 'igamma': None,
425 'igammac': None,
426 'index_copy': None,
427 'index_fill': None,
428 'index_reduce': None,
429 'isin': None,
430 'isneginf': None,
431 'isposinf': None,
432 'kthvalue': None,
433 'lcm': None,
434 'lerp': None,
435 'lgamma': None,
436 'linalg.cholesky': None,
437 'linalg.cholesky_ex': None,
438 'linalg.cond': None,
439 'linalg.detsingular': None,
440 'linalg.det': None,
441 'linalg.eigh': None,
442 'linalg.eigvalsh': None,
443 'linalg.householder_product': None,
444 'linalg.ldl_factor': None,
445 'linalg.ldl_factor_ex': None,
446 'linalg.ldl_solve': None,
447 'linalg.lstsq': None,
448 'linalg.lstsqgrad_oriented': None,
449 'linalg.lu': None,
450 'linalg.lu_factor': None,
451 'linalg.lu_factor_ex': None,
452 'linalg.lu_solve': None,
453 'linalg.matrix_norm': [torch.float32],
454 'linalg.norm': [torch.float32],
455 'linalg.normsubgradients_at_zero': [torch.float32],
456 'linalg.qr': None,
457 'linalg.slogdet': None,
458 'linalg.solve': None,
459 'linalg.solve_ex': None,
460 'linalg.svdvals': None,
461 'linalg.tensorsolve': None,
462 'linalg.vander': None,
463 'linalg.vecdot': None,
464 'logcumsumexp': None,
465 'logdet': None,
466 'lu': None,
467 'lu_solve': None,
468 'lu_unpack': None,
469 'masked.cumprod': None,
470 'masked.median': None,
471 'matrix_exp': None,
472 'mode': None,
473 'mvlgamma': None,
474 'mvlgammamvlgamma_p_1': None,
475 'mvlgammamvlgamma_p_3': None,
476 'mvlgammamvlgamma_p_5': None,
477 'nanquantile': None,
478 'nanmedian': None,
479 'native_dropout_backward': None,
480 'nextafter': None,
481 'normnuc': None,
482 'nn.functional.fractional_max_pool2d': None,
483 'nn.functional.fractional_max_pool3d': None,
484 'nn.functional.adaptive_avg_pool3d': None,
485 'nn.functional.adaptive_max_pool3d': None,
486 'nn.functional.interpolatearea': None,
487 'nn.functional.interpolatebicubic': None,
488 'nn.functional.interpolatelinear': None,
489 'nn.functional.interpolatetrilinear': None,
490 'nn.functional.max_unpool1dgrad': None,
491 'nn.functional.max_unpool2dgrad': None,
492 'nn.functional.max_unpool3dgrad': None,
493 'nn.functional.avg_pool3d': None,
494 'nn.functional.ctc_loss': None,
495 'nn.functional.embedding_bag': None,
496 'nn.functional.hardshrink': None,
497 'nn.functional.max_pool3d': None,
498 'nn.functional.max_unpool1d': None,
499 'nn.functional.max_unpool2d': None,
500 'nn.functional.max_unpool3d': None,
501 'nn.functional.mish': None,
502 'nn.functional.multi_margin_loss': None,
503 'nn.functional.multilabel_margin_loss': None,
504 'nn.functional.pdist': None,
505 'nn.functional.rrelu': None,
506 'nn.functional.softshrink': None,
507 'nn.functional.norm': None,
508 'ormqr': None,
509 'pca_lowrank': None,
510 'pinverse': None,
511 'polar': None,
512 'polygamma': None,
513 'polygammapolygamma_n_0': None,
514 'polygammapolygamma_n_1': None,
515 'polygammapolygamma_n_2': None,
516 'polygammapolygamma_n_3': None,
517 'polygammapolygamma_n_4': None,
518 'qr': None,
519 'quantile': None,
520 'renorm': None,
521 'rsub': None,
522 'scatter_reduceamax': None,
523 'scatter_reduceamin': None,
524 'scatter_reducemin': None,
525 'scatter_reducemean': None,
526 'scatter_reduceprod': None,
527 'scatter_reducesum': None,
528 'searchsorted': None,
529 'segment_reduce': None,
530 '_segment.reduce': None,
531 'segment.reduce': None,
532 'segment_reduce_offsets': None,
533 '_segment_reduce_offsets': None,
534 '_segment_reduce_lengths': None,
535 '_segment_reducelengths': None,
536 '_segment_reduceoffsets': None,
537 'sinc': None,
538 'sparse.mm': None,
539 'sparse.mmreduce': None,
540 'special.airy_ai': None,
541 'special.bessel_j0': None,
542 'special.bessel_j1': None,
543 'special.bessel_y0': None,
544 'special.bessel_y1': None,
545 'special.chebyshev_polynomial_t': None,
546 'special.chebyshev_polynomial_u': None,
547 'special.entr': None,
548 'special.erfcx': None,
549 'special.hermite_polynomial_h': None,
550 'special.hermite_polynomial_he': None,
551 'special.i0e': None,
552 'special.i1': None,
553 'special.i1e': None,
554 'special.laguerre_polynomial_l': None,
555 'special.log_ndtr': None,
556 'special.modified_bessel_i0': None,
557 'special.modified_bessel_i1': None,
558 'special.modified_bessel_k0': None,
559 'special.modified_bessel_k1': None,
560 'special.ndtri': None,
561 'special.polygamma': None,
562 'special.polygammaspecial_polygamma_n_0': None,
563 'special.scaled_modified_bessel_k0': None,
564 'special.scaled_modified_bessel_k1': None,
565 'special.spherical_bessel_j0': None,
566 'special.xlog1py': None,
567 'special.zeta': None,
568 'std_mean': None,
569 'std_meanunbiased': None,
570 'svd_lowrank': None,
571 'symeig': None,
572 'take': None,
573 'to': None,
574 'to_sparse': None,
575 'unique': None,
576 'vdot': None,
577 'view_as_complex': None,
578 'segment_reduce': None,
579 'segment_reduce_': None,
580 '_segment_reduce_lengths': None,
581 '_upsample_bilinear2d_aa': None,
582 'geometric' : None,
583 'geometric_': None,
584 'log_normal_': None,
585 'log_normal': None,
586 'bfloat16': None,
587 'cdouble': None,
588 'cfloat': None,
589 'complex': None,
590 'double': None,
591 'chalf': None,
592 'nn.functional.softminwith_dtype': None,
593 'log_softmaxwith_dtype': None,
594 'softmaxwith_dtype': None,
595 'float_power': None,
596 'full_like': None,
597 'linalg.matrix_rank': None,
598 'linalg.matrix_rankhermitian': None,
599 'linalg.pinv': None,
600 'linalg.pinvhermitian': None,
601
602 # MPS: input sizes must be divisible by output sizes
603 'nn.functional.adaptive_avg_pool1d': None,
604 'nn.functional.adaptive_avg_pool2d': None,
605
606 # Unsupported dtypes
607 # bmm is not supported for integral types
608 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
609 # Cannot convert a MPS Tensor to float64 dtype. The tensors
610 # input data is created with double in common_methods_invocations.py
611 'nn.functional.batch_norm': [torch.float32],
612 'ones_like': None,
613 'zeros_like': None,
614
615 # Convolution for integral types is not supported on MPS
616 'nn.functional.conv1d': [torch.int64],
617 'nn.functional.conv2d': [torch.int64],
618 'nn.functional.conv_transpose1d': [torch.int64],
619 'nn.functional.conv_transpose2d': [torch.int64],
620
621 # Unsupported dtypes
622 'dot': [torch.int64],
623 'index_add': [torch.int64],
624 'log1p': [torch.int64],
625 'sigmoid': [torch.int64],
626 'atan2': [torch.int64],
627
628 # GEMM on MPS is not supported for integral types
629 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
630 '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
631 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
632 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
633 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
634 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
635 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
636 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
637 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
638 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
639 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
640 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
641 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
642 'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
643 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
644 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
645
646 # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as
647 # the MPS framework doesn't support float64
648 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
649 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
650 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
651 # returned output on CPU is float64
652 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
653
654 # trunc_tensor not working properly for float16
655 'divtrunc_rounding': [torch.float16],
656 'fmod': [torch.float16],
657 }
658
659 UNDEFINED_XFAILLIST = {
660 # Top 60 operators
661 # topk fails with duplicate indices
662 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
663
664 # Failures due to random output that they generate using
665 # Philox engine causing mismatch with CPU results
666 'multinomial': [torch.float32], # random results
667 'uniform': [torch.float16, torch.float32],
668 'rand_like': [torch.float16, torch.float32],
669 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
670 'randn_like': [torch.float16, torch.float32],
671 'bernoulli': [torch.float32],
672 'exponential': [torch.float16, torch.float32],
673 'nn.functional.feature_alpha_dropoutwith_train': [torch.float32],
674 'normal': [torch.float16, torch.float32, torch.float16, torch.float32],
675 'normalin_place': [torch.float16, torch.float32],
676 'normalnumber_mean': [torch.float16, torch.float32],
677 'nn.functional.alpha_dropout': [torch.float32],
678 'nn.functional.dropout': [torch.float32],
679 'nn.functional.dropout2d': [torch.float32],
680 'nn.functional.dropout3d': [torch.float32],
681
682 # these fill tensors with uninitialized data, causing mismatch with CPU
683 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
684 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
685 # 'empty': [torch.int8],
686 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16,
687 torch.int32, torch.int64, torch.uint8, torch.int8],
688 # duplicate indices are used in the testcase - undefined behaviour
689 'index_put': None,
690 # zero to negative integer powers are undefined
691 '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64],
692 'resize_': [torch.float16, torch.float32],
693 'resize_as_': [torch.float16, torch.float32],
694
695 # CPU Errors:
696 'addr': [torch.bool, torch.int16, torch.int32,
697 torch.int64, torch.uint8, torch.int8], # "addmv_impl_cpu" not implemented for 'Half'
698 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
699 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
700 'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
701 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
702
703 # random results
704 # mps vs cpu:
705 # Mismatched elements: 40 / 96 (41.7%)
706 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
707 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
708 # cuda(2.0.0.dev20230301+cu117) vs cpu:
709 # Mismatched elements: 56 / 96 (58.3%)
710 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
711 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
712 'nn.functional.scaled_dot_product_attention': [torch.float32],
713
714 # Failures due to casting negative float to uint8 is undefined
715 'byte': [torch.float16, torch.float32],
716 }
717
718 def addDecorator(op, d) -> None:
719 op.decorators = list(op.decorators) if op.decorators is not None else []
720 op.decorators.append(d)
721
722 for op in ops:
723 key = op.name + op.variant_test_name
724 for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST]:
725 if key in xfaillist:
726 addDecorator(op, DecorateInfo(
727 unittest.expectedFailure,
728 dtypes=xfaillist[key]))
729
730 if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
731 addDecorator(op, DecorateInfo(
732 unittest.expectedFailure,
733 dtypes=MACOS_BEFORE_13_3_XFAILLIST[key]))
734
735 if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3):
736 addDecorator(op, DecorateInfo(
737 unittest.expectedFailure,
738 dtypes=MACOS_13_3_XFAILLIST[key]))
739
740 if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()):
741 addDecorator(op, DecorateInfo(
742 unittest.expectedFailure,
743 dtypes=MACOS_12_3_XFAILLIST[key]))
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000744 yield op
745
Kulin Sethe011a8e2022-05-13 18:28:53 +0000746# Same logic as test_cuda.py
747if not torch.backends.mps.is_available():
748 print('MPS not available, skipping tests', file=sys.stderr)
Catherine Leeeea07332023-03-07 18:30:27 +0000749 TestCase = NoTest # noqa: F811
750 NNTestCase = NoTest # noqa: F811
Kulin Sethe011a8e2022-05-13 18:28:53 +0000751
Denis Vieriu71ec2612023-02-15 06:09:56 +0000752product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]))
753
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +0000754# Determine whether to enable MPS memory leak check (uses same code as CUDA).
755TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
756
757def skipMPSMemoryLeakCheckIf(condition):
758 def dec(fn):
759 if getattr(fn, '_do_mps_memory_leak_check', True):
760 fn._do_mps_memory_leak_check = not condition
761 return fn
762 return dec
763
764class MpsMemoryLeakCheck():
765 def __init__(self, testcase, name=None):
766 self.name = testcase.id() if name is None else name
767 self.testcase = testcase
768
769 def __enter__(self):
770 # Performs a gc if required (required if any memory is held)
771 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
772 if caching_allocator_mem_allocated > 0:
773 gc.collect()
774 torch.mps.empty_cache()
775
776 # Acquires caching allocator and driver statistics before the test is run
777 self.caching_allocator_before = torch.mps.current_allocated_memory()
778 self.driver_before = torch.mps.driver_allocated_memory()
779
780 def __exit__(self, exec_type, exec_value, traceback):
781 # Don't check for leaks if an exception was thrown
782 if exec_type is not None:
783 return
784 # Compares caching allocator before/after statistics
785 # An increase in allocated memory is a discrepancy indicating a possible memory leak
786 discrepancy_detected = False
787 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
788 if caching_allocator_mem_allocated > self.caching_allocator_before:
789 discrepancy_detected = True
790
791 # Short-circuits if no discrepancy detected
792 if not discrepancy_detected:
793 return
794 # Validates the discrepancy persists after garbage collection and
795 # is confirmed by the driver API
796 gc.collect()
797 torch.mps.empty_cache()
798
799 discrepancy_detected = True
800 # Query memory multiple items to ensure leak was not transient
801 for n in range(3):
802 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
803 driver_mem_allocated = torch.mps.driver_allocated_memory()
804
805 caching_allocator_discrepancy = False
806 driver_discrepancy = False
807
808 if caching_allocator_mem_allocated > self.caching_allocator_before:
809 caching_allocator_discrepancy = True
810
811 if driver_mem_allocated > self.driver_before:
812 driver_discrepancy = True
813
814 if not(caching_allocator_discrepancy or driver_discrepancy):
815 # Leak was false positive, exit loop
816 discrepancy_detected = False
817 break
818
819 if caching_allocator_discrepancy and not driver_discrepancy:
820 # Just raises a warning if the leak is not validated by the driver API
821 msg = ("MPS caching allocator reports a memory leak not "
822 "verified by the driver API in {}! "
823 "Caching allocator allocated memory was {} and is now reported as {}. "
824 "MPS driver allocated memory was {} and is now {}.").format(
825 self.name, self.caching_allocator_before,
826 caching_allocator_mem_allocated, self.driver_before, driver_mem_allocated)
827 warnings.warn(msg)
828 elif caching_allocator_discrepancy and driver_discrepancy:
829 # A caching allocator discrepancy validated by the driver API is a failure
830 msg = ("MPS driver API confirmed a leak in {}! "
831 "Caching allocator allocated memory was {} and is now reported as {}. "
832 "MPS driver allocated memory was {} and is now {}.").format(
833 self.name, self.caching_allocator_before, caching_allocator_mem_allocated,
834 self.driver_before, driver_mem_allocated)
835
836 raise RuntimeError(msg)
837
838# Expand TestCase class with Memory Leak Detection on MPS device
839class TestCaseMPS(TestCase):
840 _do_mps_memory_leak_check = True
841
842 def __init__(self, method_name='runTest'):
843 super().__init__(method_name)
844 test_method = getattr(self, method_name, None)
845 if test_method is not None:
846 # Wraps the tested method if we should do MPS memory check.
847 if TEST_MPS_MEM_LEAK_CHECK:
848 if self._do_mps_memory_leak_check:
849 self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
850
851 def assertLeaksNoMpsTensors(self, name=None):
852 name = self.id() if name is None else name
853 return MpsMemoryLeakCheck(self, name)
854
855 def wrap_with_mps_policy(self, method_name, policy):
856 test_method = getattr(self, method_name)
857 setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
858
859 # checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
860 def wrap_with_mps_memory_check(self, method):
861 return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
862
863class TestMemoryLeak(TestCaseMPS):
864 def test_mps_memory_leak_detection(self):
865 l = []
866
867 @self.wrap_with_mps_memory_check
868 def no_leak():
869 pass
870
871 # Trigger an intentional memory leak
872 @self.wrap_with_mps_memory_check
873 def leak_gpu0():
874 # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
875 l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
876
877 no_leak()
878
879 # check if a runtime error for memory leak was emitted which would
880 # confirm whether memory leak detection worked successfully or not.
881 with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
882 leak_gpu0()
883
884class MPSReluTest(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +0000885 def _npRelu(self, np_features):
886 return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
887
888 def testNpRelu(self):
Philip Meierbc73aff2022-11-02 11:25:04 +0100889 torch.testing.assert_close(
Kulin Sethe011a8e2022-05-13 18:28:53 +0000890 np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
891 self._npRelu(
892 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
893 0.9]])))
894
895 def _testRelu(self, np_features, device):
896 np_relu = self._npRelu(np_features)
897 # Convert the numpy array to a PyTorch Tensor,
898 # and move the Tensor to the CPU/GPU based on the "device" parameter
899 py_tensor = torch.from_numpy(np_features).to(device)
900 py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
901 py_relu_cpu = py_relu.to("cpu")
902
Philip Meierbc73aff2022-11-02 11:25:04 +0100903 self.assertEqual(np_relu, py_relu_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000904
905 def _testReluInPlace(self, np_features, device):
906 np_relu = self._npRelu(np_features)
907 # Convert the numpy array to a PyTorch Tensor,
908 # and move the Tensor to the CPU/GPU based on the "device" parameter
909 py_tensor = torch.from_numpy(np_features).to(device)
910 py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
911 py_relu_cpu = py_relu.to("cpu")
912
Philip Meierbc73aff2022-11-02 11:25:04 +0100913 self.assertEqual(np_relu, py_relu_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000914 # Inplace Relu modifies the initial input and it should match the output of Relu
Philip Meierbc73aff2022-11-02 11:25:04 +0100915 self.assertEqual(np_relu, py_tensor.to("cpu"))
Kulin Sethe011a8e2022-05-13 18:28:53 +0000916
917 def testNumbersCPU(self):
918 for t in [np.int32]:
919 # Force execution on CPU even if a GPU kernel is available for the type.
920 self._testRelu(
921 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
922 device="cpu")
923 self._testReluInPlace(
924 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
925 device="cpu")
926
927 def testNumbersGPU(self):
928 for t in [np.float16, np.float32]:
929 self._testRelu(
930 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
931 device="mps")
932 self._testReluInPlace(
933 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
934 device="mps")
935
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +0000936class MatmulTest(TestCaseMPS):
Kulin Seth978304f2022-05-14 13:33:16 +0000937 def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
938 if expand_tensor_1_shape:
939 tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
940 else:
941 tensor1_mps = torch.randn(shape_tensor_1, device="mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +0000942
Kulin Seth978304f2022-05-14 13:33:16 +0000943 if expand_tensor_2_shape:
944 tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
945 else:
946 tensor2_mps = torch.randn(shape_tensor_2, device="mps")
947
948 tensor1_cpu = tensor1_mps.to("cpu")
949 tensor2_cpu = tensor2_mps.to("cpu")
Kulin Sethe011a8e2022-05-13 18:28:53 +0000950
951 matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
952 matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
953
954 self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
955
956 def test_vector_x_vector(self):
957 # uses `dot`
958 self._helper(3, 3)
959
960 def test_matrix_x_vector(self):
961 # uses `addmv`
962 self._helper((3, 4), 4)
963
964 def test_batched_matrix_x_broadcasted_vector(self):
965 self._helper((10, 3, 4), 4)
966
967 def test_batched_matrix_x_batched_matrix(self):
968 # uses `bmm.out`
969 self._helper((10, 3, 4), (10, 4, 5))
970
971 def test_batched_matrix_x_broadcasted_matrix(self):
972 self._helper((10, 3, 4), (4, 5))
973
974
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +0000975class MPSLeakyReluTest(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +0000976 def _npLeakyRelu(self, np_features, negative_slope=0.1):
977 return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
978
979 def testNpLeakyRelu(self):
Philip Meierbc73aff2022-11-02 11:25:04 +0100980 torch.testing.assert_close(
Kulin Sethe011a8e2022-05-13 18:28:53 +0000981 np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
982 [0.1, -0.03, 0.5, -0.07, 0.9]]),
983 self._npLeakyRelu(
984 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
985 0.9]]),
986 negative_slope=0.1))
987
988 def _testLeakyRelu(self, np_features, negative_slope, device):
989 cpu_x = torch.from_numpy(np_features).requires_grad_()
990 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
991 relu_op = torch.nn.LeakyReLU(negative_slope)
992
993 cpu_leaky_relu = relu_op(cpu_x)
994 mps_leaky_relu = relu_op(mps_x)
Philip Meierbc73aff2022-11-02 11:25:04 +0100995 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +0000996
997 # test backward pass
998 cpu_grad = torch.ones_like(cpu_leaky_relu)
999 mps_grad = cpu_grad.to('mps')
1000 cpu_leaky_relu.backward(gradient=cpu_grad)
1001 mps_leaky_relu.backward(gradient=mps_grad)
Philip Meierbc73aff2022-11-02 11:25:04 +01001002 torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001003
1004 def testNumbersCPU(self):
1005 for t in [np.float32]:
1006 self._testLeakyRelu(
1007 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1008 negative_slope=0.2,
1009 device="cpu")
1010
1011
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001012class TestAvgPool(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001013 def _sum_pool2d(self, x, kernel_size):
1014 windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
1015 return torch.sum(windows, dim=1)
1016
1017 def _sum_pool3d(self, x, kernel_size):
1018 # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
1019 h = kernel_size[0]
1020 splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
1021 # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
1022 splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
1023 joined_x = torch.cat(splited_x)
1024 return joined_x.view(1, joined_x.numel())
1025
1026 def _avg_pool2d(self, x, kernel_size):
1027 size = reduce((lambda x, y: x * y), kernel_size)
1028 return self._sum_pool2d(x, kernel_size) / size
1029
1030 def _avg_pool3d(self, x, kernel_size):
1031 size = reduce((lambda x, y: x * y), kernel_size)
1032 return self._sum_pool3d(x, kernel_size) / size
1033
1034 def test_avg_pool2d_with_zero_divisor(self):
1035 self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
1036 lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
1037
1038 def test_doubletensor_avg_pool2d_with_divisor(self):
1039 n, m = 3, 3
1040 input = torch.rand(1, 1, n, m)
1041 for i in range(1, n + 1):
1042 for j in range(1, m + 1):
1043 for divisor in [1, 7, i * j]:
1044 actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
1045 actual = actual.view(1, actual.numel())
1046 expected = self._sum_pool2d(input, (i, j)) / divisor
1047 self.assertEqual(actual, expected, rtol=0, atol=1e-5)
1048
1049 def test_avg_pool2d_ceil_mode(self):
1050 # Regression test for gh-36977
1051 x = 10 * torch.randn((1, 16, 4, 4))
1052 y = torch.nn.functional.avg_pool2d(
1053 x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1054 padding=(0, 1), stride=2)
1055 self.assertTrue(not torch.isnan(y).any())
1056 y = torch.nn.functional.avg_pool2d(
1057 x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1058 padding=(0, 1), stride=2)
1059 self.assertTrue(not torch.isnan(y).any())
1060
1061
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001062class TestMPS(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001063 def test_exp(self, device="mps", dtype=torch.float):
1064 for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
1065 b = torch.arange(18, device="cpu") / 3 * math.pi
1066 a = torch.tensor(v, dtype=dtype, device="cpu") * b
1067 a = a.to(dtype).to("mps")
1068 self.compare_with_numpy(torch.exp, np.exp, a)
1069
1070 def test_exp1(self, device="mps", dtype=torch.float):
1071 input = torch.tensor([-0.1, 3.0, -0.9]).to('mps')
1072 output = torch.exp(input).to('cpu')
Kulin Sethe011a8e2022-05-13 18:28:53 +00001073
Denis Vieriu5d483922023-02-07 16:25:03 +00001074 def test_exp_strided_output(self):
1075 x = torch.rand((256, 10), device='mps')
1076 x_cpu = x.to("cpu")
1077
1078 x = x.permute(1, 0)
1079 x_cpu = x_cpu.permute(1, 0)
1080
1081 res = x.exp()
1082 res_cpu = x_cpu.exp()
1083 self.assertEqual(res, res_cpu)
1084
Kulin Sethe011a8e2022-05-13 18:28:53 +00001085 def _testLeakyRelu(self, np_features, negative_slope, device):
1086 cpu_x = torch.from_numpy(np_features).requires_grad_()
1087 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
1088 relu_op = torch.nn.LeakyReLU(negative_slope)
1089
1090 cpu_leaky_relu = relu_op(cpu_x)
1091 mps_leaky_relu = relu_op(mps_x)
Philip Meierbc73aff2022-11-02 11:25:04 +01001092 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001093
1094 # test backward pass
1095 cpu_grad = torch.ones_like(cpu_leaky_relu)
1096 mps_grad = cpu_grad.to('mps')
1097 cpu_leaky_relu.backward(gradient=cpu_grad)
1098 mps_leaky_relu.backward(gradient=mps_grad)
Philip Meierbc73aff2022-11-02 11:25:04 +01001099 torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001100
1101 def testNumbersGPU(self):
1102 for t in [np.float32]:
1103 self._testLeakyRelu(
1104 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1105 negative_slope=0.1,
1106 device="mps")
1107
1108 def test_fill(self):
1109
1110 def helper(val, shape):
1111 tensor = torch.zeros(shape, device='mps')
1112 tensor_mps = tensor.fill_(val)
1113 tensor_mps = torch.tanh(tensor_mps)
1114
1115 tensor_0 = torch.zeros(shape, device='cpu')
1116 tensor_cpu = tensor_0.fill_(val)
1117 tensor_cpu = torch.tanh(tensor_cpu)
1118
1119 self.assertEqual(tensor_mps, tensor_cpu)
1120
1121 helper(0, [1024])
1122 helper(0.2, [2, 3])
1123
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001124 def test_fill_storage_offset(self):
1125 shape = [2, 10]
1126 val = 0.2
1127 tensor = torch.ones(shape, device="mps")
1128 tensor_mps = tensor[:][1].fill_(val)
1129 tensor_0 = torch.ones(shape, device="cpu")
1130 tensor_cpu = tensor_0[:][1].fill_(val)
1131
1132 self.assertEqual(tensor_mps, tensor_cpu)
1133
1134 shape = [1, 10]
1135 val = 0.0
1136 tensor = torch.ones(shape, device="mps")
1137 val_tensor_mps = torch.tensor(val, device="mps")
1138 tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
1139 tensor_0 = torch.ones(shape, device="cpu")
1140 val_tensor_cpu = torch.tensor(val, device="cpu")
1141 tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
1142
1143 self.assertEqual(tensor_mps, tensor_cpu)
1144
Denis Vieriu80394bb2023-01-04 02:20:50 +00001145 def test_cdist_large(self, device="mps"):
1146 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1147 x = torch.randn(100, 10, device=device)
1148 y = torch.randn(100, 10, device=device)
1149 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1150 expected = self._brute_cdist(x, y, p=2)
1151 self.assertEqual(expected, actual)
1152
1153 def test_cdist_large_batch(self, device="mps"):
1154 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1155 x = torch.randn(4, 3, 100, 10, device=device)
1156 y = torch.randn(4, 3, 100, 10, device=device)
1157 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1158 expected = self._brute_cdist(x, y, p=2)
1159 self.assertEqual(expected, actual)
1160
1161 def test_cdist_non_contiguous(self, device="mps"):
1162 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1163 x = torch.randn(5, 7, device=device).mT
1164 y = torch.randn(5, 3, device=device).mT
1165 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1166 expected = self._brute_cdist(x, y, p=2)
1167 self.assertFalse(x.is_contiguous())
1168 self.assertFalse(y.is_contiguous())
1169 self.assertEqual(expected, actual)
1170
1171 x = torch.randn(7, 5, device=device)
1172 y = torch.randn(5, 3, device=device).t()
1173 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1174 expected = self._brute_cdist(x, y, p=2)
1175 self.assertTrue(x.is_contiguous())
1176 self.assertFalse(y.is_contiguous())
1177 self.assertEqual(expected, actual)
1178
1179 x = torch.randn(5, 7, device=device).t()
1180 y = torch.randn(3, 5, device=device)
1181 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1182 expected = self._brute_cdist(x, y, p=2)
1183 self.assertFalse(x.is_contiguous())
1184 self.assertTrue(y.is_contiguous())
1185 self.assertEqual(expected, actual)
1186
1187 def test_cdist_non_contiguous_batch(self, device="mps"):
1188 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1189 x = torch.randn(4, 3, 2, 5, 7, device=device).mT
1190 y = torch.randn(4, 3, 2, 5, 3, device=device).mT
1191 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1192 expected = self._brute_cdist(x, y, p=2)
1193 self.assertFalse(x.is_contiguous())
1194 self.assertFalse(y.is_contiguous())
1195 self.assertEqual(expected, actual)
1196
1197 x = torch.randn(7, 2, 7, 5, device=device)
1198 y = torch.randn(7, 2, 5, 3, device=device).mT
1199 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1200 expected = self._brute_cdist(x, y, p=2)
1201 self.assertTrue(x.is_contiguous())
1202 self.assertFalse(y.is_contiguous())
1203 self.assertEqual(expected, actual)
1204
1205 x = torch.randn(4, 5, 7, device=device).mT
1206 y = torch.randn(4, 3, 5, device=device)
1207 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1208 expected = self._brute_cdist(x, y, p=2)
1209 self.assertFalse(x.is_contiguous())
1210 self.assertTrue(y.is_contiguous())
1211 self.assertEqual(expected, actual)
1212
1213 def test_cdist_euclidean_large(self, device="mps"):
1214 def _test_euclidean_large_cdist(sizex, sizey=None):
1215 if sizey is None:
1216 sizey = sizex
1217 x = torch.randn(sizex, device=device, dtype=torch.float)
1218 y = torch.randn(sizey, device=device, dtype=torch.float)
1219 eps = 1e-6
1220 # to avoid extremum
1221 x = x - (((x - y) < eps).float() * 2 * eps)
1222 x.requires_grad = True
1223 y.requires_grad = True
1224 dist = torch.cdist(x, y, p=2)
1225 # Do a backward pass to check that it is valid for large
1226 # matrices
1227 loss = dist.sum()
1228 loss.backward()
1229
1230 _test_euclidean_large_cdist((2000, 5))
1231
1232 def test_cdist_same_inputs(self, device="mps"):
1233 # Test to detect issues in cdist gradient calculation
1234 # When the distances are 0
1235 sizex = (1, 27, 32)
1236 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
1237 x = torch.randn(sizex, device=device, dtype=torch.float)
1238 dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
1239 y = x.clone()
1240 eps = 1e-6
1241 x.requires_grad = True
1242 d = torch.cdist(x, y)
1243 d.backward(dist_grad)
1244 # Check that the backward passs does not contain invalid
1245 # values such as nan or inf
1246 assert torch.isfinite(x.grad).all()
1247
1248
1249 def _brute_cdist(self, x, y, p=2):
1250 r1 = x.shape[-2]
1251 r2 = y.shape[-2]
1252 if r1 == 0 or r2 == 0:
1253 return torch.empty(r1, r2, device=x.device)
1254 return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
1255
1256 def test_cdist_norm(self, device="mps"):
1257 for r1 in [3, 4]:
1258 for m in [2, 3]:
1259 for r2 in [4, 6]:
1260 for p in [0, 1, 1.5, 2.5, float('inf')]:
1261 x = torch.randn(r1, m, device=device)
1262 y = torch.randn(r2, m, device=device)
1263 if p == 2:
1264 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1265 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1266 expected = self._brute_cdist(x, y, p=2)
1267 self.assertEqual(expected, actual, rtol=0, atol=0.02)
1268 else:
1269 actual = torch.cdist(x, y, p=p)
1270 expected = self._brute_cdist(x, y, p=p)
1271 self.assertEqual(expected, actual)
1272
1273 def test_cdist_norm_batch(self, device="mps"):
1274 for r1 in [3, 4]:
1275 for m in [2, 3]:
1276 for r2 in [4, 6]:
1277 for p in [0, 3, 1.5, 2.5, float('inf')]:
1278 x = torch.randn(2, 3, 6, r1, m, device=device)
1279 y = torch.randn(2, 3, 6, r2, m, device=device)
1280 if p == 2:
1281 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1282 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1283 expected = self._brute_cdist(x, y, p=2)
1284 self.assertEqual(expected, actual, rtol=0, atol=0.02)
1285 else:
1286 actual = torch.cdist(x, y, p=p)
1287 expected = self._brute_cdist(x, y, p=p)
1288 self.assertEqual(expected, actual)
1289
Kulin Sethe011a8e2022-05-13 18:28:53 +00001290 def test_mm(self):
1291 B = torch.ones(5, 6).to("mps")
1292 C = torch.ones(6, 5).to("mps")
1293 D = torch.mm(B, C).cpu()
Philip Meierbc73aff2022-11-02 11:25:04 +01001294 torch.testing.assert_close(D, torch.full((5, 5), 6.0))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001295
Denis Vieriu1a0738f2023-01-05 14:48:34 +00001296 def test_linalg_cross(self):
1297 def helper(dtype):
1298 device = "mps"
1299 if dtype is torch.int32 or dtype is torch.int64:
1300 x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1301 y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1302 else:
1303 x = torch.rand(100, 3, 100, dtype=dtype, device=device)
1304 y = torch.rand(100, 3, 100, dtype=dtype, device=device)
1305 x_cpu = x.to("cpu")
1306 y_cpu = y.to("cpu")
1307 res1 = torch.linalg.cross(x, y, dim=1)
1308 res2 = torch.tensor((), dtype=dtype, device=device)
1309 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1310 res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1311 torch.linalg.cross(x, y, dim=1, out=res2)
1312 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1313 self.assertEqual(res1, res2)
1314 self.assertEqual(res1, res1_cpu)
1315 self.assertEqual(res2, res2_cpu)
1316
1317 # test for broadcastable inputs
1318 if dtype is torch.int32 or dtype is torch.int64:
1319 x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
1320 y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
1321 else:
1322 x = torch.rand(1, 3, 2, dtype=dtype, device=device)
1323 y = torch.rand(4, 3, 1, dtype=dtype, device=device)
1324 x_cpu = x.to("cpu")
1325 y_cpu = y.to("cpu")
1326 res1 = torch.linalg.cross(x, y, dim=1)
1327 res2 = torch.tensor((), dtype=dtype, device=device)
1328 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1329 res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1330 torch.linalg.cross(x, y, dim=1, out=res2)
1331 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1332 self.assertEqual(res1, res2)
1333 self.assertEqual(res1, res1_cpu)
1334 self.assertEqual(res2, res2_cpu)
1335 [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
1336
1337 def test_cross(self):
1338 a = torch.randn(4, 3, device="mps")
1339 b = torch.randn(4, 3, device="mps")
1340 a_cpu = a.to("cpu")
1341 b_cpu = b.to("cpu")
1342 res = torch.cross(a, b, dim=1)
1343 res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
1344 self.assertEqual(res, res_cpu)
1345
Kulin Sethe011a8e2022-05-13 18:28:53 +00001346 def test_addmm(self):
1347 A = torch.ones(5, 5).to("mps")
1348 B = torch.ones(5, 6).to("mps")
1349 C = torch.ones(6, 5).to("mps")
1350 D = torch.addmm(A, B, C).to("cpu")
Philip Meierbc73aff2022-11-02 11:25:04 +01001351 torch.testing.assert_close(D, torch.full((5, 5), 7.0))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001352
1353 def test_bmm(self):
1354 batch1_cpu = torch.randn(10, 3, 4)
1355 batch2_cpu = torch.randn(10, 4, 5)
1356
1357 batch1_mps = batch1_cpu.detach().clone().to("mps")
1358 batch2_mps = batch2_cpu.detach().clone().to("mps")
1359
1360 output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
1361 output_mps = torch.bmm(batch1_mps, batch2_mps)
1362
1363 self.assertEqual(output_cpu, output_mps)
1364 self.assertEqual(output_cpu.size(), output_mps.size())
1365
Denis Vieriu507b8c32023-02-11 00:16:46 +00001366 def test_addr(self):
1367 A = torch.ones(5, 10).to("mps")
1368 B = torch.ones(5).to("mps")
1369 C = torch.ones(10).to("mps")
1370 D = torch.addr(A, B, C).to("cpu")
1371 torch.testing.assert_close(D, torch.full((5, 10), 2.0))
1372
PumeTufc1c0cd2022-11-18 07:24:33 +00001373 def test_trace(self):
1374 M_cpu = torch.randn(3, 3)
1375 M_mps = M_cpu.detach().clone().to("mps")
1376
1377 output_cpu = torch.trace(M_cpu)
1378 output_mps = torch.trace(M_mps)
1379
1380 self.assertEqual(output_cpu, output_mps)
1381 self.assertEqual(output_cpu.size(), output_mps.size())
1382
Kulin Sethe011a8e2022-05-13 18:28:53 +00001383 def test_addbmm(self):
1384 M_cpu = torch.randn(3, 5)
1385 batch1_cpu = torch.randn(10, 3, 4)
1386 batch2_cpu = torch.randn(10, 4, 5)
1387
1388 M_mps = M_cpu.detach().clone().to("mps")
1389 batch1_mps = batch1_cpu.detach().clone().to("mps")
1390 batch2_mps = batch2_cpu.detach().clone().to("mps")
1391
1392 output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
1393 output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
1394
1395 self.assertEqual(output_cpu, output_mps)
1396 self.assertEqual(output_cpu.size(), output_mps.size())
1397
1398 def test_baddbmm(self):
Kulin Seth3d833212022-05-20 03:18:09 +00001399 def helper(input_shape, batch1_shape, batch2_shape):
1400 M_cpu = torch.randn(input_shape)
1401 batch1_cpu = torch.randn(batch1_shape)
1402 batch2_cpu = torch.randn(batch2_shape)
1403 alpha = 1.2
1404 beta = 0.8
Kulin Sethe011a8e2022-05-13 18:28:53 +00001405
Kulin Seth3d833212022-05-20 03:18:09 +00001406 M_mps = M_cpu.detach().clone().to("mps")
1407 batch1_mps = batch1_cpu.detach().clone().to("mps")
1408 batch2_mps = batch2_cpu.detach().clone().to("mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001409
Kulin Seth3d833212022-05-20 03:18:09 +00001410 output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
1411 output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001412
Kulin Seth3d833212022-05-20 03:18:09 +00001413 self.assertEqual(output_cpu, output_mps)
1414 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Sethd63db522022-05-28 14:41:56 +00001415
Kulin Seth3d833212022-05-20 03:18:09 +00001416 helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
1417 helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
1418 helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001419
1420 def test_local_scalar_dense_mps(self):
1421 x_cpu = torch.randn(1)
1422 y_mps = x_cpu.to("mps")
Philip Meierbc73aff2022-11-02 11:25:04 +01001423 torch.testing.assert_close(x_cpu.item(), y_mps.item())
Kulin Sethe011a8e2022-05-13 18:28:53 +00001424
Kulin Seth7ff6a002022-09-28 00:43:11 +00001425 def test_linear_1d_weight(self):
1426 device = 'cpu'
1427 projected = torch.rand([8]).to(device)
1428 x = torch.rand([1, 2, 2, 8]).to(device)
1429 x_mps = x.to('mps')
1430 projected_mps = projected.to('mps')
1431 linear = F.linear(x, projected)
1432 linear_mps = F.linear(x_mps, projected_mps)
1433
1434 self.assertEqual(linear, linear_mps)
1435
1436 projected = torch.rand([1, 8]).to(device)
1437 x = torch.rand([1, 2, 2, 8]).to(device)
1438 x_mps = x.to('mps')
1439 projected_mps = projected.to('mps')
1440 linear = F.linear(x, projected)
1441 linear_mps = F.linear(x_mps, projected_mps)
1442
1443 self.assertEqual(linear, linear_mps)
1444
Kulin Sethe011a8e2022-05-13 18:28:53 +00001445 def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
1446 cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
1447 mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
1448
1449 # Use the same weights and bias as the ones from the cpu
1450 mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
1451
1452 if bias:
1453 mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
1454
1455 linear_mps_input = torch.randn(shape).to('mps')
1456 linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
1457
1458 if backward_pass:
1459 linear_mps_input = linear_mps_input.requires_grad_()
1460 linear_cpu_input = linear_cpu_input.requires_grad_()
1461
1462 linear_cpu_output = cpu_linear(linear_cpu_input)
1463 linear_mps_output = mps_linear(linear_mps_input)
1464
1465 self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
1466 self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
1467
1468 if backward_pass:
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00001469 cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True)
1470 grad = cpu_grad.detach().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00001471
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00001472 linear_cpu_output.backward(gradient=cpu_grad, create_graph=True)
1473 linear_mps_output.backward(gradient=grad, create_graph=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001474
1475 self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
1476 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
1477
1478 self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
1479 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
1480 if bias:
1481 self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
1482 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
1483
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00001484 # test gradgrad
1485 x_grad_out = torch.rand_like(linear_cpu_input)
1486 x_grad_out_mps = x_grad_out.to("mps")
1487 w_grad_out = torch.rand_like(cpu_linear.weight)
1488 w_grad_out_mps = w_grad_out.to("mps")
1489
1490 linear_cpu_input.grad.detach().zero_()
1491 linear_mps_input.grad.detach().zero_()
1492 cpu_linear.weight.grad.detach().zero_()
1493 mps_linear.weight.grad.detach().zero_()
1494 if bias:
1495 b_grad_out = torch.rand_like(cpu_linear.bias)
1496 b_grad_out_mps = b_grad_out.to("mps")
1497 cpu_linear.bias.grad.detach().zero_()
1498 mps_linear.bias.grad.detach().zero_()
1499
1500 linear_cpu_input.grad.backward(x_grad_out, retain_graph=True)
1501 linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True)
1502 cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True)
1503 mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True)
1504 if bias:
1505 cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True)
1506 mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True)
1507
1508 self.assertEqual(cpu_grad.grad, grad.grad)
1509 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad)
1510 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad)
1511 if bias:
1512 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad)
1513
Ramin Azarmehr0e3953f2022-07-04 02:06:14 +00001514 def test_linear1D(self):
1515 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False)
1516
1517 def test_linear1D_backward(self):
1518 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True)
1519
Kulin Sethe011a8e2022-05-13 18:28:53 +00001520 def test_linear2D(self):
1521 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
1522
1523 def test_linear2D_backward(self):
1524 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
1525
1526 def test_linear2D_no_bias(self):
1527 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
1528
1529 def test_linear2D_no_bias_backward(self):
1530 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
1531
1532 def test_linear3D(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001533 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001534
Nikita Shulga70508262022-05-25 16:23:10 +00001535 def test_linear3D_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001536 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001537
1538 def test_linear3D_no_bias(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001539 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001540
1541 def test_linear3D_no_bias_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001542 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001543
1544 def test_uniform(self):
1545 low = torch.zeros(5, 5, requires_grad=True)
1546 high = (torch.ones(5, 5) * 3).requires_grad_()
1547 low_1d = torch.zeros(1, requires_grad=True)
1548 high_1d = (torch.ones(1) * 3).requires_grad_()
1549 self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
1550 self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
Kulin Seth3d833212022-05-20 03:18:09 +00001551 self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
1552 self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
1553 self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001554
Kulin Seth3d833212022-05-20 03:18:09 +00001555 # Check log_prob computation when value outside range
1556 uniform = Uniform(low_1d, high_1d, validate_args=False)
1557 above_high = torch.tensor([4.0])
1558 below_low = torch.tensor([-1.0])
1559 self.assertEqual(uniform.log_prob(above_high).item(), -inf)
1560 self.assertEqual(uniform.log_prob(below_low).item(), -inf)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001561
Kulin Seth3d833212022-05-20 03:18:09 +00001562 # check cdf computation when value outside range
1563 self.assertEqual(uniform.cdf(below_low).item(), 0)
1564 self.assertEqual(uniform.cdf(above_high).item(), 1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001565
Kulin Seth3d833212022-05-20 03:18:09 +00001566 state = torch.get_rng_state()
1567 rand = low.new(low.size()).uniform_()
1568 torch.set_rng_state(state)
1569 u = Uniform(low, high).rsample()
1570 u.backward(torch.ones_like(u))
1571 self.assertEqual(low.grad, 1 - rand)
1572 self.assertEqual(high.grad, rand)
1573 low.grad.zero_()
1574 high.grad.zero_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00001575
Denis Vieriu53ef96f2023-01-06 22:49:04 +00001576 def test_randperm(self, device="mps"):
1577 rng_device = None
1578 for n in (5, 100, 50000, 100000):
1579 for dtype in (torch.long, torch.half, torch.float):
1580 if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here.
1581 continue
1582 if n > 256 and dtype == torch.bfloat16:
1583 continue
1584 with torch.random.fork_rng(devices=rng_device):
1585 res1 = torch.randperm(n, dtype=dtype, device=device)
1586 res2 = torch.empty(0, dtype=dtype, device=device)
1587 torch.randperm(n, out=res2, dtype=dtype, device=device)
1588 self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
1589
1590 # Default type is long
1591 for n in (100, 10000):
1592 self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
1593
1594 # randperm of 0 elements is an empty tensor
1595 res1 = torch.randperm(0)
1596 res2 = torch.tensor(5, dtype=dtype, device=device)
1597 torch.randperm(0, out=res2)
1598 self.assertEqual(res1.numel(), 0)
1599 self.assertEqual(res2.numel(), 0)
1600
1601 # Test non-contiguous tensors
1602 for n in (4, 5, 6, 10, 20):
1603 non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
1604 self.assertFalse(non_contiguous_tensor.is_contiguous())
1605 with torch.random.fork_rng(devices=rng_device):
1606 res = torch.randperm(n, dtype=torch.long, device=device)
1607 torch.randperm(n, out=non_contiguous_tensor)
1608 self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
1609
Kulin Sethe011a8e2022-05-13 18:28:53 +00001610 # Test forward maxpool2d
1611 def test_max_pool2d(self):
1612 def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
1613
1614 cpu_x = None
Thomas4935b592022-11-23 02:18:03 +00001615 if (test_ties):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001616 cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
1617 else:
1618 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
1619 x = cpu_x.detach().clone().to('mps').requires_grad_()
1620
1621 pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
1622 ceil_mode=ceil_mode, return_indices=return_indices)
1623
Thomas4935b592022-11-23 02:18:03 +00001624 if (return_indices is False):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001625 y = pool(x)
1626 ref_y = pool(cpu_x)
1627
1628 cpu_grad = torch.ones_like(ref_y)
1629 grad = cpu_grad.to('mps')
1630
1631 y.backward(gradient=grad)
1632 ref_y.backward(gradient=cpu_grad)
1633
1634 self.assertEqual(y, ref_y)
1635 self.assertEqual(x.grad, cpu_x.grad)
1636 else:
1637 y, idx = pool(x)
1638 ref_y, ref_idx = pool(cpu_x)
1639
1640 cpu_grad = torch.ones_like(ref_y)
1641 grad = cpu_grad.to('mps')
1642
1643 y.backward(gradient=grad)
1644 ref_y.backward(gradient=cpu_grad)
1645
1646 self.assertEqual(y, ref_y)
1647 self.assertEqual(idx, ref_idx)
1648 self.assertEqual(x.grad, cpu_x.grad)
1649
1650 # Test with no batch dimension
1651 helper((8, 4, 4), ks=2)
1652 helper((2, 8, 4, 4), ks=2)
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001653 helper((1, 1000, 32, 32), ks=4)
1654 helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00001655 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001656 helper((1, 1000, 32, 32), ks=4, padding=1)
1657 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00001658 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001659 helper((1, 1000, 32, 32), ks=4, dilation=2)
1660 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00001661 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001662 helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
1663 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00001664
1665 # Test return indices
1666 for test_ties in [False, True]:
1667 # Test with no batch dimension
1668 helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
1669 helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001670 helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
1671 helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00001672 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001673 helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
1674 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
Kulin Sethe011a8e2022-05-13 18:28:53 +00001675 return_indices=True, test_ties=test_ties) # test for max_pool1d
1676 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001677 helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
1678 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
Kulin Sethe011a8e2022-05-13 18:28:53 +00001679 return_indices=True, test_ties=test_ties) # test for max_pool1d
1680 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001681 helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
1682 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
Kulin Sethe011a8e2022-05-13 18:28:53 +00001683 return_indices=True, test_ties=test_ties) # test for max_pool1d
1684
1685 def test_adaptive_avg_pool2d_output_size_one(self):
1686 def helper(size, memory_format):
1687 x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
Kulin Seth3d833212022-05-20 03:18:09 +00001688 if memory_format == 'non_contiguous':
1689 x = x[::2, ::2, ::2, ::2]
1690 else:
1691 x = x.to(memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001692
1693 net = torch.nn.AdaptiveAvgPool2d((1, 1))
1694 out = net(x)
1695 ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
1696
1697 out.sum().backward() # make sure it doesn't crash
1698
1699 self.assertEqual(out, ref_out)
1700 if memory_format == torch.channels_last:
1701 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
1702 c = out.size(1)
1703 self.assertEqual(out.stride(), [c, 1, c, c])
1704 else:
1705 self.assertTrue(out.is_contiguous())
1706 c = out.size(1)
1707 self.assertEqual(out.stride(), [c, 1, 1, 1])
1708
1709 helper((2, 3, 6, 6), torch.contiguous_format)
1710
Denis Vieriued1957d2023-03-01 01:36:36 +00001711 def test_masked_scatter(self):
1712 def helper(shape):
1713 x_mps = torch.randn(shape, device="mps")
1714 x_cpu = x_mps.detach().clone().cpu()
1715
1716 mask_mps = torch.rand(shape, device="mps") < 0.6
1717 mask_cpu = mask_mps.detach().clone().cpu()
1718
1719 y_mps = torch.randn(shape, device="mps")
1720 y_cpu = y_mps.detach().clone().cpu()
1721
1722 y_mps.masked_scatter_(mask_mps, x_mps)
1723 y_cpu.masked_scatter_(mask_cpu, x_cpu)
1724
1725 self.assertEqual(y_mps, y_cpu)
1726 helper([2, 5])
1727 helper([10, 10])
1728 helper([5, 10, 3])
1729 helper([10, 5, 10, 3])
1730 helper([10, 5, 10, 3, 20])
1731
Kulin Seth3d833212022-05-20 03:18:09 +00001732 def test_masked_fill(self):
1733 device = "mps"
1734 dtype = torch.float32
1735 mask_dtype = torch.bool
1736
1737 with warnings.catch_warnings(record=True) as w:
1738 warnings.simplefilter("always")
1739 num_dest = 10
1740 dst = torch.zeros(num_dest, dtype=dtype, device=device)
1741 mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
1742 val = random.random()
1743 dst2 = torch.zeros(num_dest, dtype=dtype)
1744 mask_cpu = mask.to("cpu")
1745
1746 dst.masked_fill_(mask, val)
1747 for i in range(num_dest):
1748 if mask_cpu[i]:
1749 dst2[i] = val
1750 self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
1751
1752 # test non-contiguous case
1753 dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
1754 dst2 = dst.contiguous()
1755 if dtype.is_complex:
1756 mask = dst.abs() > 0
1757 else:
1758 mask = dst > 0
1759 self.assertTrue(not dst.is_contiguous())
1760 self.assertTrue(dst2.is_contiguous())
1761 dst.masked_fill_(mask.to(mask_dtype), val)
1762 dst2.masked_fill_(mask.to(mask_dtype), val)
1763 self.assertEqual(dst, dst2, atol=0, rtol=0)
1764
1765 if mask_dtype == torch.uint8:
1766 self.assertEqual(len(w), 3)
1767
1768 warn = 'masked_fill_ received a mask with dtype torch.uint8,'
1769 for wi in w:
1770 self.assertEqual(str(wi.message)[0:52], str(warn))
1771 else:
1772 self.assertEqual(len(w), 0)
1773
1774 def test_nhwc_operation(self):
1775 def helper(shape, channels_last=False):
1776 import numpy as np
1777 np.random.seed(332)
1778 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
1779 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00001780 if (channels_last):
Kulin Seth3d833212022-05-20 03:18:09 +00001781 cpu_x = cpu_x.to(memory_format=torch.channels_last)
1782 cpu_x.retain_grad()
1783 x = cpu_x.detach().clone().to('mps').requires_grad_()
1784
1785 # This passes
1786 self.assertEqual(x, cpu_x)
1787
1788 helper((2, 2, 2, 2), True)
1789
Kulin Sethe011a8e2022-05-13 18:28:53 +00001790 # Test forward batch norm
1791 def test_batch_norm(self):
1792 def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
1793 track_running_stats=True, test_module=False):
1794
1795 import numpy as np
1796 np.random.seed(332)
1797 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
1798 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00001799 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001800 cpu_x = cpu_x.to(memory_format=torch.channels_last)
1801 cpu_x.retain_grad()
1802 x = cpu_x.detach().clone().to('mps').requires_grad_()
1803
1804 mean_shape = [shape[1]]
1805 cpu_running_mean = None
1806 cpu_running_var = None
1807 running_mean = None
1808 running_var = None
Thomas4935b592022-11-23 02:18:03 +00001809 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001810 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
1811 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
1812 var_arr = 32 * np.random.random_sample(size=mean_shape)
1813 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
1814 running_mean = cpu_running_mean.detach().clone().to('mps')
1815 running_var = cpu_running_var.detach().clone().to('mps')
1816
1817 weight = None
1818 cpu_weight = None
1819 bias = None
1820 cpu_bias = None
Thomas4935b592022-11-23 02:18:03 +00001821 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001822 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
1823 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
1824 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
1825 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
1826
1827 y = None
1828 ref_y = None
1829
Thomas4935b592022-11-23 02:18:03 +00001830 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001831 y = torch.nn.functional.batch_norm(x, running_mean, running_var,
1832 weight=weight,
1833 bias=bias,
1834 training=training,
1835 momentum=momentum, eps=eps)
1836 ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
1837 weight=cpu_weight,
1838 bias=cpu_bias,
1839 training=training,
1840 momentum=momentum, eps=eps)
1841
1842 else:
1843
1844 batchnorm_op = None
1845 mps_batchnorm_op = None
1846
Thomas4935b592022-11-23 02:18:03 +00001847 if (len(shape) == 3):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001848 batchnorm_op = torch.nn.BatchNorm1d(shape[1],
1849 eps=eps,
1850 momentum=momentum,
1851 affine=wts,
1852 track_running_stats=track_running_stats,
1853 device='cpu')
1854 mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
1855 eps=eps,
1856 momentum=momentum,
1857 affine=wts,
1858 track_running_stats=track_running_stats,
1859 device='mps')
Thomas4935b592022-11-23 02:18:03 +00001860 elif (len(shape) == 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001861 batchnorm_op = torch.nn.BatchNorm2d(shape[1],
1862 eps=eps,
1863 momentum=momentum,
1864 affine=wts,
1865 track_running_stats=track_running_stats,
1866 device='cpu')
1867 mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
1868 eps=eps,
1869 momentum=momentum,
1870 affine=wts,
1871 track_running_stats=track_running_stats,
1872 device='mps')
Thomas4935b592022-11-23 02:18:03 +00001873 elif (len(shape) == 5):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001874 batchnorm_op = torch.nn.BatchNorm3d(shape[1],
1875 eps=eps,
1876 momentum=momentum,
1877 affine=wts,
1878 track_running_stats=track_running_stats,
1879 device='cpu')
1880 mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
1881 eps=eps,
1882 momentum=momentum,
1883 affine=wts,
1884 track_running_stats=track_running_stats,
1885 device='mps')
1886
Thomas4935b592022-11-23 02:18:03 +00001887 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001888 batchnorm_op.running_mean = cpu_running_mean
1889 batchnorm_op.running_var = cpu_running_var
1890 mps_batchnorm_op.running_mean = running_mean
1891 mps_batchnorm_op.running_var = running_var
Thomas4935b592022-11-23 02:18:03 +00001892 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001893 batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
1894 batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
1895 mps_batchnorm_op.weight = torch.nn.Parameter(weight)
1896 mps_batchnorm_op.bias = torch.nn.Parameter(bias)
1897
1898 ref_y = batchnorm_op(cpu_x)
1899 y = mps_batchnorm_op(x)
1900
1901 self.assertEqual(y, ref_y)
Thomas4935b592022-11-23 02:18:03 +00001902 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001903 self.assertEqual(running_mean, cpu_running_mean)
1904 self.assertEqual(running_var, cpu_running_var)
1905 else:
1906 self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
1907 self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
1908
1909 cpu_grad = torch.randn(ref_y.shape)
1910 grad = cpu_grad.to('mps')
1911 ref_y.backward(gradient=cpu_grad)
1912 y.backward(gradient=grad)
1913
1914 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00001915 if (wts):
1916 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001917 self.assertEqual(weight.grad, cpu_weight.grad)
1918 self.assertEqual(bias.grad, cpu_bias.grad)
1919 else:
1920 self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
1921 self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
1922
1923 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
1924 for test_module in [False, True]:
1925 for track_running_stats in [True, False]:
Kulin Seth3d833212022-05-20 03:18:09 +00001926 for channels_last in [False]:
Thomas4935b592022-11-23 02:18:03 +00001927 if (channels_last and len(shape) != 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001928 continue
1929 # Running stats must be tracked in eval mode
Thomas4935b592022-11-23 02:18:03 +00001930 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001931 helper(shape, eps=0, momentum=1, channels_last=channels_last,
1932 track_running_stats=track_running_stats, test_module=test_module)
1933 helper(shape, channels_last=channels_last,
1934 track_running_stats=track_running_stats, test_module=test_module)
1935 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
1936 track_running_stats=track_running_stats, test_module=test_module)
1937 helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
1938 track_running_stats=track_running_stats, test_module=test_module)
1939 helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
1940 track_running_stats=track_running_stats, test_module=test_module)
1941 helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
1942 track_running_stats=track_running_stats, test_module=test_module)
1943 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
1944 track_running_stats=track_running_stats, test_module=test_module)
1945 helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
1946 track_running_stats=track_running_stats, test_module=test_module)
1947 helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
1948 track_running_stats=track_running_stats, test_module=test_module)
1949 helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
1950 track_running_stats=track_running_stats, test_module=test_module)
1951
Denis Vieriu80394bb2023-01-04 02:20:50 +00001952 def test_norm(self):
1953 a = torch.arange(9, dtype=torch.float, device="mps") - 4
1954 b = a.reshape((3, 3))
1955
1956 a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
1957 b_cpu = a_cpu.reshape((3, 3))
1958
1959 res = torch.norm(a)
1960 res_cpu = torch.norm(a_cpu)
1961 self.assertEqual(res, res_cpu)
1962
1963 res = torch.norm(b)
1964 res_cpu = torch.norm(b_cpu)
1965 self.assertEqual(res, res_cpu)
1966
1967 res = torch.norm(a, float('inf'))
1968 res_cpu = torch.norm(a_cpu, float('inf'))
1969 self.assertEqual(res, res_cpu)
1970
1971 res = torch.norm(b, float('inf'))
1972 res_cpu = torch.norm(b_cpu, float('inf'))
1973 self.assertEqual(res, res_cpu)
1974
1975 c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
1976 c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
1977
1978 res = torch.norm(c, dim=0)
1979 res_cpu = torch.norm(c_cpu, dim=0)
1980 self.assertEqual(res, res_cpu)
1981
1982 res = torch.norm(c, dim=1)
1983 res_cpu = torch.norm(c_cpu, dim=1)
1984 self.assertEqual(res, res_cpu)
1985
1986 res = torch.norm(c, p=1, dim=1)
1987 res_cpu = torch.norm(c_cpu, p=1, dim=1)
1988 self.assertEqual(res, res_cpu)
1989
1990 d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
1991 d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
1992
1993 res = torch.norm(d, dim=(1, 2))
1994 res_cpu = torch.norm(d_cpu, dim=(1, 2))
1995 self.assertEqual(res, res_cpu)
1996
1997 res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
1998 res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
1999 self.assertEqual(res, res_cpu)
2000
Kulin Seth77b68852022-06-10 13:25:41 +00002001 def test_layer_norm(self):
2002 # TODO: Test non-contiguous
2003 def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):
2004 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
2005 x = cpu_x.detach().clone().to('mps').requires_grad_()
2006
2007 cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype)
2008 mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
2009 cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2010 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2011 cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2012 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2013
Thomas4935b592022-11-23 02:18:03 +00002014 if (elementwise_affine):
Kulin Seth77b68852022-06-10 13:25:41 +00002015 cpu_op.weight = torch.nn.Parameter(cpu_wt)
2016 mps_op.weight = torch.nn.Parameter(wt)
2017 cpu_op.bias = torch.nn.Parameter(cpu_bias)
2018 mps_op.bias = torch.nn.Parameter(bias)
2019
2020 cpu_result = cpu_op(cpu_x)
2021 result = mps_op(x)
2022
2023 cpu_grad = torch.randn(cpu_result.shape)
2024 grad = cpu_grad.to('mps')
2025
2026 cpu_result.backward(cpu_grad)
2027 result.backward(grad)
2028
2029 self.assertEqual(result, cpu_result)
2030 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002031 if (elementwise_affine):
Kulin Seth77b68852022-06-10 13:25:41 +00002032 self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad)
2033 self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad)
2034
2035 for elementwise_affine in [True, False]:
2036 helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine)
2037 helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
2038 helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
2039
Nikita Shulga075a4942023-03-09 22:09:10 +00002040 # Regression test for https://github.com/pytorch/pytorch/issues/96113
2041 torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
2042
Kulin Sethe011a8e2022-05-13 18:28:53 +00002043 def test_instance_norm(self):
2044 def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
2045
2046 import numpy as np
2047 np.random.seed(332)
2048 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2049 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00002050 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002051 cpu_x = cpu_x.to(memory_format=torch.channels_last)
2052 cpu_x.retain_grad()
2053 x = cpu_x.detach().clone().to('mps').requires_grad_()
2054
2055 mean_shape = [shape[1]]
2056 cpu_running_mean = None
2057 cpu_running_var = None
2058 running_mean = None
2059 running_var = None
Thomas4935b592022-11-23 02:18:03 +00002060 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002061 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2062 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2063 var_arr = 32 * np.random.random_sample(size=mean_shape)
2064 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2065 running_mean = cpu_running_mean.detach().clone().to('mps')
2066 running_var = cpu_running_var.detach().clone().to('mps')
2067
2068 weight = None
2069 cpu_weight = None
2070 bias = None
2071 cpu_bias = None
Thomas4935b592022-11-23 02:18:03 +00002072 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002073 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2074 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2075 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2076 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2077
2078 y = None
2079 ref_y = None
2080
Thomas4935b592022-11-23 02:18:03 +00002081 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002082 ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
2083 weight=cpu_weight,
2084 bias=cpu_bias,
2085 momentum=momentum, eps=eps)
2086 y = torch.nn.functional.instance_norm(x, running_mean, running_var,
2087 weight=weight,
2088 bias=bias,
2089 momentum=momentum, eps=eps)
2090
2091 else:
2092
2093 instancenorm_op = None
2094 mps_instancenorm_op = None
2095
Thomas4935b592022-11-23 02:18:03 +00002096 if (len(shape) == 3):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002097 instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2098 eps=eps,
2099 momentum=momentum,
2100 affine=wts,
2101 track_running_stats=track_running_stats,
2102 device='cpu')
2103 mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2104 eps=eps,
2105 momentum=momentum,
2106 affine=wts,
2107 track_running_stats=track_running_stats,
2108 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002109 elif (len(shape) == 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002110 instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2111 eps=eps,
2112 momentum=momentum,
2113 affine=wts,
2114 track_running_stats=track_running_stats,
2115 device='cpu')
2116 mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2117 eps=eps,
2118 momentum=momentum,
2119 affine=wts,
2120 track_running_stats=track_running_stats,
2121 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002122 elif (len(shape) == 5):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002123 instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2124 eps=eps,
2125 momentum=momentum,
2126 affine=wts,
2127 track_running_stats=track_running_stats,
2128 device='cpu')
2129 mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2130 eps=eps,
2131 momentum=momentum,
2132 affine=wts,
2133 track_running_stats=track_running_stats,
2134 device='mps')
2135
Thomas4935b592022-11-23 02:18:03 +00002136 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002137 instancenorm_op.running_mean = cpu_running_mean
2138 instancenorm_op.running_var = cpu_running_var
2139 mps_instancenorm_op.running_mean = running_mean
2140 mps_instancenorm_op.running_var = running_var
Thomas4935b592022-11-23 02:18:03 +00002141 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002142 instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
2143 instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
2144 mps_instancenorm_op.weight = torch.nn.Parameter(weight)
2145 mps_instancenorm_op.bias = torch.nn.Parameter(bias)
2146
2147 ref_y = instancenorm_op(cpu_x)
2148 y = mps_instancenorm_op(x)
2149
2150 self.assertEqual(y, ref_y)
Thomas4935b592022-11-23 02:18:03 +00002151 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002152 self.assertEqual(running_mean, cpu_running_mean)
2153 self.assertEqual(running_var, cpu_running_var)
2154 else:
2155 self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
2156 self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
2157
2158 cpu_grad = torch.randn(ref_y.shape)
2159 grad = cpu_grad.to('mps')
2160 ref_y.backward(gradient=cpu_grad)
2161 y.backward(gradient=grad)
2162
2163 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002164 if (wts):
2165 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002166 self.assertEqual(weight.grad, cpu_weight.grad)
2167 self.assertEqual(bias.grad, cpu_bias.grad)
2168 else:
2169 self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
2170 self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
2171
2172 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2173 for test_module in [False, True]:
2174 for track_running_stats in [True, False]:
2175 for channels_last in [False]:
Thomas4935b592022-11-23 02:18:03 +00002176 if (channels_last and len(shape) != 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002177 continue
2178 # Running stats must be tracked in eval mode
Thomas4935b592022-11-23 02:18:03 +00002179 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002180 helper(shape, eps=0, momentum=1, channels_last=channels_last,
2181 track_running_stats=track_running_stats, test_module=test_module)
2182 helper(shape, channels_last=channels_last,
2183 track_running_stats=track_running_stats, test_module=test_module)
2184 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2185 track_running_stats=track_running_stats, test_module=test_module)
2186 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2187 track_running_stats=track_running_stats, test_module=test_module)
2188 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2189 track_running_stats=track_running_stats, test_module=test_module)
2190 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2191 track_running_stats=track_running_stats, test_module=test_module)
2192 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2193 track_running_stats=track_running_stats, test_module=test_module)
2194 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2195 track_running_stats=track_running_stats, test_module=test_module)
2196 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2197 track_running_stats=track_running_stats, test_module=test_module)
2198 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2199 track_running_stats=track_running_stats, test_module=test_module)
2200
2201 # Test conv2d
2202 def test_conv2d_unit(self):
2203 def helper(input_shape, wt_shape,
2204 stride=1, padding=0,
2205 dilation=1, groups=1,
2206 bias_shape=None):
2207
2208 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
2209 x = cpu_x.detach().clone().to('mps').requires_grad_()
2210
2211 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
2212 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2213
2214 cpu_bias = None
2215 bias = None
2216
Thomas4935b592022-11-23 02:18:03 +00002217 if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002218 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
2219 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2220
2221 y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
2222 padding=padding, dilation=dilation, groups=groups)
2223 ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
2224 padding=padding, dilation=dilation, groups=groups)
2225
2226 cpu_grad = torch.ones_like(ref_y)
2227 grad = cpu_grad.to('mps')
2228
2229 y.backward(gradient=grad)
2230 ref_y.backward(gradient=cpu_grad)
2231
2232 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
2233 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
2234 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
Thomas4935b592022-11-23 02:18:03 +00002235 if (bias_shape is not None):
Kulin Seth3d833212022-05-20 03:18:09 +00002236 self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002237
2238 N = 1
2239 C_in = 3
2240 C_out = 64
2241 H = 64
2242 W = 64
2243 kH = 4
2244 kW = 4
2245 stride = 2
2246 padding = 1
2247
2248 helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
2249
2250 N = 4
2251 C_in = 16
2252 H = 32
2253 W = 32
2254
2255 C_out = 8
2256 kH = 3
2257 kW = 3
2258
2259 for groups in [1, 2, 4]:
2260 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
2261 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
2262
2263 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
2264 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
2265
2266 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
2267 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
2268
2269 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
2270 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
2271 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
2272 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
2273
2274 # Test conv transpose 2d
2275 def test_conv_transpose2d(self):
2276 def helper(input_shape, wt_shape,
2277 stride=1, padding=0,
2278 output_padding=0,
2279 dilation=1, groups=1,
2280 bias_shape=None):
2281
2282 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
2283 x = cpu_x.detach().clone().to('mps').requires_grad_()
2284
2285 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
2286 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2287
2288 cpu_bias = None
2289 bias = None
2290
Thomas4935b592022-11-23 02:18:03 +00002291 if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002292 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
2293 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2294
2295 y = torch.nn.functional.conv_transpose2d(
2296 x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
2297 ref_y = torch.nn.functional.conv_transpose2d(
2298 cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
2299 output_padding=output_padding, groups=groups, dilation=dilation)
2300
2301 cpu_grad = torch.randn(ref_y.shape)
2302 grad = cpu_grad.to('mps')
2303
2304 y.backward(gradient=grad)
2305 ref_y.backward(gradient=cpu_grad)
2306
2307 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
2308 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
2309 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
2310
Thomas4935b592022-11-23 02:18:03 +00002311 # if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002312 # print(cpu_bias.grad)
2313 # print(bias.grad.to('cpu'))
2314 # self.assertEqual(bias.grad, cpu_bias.grad)
2315
2316 N = 4
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002317 C_in = 2
Kulin Sethe011a8e2022-05-13 18:28:53 +00002318 H = 32
2319 W = 32
2320
2321 C_out = 8
2322 groups = 1
2323 kH = 3
2324 kW = 3
2325
2326 for stride in [1, 2, 3]:
2327 for padding in [0, 1, 2]:
2328 for output_padding in [0, 1, 2]:
2329 for dilation in [1, 2]:
Thomas4935b592022-11-23 02:18:03 +00002330 if (output_padding >= stride or output_padding >= dilation):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002331 continue
2332 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
2333 padding=padding, output_padding=output_padding, dilation=dilation)
2334 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
2335 padding=padding, output_padding=output_padding, dilation=dilation)
2336
2337 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
2338 padding=padding, output_padding=output_padding, dilation=dilation)
2339 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
2340 padding=padding, output_padding=output_padding, dilation=dilation)
2341
2342 # Test sigmoid
2343 def test_sigmoid(self):
2344 def helper(shape):
2345
2346 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2347 x = cpu_x.detach().clone().to('mps').requires_grad_()
2348
2349 sigmoid_op = torch.nn.Sigmoid()
2350
2351 y = sigmoid_op(x)
2352 ref_y = sigmoid_op(cpu_x)
2353
2354 cpu_grad = torch.ones_like(ref_y)
2355 grad = cpu_grad.to('mps')
2356
2357 y.backward(gradient=grad)
2358 ref_y.backward(gradient=cpu_grad)
2359
2360 self.assertEqual(y, ref_y)
2361 self.assertEqual(x.grad, cpu_x.grad)
2362
2363 helper((2, 3, 4, 5))
2364 helper((2, 3, 4))
2365 helper((2, 8, 4, 5))
2366
2367 # Test tanh
2368 def test_tanh(self):
2369 def helper(shape):
2370
2371 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2372 x = cpu_x.detach().clone().to('mps').requires_grad_()
2373
2374 tanh_op = torch.nn.Tanh()
2375
2376 y = tanh_op(x)
2377 ref_y = tanh_op(cpu_x)
2378
2379 cpu_grad = torch.ones_like(ref_y)
2380 grad = cpu_grad.to('mps')
2381
2382 y.backward(gradient=grad)
2383 ref_y.backward(gradient=cpu_grad)
2384
2385 self.assertEqual(y, ref_y)
2386 self.assertEqual(x.grad, cpu_x.grad)
2387
2388 helper((2, 3, 4, 5))
2389 helper((2, 3, 4))
2390 helper((2, 8, 4, 5))
2391
2392 def test_threshold(self):
2393 def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
2394 m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
2395
2396 input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
2397 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
2398
2399 output_cpu = m(input_cpu)
2400 output_mps = m(input_mps)
2401
2402 cpu_grad = torch.ones_like(output_cpu)
2403 mps_grad = cpu_grad.to('mps')
2404
2405 self.assertEqual(output_cpu, output_mps)
2406
2407 if requires_grad:
2408 output_cpu.backward(gradient=cpu_grad)
2409 output_mps.backward(gradient=mps_grad)
2410
2411 self.assertEqual(input_cpu.grad, input_mps.grad)
2412
2413 helper(threshold=0.1, value=20, num_elems=2)
2414 helper(threshold=-0.1, value=10, num_elems=10)
2415 helper(threshold=0.5, value=-15, num_elems=100)
2416 helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
2417
2418 # Test pow
2419 def test_pow(self):
2420 def helper(shape):
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00002421 # aten::pow.Tensor_Tensor
Kulin Sethe011a8e2022-05-13 18:28:53 +00002422 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2423 x = cpu_x.detach().clone().to('mps')
2424 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2425 y = cpu_y.detach().clone().to('mps')
2426 z = torch.pow(x, y)
2427 ref_z = torch.pow(cpu_x, cpu_y)
2428
2429 self.assertEqual(z, ref_z)
2430
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00002431 # aten::pow.Tensor_Scalar
Kulin Sethe011a8e2022-05-13 18:28:53 +00002432 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2433 x = cpu_x.detach().clone().to('mps')
2434 exp = random.random()
2435 z = torch.pow(x, exp)
2436 ref_z = torch.pow(cpu_x, exp)
2437
2438 self.assertEqual(z, ref_z)
2439
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00002440 # aten::pow.Scalar
2441 x = random.random()
2442 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2443 y = cpu_y.detach().clone().to('mps')
2444 z = torch.pow(x, y)
2445 ref_z = torch.pow(x, cpu_y)
2446
2447 self.assertEqual(z, ref_z)
2448
Kulin Sethe011a8e2022-05-13 18:28:53 +00002449 helper((2, 8, 4, 5))
2450
2451 # Test addcmul
2452 def test_addcmul(self):
Nikita Shulga769cc8a2023-03-07 04:19:30 +00002453 def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None):
2454 def rand_helper(dtype):
2455 if dtype.is_floating_point:
2456 return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
2457 return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002458
Nikita Shulga769cc8a2023-03-07 04:19:30 +00002459 cpu_x = rand_helper(xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002460 x = cpu_x.detach().clone().to('mps')
2461
Nikita Shulga769cc8a2023-03-07 04:19:30 +00002462 cpu_y = rand_helper(ytype if ytype is not None else xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002463 y = cpu_y.detach().clone().to('mps')
2464
Nikita Shulga769cc8a2023-03-07 04:19:30 +00002465 cpu_z = rand_helper(ztype if ztype is not None else xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002466 z = cpu_z.detach().clone().to('mps')
2467
2468 y = torch.addcmul(x, y, z, value=value)
2469 ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
2470
2471 self.assertEqual(y, ref_y)
2472
2473 helper((2, 3, 4, 5), 0.1)
2474 helper((2, 8, 4, 5), 0.1)
2475 helper((2, 3, 4, 5), 0.2)
2476 helper((2, 8, 4, 5), 0.2)
Nikita Shulga769cc8a2023-03-07 04:19:30 +00002477 # Integral types
2478 helper((2, 2), 1.0, xtype=torch.int32)
2479 helper((2, 2), 2.0, xtype=torch.int16)
2480
2481 # Mixed types
2482 helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32)
2483 helper((3, 2), 1.0, ytype=torch.float16)
2484 helper((2, 3), 1.0, ztype=torch.float16)
2485 helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8)
2486 helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002487
2488 # Test addcdiv
2489 def test_addcdiv(self):
2490 def helper(shape, value):
2491 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2492 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2493 # clamp to avoid division by 0
2494 cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
2495 cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2496
2497 mps_x = cpu_x.detach().clone().to('mps')
2498 mps_y = cpu_y.detach().clone().to('mps')
2499 mps_z = cpu_z.detach().clone().to('mps')
2500 mps_out = cpu_out.detach().clone().to('mps')
2501
2502 result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
2503 result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
2504 self.assertEqual(result_div_mps, result_div_cpu)
2505 # test .out variant
2506 self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
2507
2508 helper((2, 3, 4, 5), 0.1)
2509 helper((2, 8, 4, 5), 0.2)
2510 helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally
2511
Ramin Azarmehraa62b3e2022-05-31 19:15:45 +00002512 def test_buffer_size_match(self):
2513 # this test shouldn't cause any crash
2514 size = 16
2515 cpu_A = torch.rand(size, device='cpu')
2516 cpu_F = torch.rand(size, size, size, device='cpu')
2517
2518 mps_A = cpu_A.to('mps')
2519 mps_F = cpu_F.to('mps')
2520 self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
2521
Kulin Sethe011a8e2022-05-13 18:28:53 +00002522 def test_transpose_inplace(self):
2523 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
2524 cpu_x = torch.tensor(values, device='cpu')
2525 mps_x = torch.tensor(values, device='mps')
2526
2527 cpu_x.transpose_(0, 1)
2528 mps_x.transpose_(0, 1)
2529 self.assertEqual(cpu_x, mps_x.to('cpu'))
2530
Kulin Seth4858c562022-06-02 06:17:19 +00002531 def test_expand_cpu_to_mps_copy(self):
2532 # https://github.com/pytorch/pytorch/issues/78642
2533
2534 x = torch.tensor(1).expand([10]).to("mps")
2535 x_cpu = torch.tensor(1).expand([10])
2536
2537 self.assertEqual(x_cpu, x.cpu())
2538
Denis Vieriu0a677f22023-01-10 22:45:48 +00002539 def test_cpu_to_strided_mps_copy(self):
2540 # https://github.com/pytorch/pytorch/issues/86975
2541
2542 a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
2543 b1 = torch.Tensor([-1, -1])
2544 a1[1:, 1] = b1
2545
2546 a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
2547 b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
2548 a2[1:, 1] = b2
2549
2550 self.assertEqual(a1, a2)
2551
Denis Vieriue3ac1092023-02-07 16:20:08 +00002552 def test_view_slice_reshape(self):
2553 x = torch.randn([1, 4, 4], device="mps")
2554 y = x[0, :1, 1:]
2555
2556 x_cpu = x.to("cpu")
2557 y_cpu = x_cpu[0, :1, 1:]
2558
2559 r = y + 1
2560 r_cpu = y_cpu + 1
2561 self.assertEqual(r, r_cpu)
2562
2563 def test_slice_reshape(self):
2564 x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
2565 x_cpu = x.detach().clone().to("cpu")
2566
2567 x = x[:, 3:].view(2, 3, 4, 1)
2568 x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1)
2569 self.assertEqual(x, x_cpu)
2570
2571 x = x + 2
2572 x_cpu = x_cpu + 2
2573 self.assertEqual(x, x_cpu)
2574
Denis Vieriu304a9542023-03-03 08:08:31 +00002575 def test_reshape_storage_offset(self):
2576 # https://github.com/pytorch/pytorch/issues/95883
2577 B = 4
2578 T = 1
2579
2580 lin_cpu = nn.Linear(10, 256)
2581 lin_mps = nn.Linear(10, 256, device="mps")
2582
2583 # Use the same weights and bias as the ones from the cpu
2584 lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
2585 lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
2586
2587 x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
2588 x_cpu = x_mps.detach().clone().cpu().requires_grad_()
2589 x_mps = lin_mps(x_mps)
2590 x_cpu = lin_cpu(x_cpu)
2591
2592 self.assertEqual(x_mps.shape, (B, T, 256))
2593 self.assertEqual(x_cpu.shape, (B, T, 256))
2594
2595 cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
2596 cls_token_cpu = cls_token_mps.detach().clone().cpu()
2597 x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
2598 x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
2599
2600 x_mps = x_mps.transpose(0, 1)
2601 x_cpu = x_cpu.transpose(0, 1)
2602
2603 target_mps = torch.rand_like(x_mps)
2604 target_cpu = target_mps.detach().clone().cpu()
2605 loss_mps = F.mse_loss(x_mps, target_mps)
2606 loss_cpu = F.mse_loss(x_cpu, target_cpu)
2607 self.assertEqual(loss_mps, loss_cpu)
2608
2609 loss_mps.backward()
2610 loss_cpu.backward()
2611 self.assertEqual(x_mps.grad, x_cpu.grad)
2612
2613 def test_stack(self):
2614 # https://github.com/pytorch/pytorch/issues/87856
2615 x_cpu = torch.tensor([[1, 2]])
2616 x_mps = x_cpu.detach().clone().to("mps")
2617
2618 y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
2619 y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
2620
2621 self.assertEqual(y_cpu, y_mps)
2622
2623 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
2624 t_cpu = t_mps.detach().cpu().detach()
2625
2626 x_mps = t_mps[2:]
2627 y_mps = t_mps[:2]
2628
2629 x_cpu = t_cpu[2:]
2630 y_cpu = t_cpu[:2]
2631
2632 res_mps = torch.stack((y_mps, x_mps), dim=-1)
2633 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
2634
2635 self.assertEqual(res_mps, res_cpu)
2636
2637 def test_unsafe_chunk(self):
2638 # https://github.com/pytorch/pytorch/issues/91065
2639 a = torch.rand(5, dtype=torch.float32, device="cpu")
2640 ret = a.unsafe_chunk(4, 0)
2641 y = ret[0] * ret[2]
2642 a_mps = a.to("mps")
2643 ret_mps = a_mps.unsafe_chunk(4, 0)
2644 y_mps = ret_mps[0] * ret_mps[2]
2645 self.assertEqual(y, y_mps)
2646
Ramin Azarmehr9511b9f2023-02-18 16:29:01 +00002647 def test_slice_casting(self):
2648 # generate random binary numbers
2649 cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
2650 mps_in = cpu_in.detach().clone().to("mps")
2651 # check copy_cast(unit8 -> bool) on tensors with storage offset
2652 cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
2653 mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
2654 self.assertEqual(cpu_out, mps_out)
2655
Denis Vieriue3ac1092023-02-07 16:20:08 +00002656 def test_slice_reshape_contg_view(self):
2657 import torch
2658
2659 x_mps = torch.randn(1, 4800, 2, device="mps")
2660 x_cpu = x_mps.detach().clone().cpu()
2661
2662 r_mps = x_mps + 2
2663 r_cpu = x_cpu + 2
2664
2665 self.assertEqual(r_mps, r_cpu)
2666
Denis Vieriu86efa102023-02-23 17:26:10 +00002667 def test_contiguous_slice_2d(self):
2668 def helper(shape):
2669 for i in range(0, shape[0]):
2670 for j in range(0, shape[1]):
2671 t_mps = torch.randn(shape, device="mps")
2672 t_cpu = t_mps.detach().clone().cpu()
2673
2674 y_mps = t_mps[i:, :j]
2675 y_cpu = t_cpu[i:, :j]
2676 self.assertEqual(y_mps + 1, y_cpu + 1)
2677
2678 y_mps = t_mps[i:, j]
2679 y_cpu = t_cpu[i:, j]
2680 self.assertEqual(y_mps + 1, y_cpu + 1)
2681
2682 y_mps = t_mps[i, :j]
2683 y_cpu = t_cpu[i, :j]
2684 self.assertEqual(y_mps + 1, y_cpu + 1)
2685
2686 y_mps = t_mps[:i, :j]
2687 y_cpu = t_cpu[:i, :j]
2688 self.assertEqual(y_mps + 1, y_cpu + 1)
2689
2690 y_mps = t_mps[:i, j]
2691 y_cpu = t_cpu[:i, j]
2692 self.assertEqual(y_mps + 1, y_cpu + 1)
2693
2694 y_mps = t_mps[:i, j:]
2695 y_cpu = t_cpu[:i, j:]
2696 self.assertEqual(y_mps + 1, y_cpu + 1)
2697
2698 l = []
2699 for N in range(1, 3):
2700 l.append(N)
2701 for C in range(1, 3):
2702 l.append(C)
2703 helper(l)
2704 for D in range(1, 3):
2705 l.append(D)
2706 helper(l)
2707 for H in range(1, 3):
2708 l.append(H)
2709 helper(l)
2710 for W in range(1, 3):
2711 l.append(W)
2712 helper(l)
2713 l.pop()
2714 l.pop()
2715 l.pop()
2716 l.pop()
2717 l.pop()
2718
2719 helper([9, 15, 4])
2720 helper([9, 3, 2])
2721 helper([3, 4, 18, 22])
2722 helper([3, 4, 18, 22, 150])
2723
Denis Vieriue5a959a2023-03-01 16:16:49 +00002724 def test_contiguous_slice_3d(self):
2725 x = torch.randn(2, 3, 3, device="mps")
2726 x_cpu = x.detach().clone().cpu()
2727 x = x[:1]
2728 x_cpu = x_cpu[:1]
2729 out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
2730 out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
2731 self.assertEqual(out, out_cpu)
2732
Denis Vieriub71c7102022-12-08 17:59:55 +00002733 def test_view_slice(self):
2734 # https://github.com/pytorch/pytorch/issues/83995
2735 NUM_SAMPLES = 60
2736 s = (0, 1)
2737
2738 X = torch.rand(8000, 3, dtype=torch.float32, device='cpu')
2739 X_mps = X.detach().clone().to("cpu")
2740
2741 idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s))
2742 pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1]))
2743 idx_mps = idx.to("mps")
2744 pts_mps = pts.to("mps")
2745 pts[:, s] = idx
2746 pts_mps[:, s] = idx_mps
2747
2748 actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float)
2749 actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
2750
2751 for i in range(NUM_SAMPLES):
2752 for j in range(X.shape[1]):
2753 actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j]
2754 actual_pts[i, j] = X[pts[i, j], j]
2755 self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j])
2756
Denis Vieriudbf96162023-01-02 16:31:27 +00002757 def test_slice_scatter(self):
2758 shape = (4, 4)
2759 tensor = torch.randint(10, shape, device="mps")
2760 tensor_before = tensor.clone()
2761 torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
2762 torch.testing.assert_close(tensor, tensor_before)
Denis Vieriub71c7102022-12-08 17:59:55 +00002763
Kulin Sethe011a8e2022-05-13 18:28:53 +00002764 def test_slice(self):
2765 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
2766 cpu_x = torch.tensor(values, device='cpu')
2767 mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
2768
2769 cpu_slice1 = cpu_x[:2, :]
2770 mps_slice1 = mps_x[:2, :]
Kulin Sethe011a8e2022-05-13 18:28:53 +00002771 self.assertEqual(cpu_slice1, mps_slice1)
2772
2773 cpu_slice2 = cpu_x[:, :1]
2774 mps_slice2 = mps_x[:, :1]
Kulin Sethe011a8e2022-05-13 18:28:53 +00002775 self.assertEqual(cpu_slice2, mps_slice2)
2776
2777 cpu_slice3 = cpu_x[1:2, :]
2778 mps_slice3 = mps_x[1:2, :]
2779 self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
2780
2781 cpu_slice4 = cpu_x[1, :]
2782 mps_slice4 = mps_x[1, :].to('cpu')
2783 self.assertEqual(cpu_slice4, mps_slice4)
2784
Denis Vieriua6b75bb2022-08-22 17:05:53 +00002785 def test_scalar_from_slice_unary(self):
2786 # https://github.com/pytorch/pytorch/issues/82543
2787 tensor_list = torch.tensor([1.0, 1.2], device="mps")
2788
2789 for scalar in tensor_list:
2790 r_mps = torch.ceil(scalar)
2791 r_cpu = torch.ceil(scalar.to("cpu"))
2792 self.assertEqual(r_mps.cpu(), r_cpu)
2793
2794 def test_scalar_from_slice_binary(self):
2795 # https://github.com/pytorch/pytorch/issues/82543
2796 def helper(binary_op):
2797 tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
2798
2799 for scalar in tensor_list:
2800 r_mps = binary_op(scalar, 1.0)
2801 r_cpu = binary_op(scalar.cpu(), 1.0)
2802 self.assertEqual(r_mps.cpu(), r_cpu)
2803 helper(torch.sub)
2804 helper(torch.add)
2805 helper(torch.not_equal)
2806 helper(torch.eq)
2807
Kulin Sethd63db522022-05-28 14:41:56 +00002808 def test_slice_contiguous_view(self):
2809 # https://github.com/pytorch/pytorch/issues/77750
2810
2811 def helper(operator):
2812 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
2813 t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
2814
2815 # contiguous view
2816 x_mps = t_mps[2:] # 3, 4
2817 y_mps = t_mps[:2] # 1, 2
2818
2819 x_cpu = t_cpu[2:]
2820 y_cpu = t_cpu[:2]
2821
2822 res_mps = res_cpu = None
2823 if operator == "<=":
2824 res_mps = x_mps <= y_mps
2825 res_cpu = x_cpu <= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00002826 elif operator == "<":
Kulin Sethd63db522022-05-28 14:41:56 +00002827 res_mps = x_mps < y_mps
2828 res_cpu = x_cpu < y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00002829 elif operator == ">=":
Kulin Sethd63db522022-05-28 14:41:56 +00002830 res_mps = x_mps >= y_mps
2831 res_cpu = x_cpu >= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00002832 elif operator == ">":
Kulin Sethd63db522022-05-28 14:41:56 +00002833 res_mps = x_mps >= y_mps
2834 res_cpu = x_cpu >= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00002835 elif operator == "==":
Kulin Sethd63db522022-05-28 14:41:56 +00002836 res_mps = x_mps == y_mps
2837 res_cpu = x_cpu == y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00002838 elif operator == "!=":
Kulin Sethd63db522022-05-28 14:41:56 +00002839 res_mps = x_mps != y_mps
2840 res_cpu = x_cpu != y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00002841 elif operator == "stack":
2842 res_mps = torch.stack((y_mps, x_mps), dim=-1)
2843 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
Kulin Sethd63db522022-05-28 14:41:56 +00002844
2845 self.assertEqual(res_mps, res_cpu)
2846
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00002847 for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
Kulin Sethd63db522022-05-28 14:41:56 +00002848 helper(op)
2849
Denis Vieriube327ec2022-09-30 18:51:43 +00002850 def test_slice_of_slice(self):
2851 x = torch.tensor([0.5, 0.5], device="cpu")
2852 x_mps = torch.tensor([0.5, 0.5], device="mps")
2853
2854 tensor = x[1][None]
2855 tensor_mps = x_mps[1][None]
2856
2857 res = tensor.ne(0)
2858 res_mps = tensor_mps.ne(0)
2859
2860 self.assertEqual(res, res_mps)
2861
Kulin Sethd63db522022-05-28 14:41:56 +00002862 def test_index_storage_offset(self):
2863 # https://github.com/pytorch/pytorch/issues/78107
2864
2865 a = torch.tensor([8.2670e-01, -1.0293e+00])
2866 b_cpu = a[0]
2867 c_cpu = a[1]
2868
2869 # both 'b' and 'c' are views of 'a'
2870 # 'b' has a storage offset of 0, while 'c' has a storage offset of 1
2871 # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
2872 # otherwise it ends with same value as 'b'
2873 b = b_cpu.to('mps')
2874 c = c_cpu.to('mps')
2875
2876 res_mps = b > c
2877 res_cpu = b_cpu > c_cpu
2878 self.assertEqual(res_mps, res_cpu)
2879
2880 res_mps = c > b
2881 res_cpu = c_cpu > b_cpu
2882 self.assertEqual(res_mps, res_cpu)
2883
Kulin Sethe011a8e2022-05-13 18:28:53 +00002884 def test_flatten(self):
2885 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
2886 cpu_x = torch.tensor(values, device='cpu')
2887 mps_x = torch.tensor(values, device='mps')
2888
2889 cpu_flatten1 = cpu_x.flatten()
2890 mps_flatten1 = mps_x.flatten().to('cpu')
2891 self.assertEqual(cpu_flatten1, mps_flatten1)
2892
2893 cpu_flatten2 = cpu_x.flatten(start_dim=1)
2894 mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
2895 self.assertEqual(cpu_flatten2, mps_flatten2)
2896
2897 cpu_flatten3 = cpu_x.flatten(end_dim=1)
2898 mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
2899 self.assertEqual(cpu_flatten3, mps_flatten3)
2900
2901 # Test repeat
2902 def test_repeat(self):
2903 def helper(shape, repeats):
2904
2905 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2906 x = cpu_x.detach().clone().to('mps').requires_grad_()
2907
2908 y = x.repeat(repeats)
2909 ref_y = cpu_x.repeat(repeats)
2910
2911 cpu_grad = torch.randn(ref_y.shape)
2912 grad = cpu_grad.to('mps')
2913
2914 y.backward(gradient=grad)
2915 ref_y.backward(gradient=cpu_grad)
2916
2917 self.assertEqual(y, ref_y)
2918 self.assertEqual(x.grad, cpu_x.grad)
2919
2920 helper((2, 3, 4, 5), (2, 3, 4, 5))
2921 helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
2922 helper((3, 4, 5), (2, 3, 4, 5))
2923 helper((3, 4, 5), (2, 2, 2))
2924
Henry Chengfe0c7fb2023-02-12 08:43:52 +00002925 def test_torch_repeat_interleave(self, device="mps"):
2926 y = torch.tensor([[1, 2], [3, 4]], device=device)
2927 # exercise single argument function signature
2928 temp = y.repeat_interleave(2)
2929 self.assertEqual(torch.Size([8]), temp.size())
2930
2931 for dtype in [torch.int, torch.long]:
2932 lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
2933 output_size = torch.sum(lengths)
2934 a = torch.repeat_interleave(
2935 y,
2936 lengths,
2937 dim=0,
2938 )
2939 self.assertEqual(a.dtype, y.dtype)
2940 self.assertEqual(a.size(), torch.Size([3, 2]))
2941
2942 a_with_output = torch.repeat_interleave(
2943 y,
2944 lengths,
2945 dim=0,
2946 output_size=output_size,
2947 )
2948 self.assertEqual(a_with_output.dtype, y.dtype)
2949 self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
2950
2951 def test_repeat_interleave(self, device="mps"):
2952 x = torch.tensor([0, 1, 2, 3], device=device)
2953 expected = torch.tensor([1, 2, 2, 3, 3, 3], dtype=torch.int32, device=device)
2954 self.assertEqual(torch.repeat_interleave(x), expected)
2955
2956 with self.assertRaises(RuntimeError):
2957 torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
2958
2959 with self.assertRaises(RuntimeError):
2960 torch.repeat_interleave(torch.arange(4.0, device=device))
2961
2962 with self.assertRaises(RuntimeError):
2963 torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
2964
2965 y = torch.tensor([[1, 2], [3, 4]], device=device)
2966
2967 y1_v1 = torch.repeat_interleave(y, 2)
2968 y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
2969 y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
2970 y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
2971 self.assertEqual(y1_v1, y1_expect)
2972 self.assertEqual(y1_v2, y1_expect)
2973 self.assertEqual(y1_v3, y1_expect)
2974
2975 y2 = torch.repeat_interleave(y, 3, dim=1)
2976 y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
2977 [3, 3, 3, 4, 4, 4]], device=device)
2978 self.assertEqual(y2, y2_expect)
2979
2980 y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
2981 y3_expect = torch.tensor([[1, 2],
2982 [3, 4],
2983 [3, 4]], device=device)
2984 self.assertEqual(y3, y3_expect)
2985
2986 with self.assertRaises(RuntimeError):
2987 torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
2988
2989 with self.assertRaises(RuntimeError):
2990 torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
2991
2992 # test zero sized dimension
2993 x = torch.zeros((5, 0), device=device)
2994 y = torch.repeat_interleave(x, repeats=3, dim=1)
2995 self.assertEqual(y, x.new_zeros(5, 0, device=device))
2996
2997 x = torch.tensor([], dtype=torch.int64, device=device)
2998 y = torch.repeat_interleave(x, x)
2999 self.assertEqual(y, x)
3000
3001 def test_repeat_interleave_simple(self):
3002 def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None):
3003 x = torch.randn(shape, dtype=dtype, device="mps")
3004 x_cpu = x.detach().clone().cpu()
3005
3006 num_repeats_cpu = num_repeats.detach().clone().cpu()
3007
3008 repeats = torch.repeat_interleave(x, num_repeats, dim)
3009 repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim)
3010
3011 self.assertEqual(repeats, repeats_cpu)
3012 helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
3013 helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
3014 helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
3015 helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
3016 helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
3017
Rohan Mitchellf42b42d2022-05-31 18:23:25 +00003018 def test_count_nonzero(self):
3019 def helper(dtype):
3020 n = [
3021 [[1, 0, 2], [3, 0, 2], [7, 9, -4]],
3022 [[0, 2, 3], [3, 2, 1], [2, 0, 0]],
3023 ]
3024 cpu_x = torch.tensor(n, dtype=dtype)
3025 mps_x = torch.tensor(n, dtype=dtype).to('mps')
3026
3027 # All non-zeros
3028 self.assertEqual(
3029 torch.count_nonzero(cpu_x),
3030 torch.count_nonzero(mps_x)
3031 )
3032
3033 # dim=1
3034 self.assertEqual(
3035 torch.count_nonzero(cpu_x, dim=1),
3036 torch.count_nonzero(mps_x, dim=1)
3037 )
3038
3039 # dim=(0, 1)
3040 self.assertEqual(
3041 torch.count_nonzero(cpu_x, dim=(0, 1)),
3042 torch.count_nonzero(mps_x, dim=(0, 1))
3043 )
3044 helper(torch.int32)
3045 helper(torch.int64)
3046 helper(torch.float16)
3047 helper(torch.float32)
3048
Kulin Sethe011a8e2022-05-13 18:28:53 +00003049 def _test_module_empty_input(self, module, inp, check_size=True):
3050 inp.requires_grad_(True)
3051 out = module(inp)
3052 gO = torch.rand_like(out)
3053 out.backward(gO)
3054 if check_size:
3055 self.assertEqual(out.size(), inp.size())
3056 for p in module.parameters():
3057 if p.requires_grad:
3058 self.assertEqual(p.grad, torch.zeros_like(p.grad))
3059 self.assertEqual(inp.grad, torch.zeros_like(inp))
3060
Lukas Hoeniga52bfe22022-05-24 20:09:45 +00003061 # Test dtype casting, with and without simultaneous device change
3062 def test_to(self):
3063 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3064 cpu_x = torch.tensor(values, device='cpu')
3065 mps_x = torch.tensor(values, device='mps')
3066
3067 self.assertEqual(cpu_x.int(), mps_x.int().cpu())
3068 self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
3069 self.assertEqual(cpu_x.float(), mps_x.float().cpu())
3070
3071 self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
3072 torch.tensor(1, dtype=torch.int32))
3073 self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
3074 self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
3075 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
3076 torch.tensor(1, dtype=torch.int32))
3077 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
3078 torch.tensor(1.0))
3079 self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
3080 torch.tensor(4, dtype=torch.int32))
3081 self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
3082 torch.tensor(4, dtype=torch.int32))
3083 self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
3084 torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
Nikita Shulga43905462022-06-22 18:41:21 +00003085 # Cast int8 and uint8 to float and compare results
3086 # See https://github.com/pytorch/pytorch/issues/80009 for more details
3087 cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8)
3088 cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8)
3089 for x_cpu in [cpu_byte, cpu_char]:
3090 x_mps = x_cpu.to('mps')
3091 self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32))
3092
Lukas Hoeniga52bfe22022-05-24 20:09:45 +00003093
Kulin Sethd63db522022-05-28 14:41:56 +00003094 def test_setitem_scalar(self) -> None:
3095 device = 'mps'
3096 for dtype in [torch.int32, torch.float32, torch.int64]:
3097 for i in range(3, 6):
3098 for j in range(3, 6):
3099 t = torch.zeros(i, j, dtype=dtype, device=device)
3100 self.assertEqual(t.sum(), 0)
3101 t[1, 1] = 1
3102 t[2, 1] = j
3103 t[1, 2] = i
3104 self.assertEqual(t[1, 1], 1)
3105 self.assertEqual(t[1, 2], i)
3106 self.assertEqual(t[2, 1], j)
3107 self.assertEqual(t.sum(), 1 + i + j)
Nikita Shulga437ecfc2022-05-27 20:46:53 +00003108
Nikita Shulga81cd2762022-06-14 07:48:56 -07003109 def test_stride_of_strides(self) -> None:
3110 x = torch.rand(32, 1, device='mps')
3111 y = x.as_strided(size=(32, 2), stride=(1, 0))
3112 # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
3113 # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
3114 z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
3115 self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
3116
Kulin Seth596bb412022-07-20 14:27:54 +00003117 def test_type_casting(self):
3118 # https://github.com/pytorch/pytorch/issues/81567
3119 def helper(data, to_dtype):
3120 a_cpu = torch.tensor(data)
3121 a_mps = a_cpu.to(torch.device('mps'))
3122
3123 res_cpu = a_cpu.type(to_dtype)
3124 res_mps = a_mps.type(to_dtype)
3125 self.assertEqual(res_cpu, res_mps)
3126
3127 helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
3128 helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
3129 helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
3130 helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
3131 helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
3132 helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
3133 helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
3134
3135 def test_to_casting(self):
3136 # https://github.com/pytorch/pytorch/issues/81567
3137 def helper(data, to_dtype):
3138 a_cpu = torch.tensor(data)
3139 a_mps = a_cpu.to(torch.device('mps'))
3140
3141 res_cpu = a_cpu.to(to_dtype)
3142 res_mps = a_mps.to(to_dtype)
3143 self.assertEqual(res_cpu, res_mps)
3144
3145 helper([9.0, 3.0, 5.0, 4.0], torch.int64)
3146 helper([9.0, 3.0, 5.0, 4.0], torch.float)
3147 helper([9.0, 3.0, 5.0, 4.0], torch.int32)
3148 helper([9.0, 3.0, 5.0, 4.0], torch.short)
3149 helper([9.0, 3.0, 5.0, 4.0], torch.half)
3150 helper([9.0, 3.0, 5.0, 4.0], torch.int8)
3151 helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
3152
3153 def test_storage_offset_greater_than_src_nbytes(self):
3154 # https://github.com/pytorch/pytorch/issues/80844
3155 n_tensors = 100
3156 n_tensor_elems = 784
3157 elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
3158
3159 tensor_list = []
3160 for i in range(0, n_tensors - 1):
3161 # create a list of contiguous view tensors (view tensor created by the slice op)
3162 t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
3163 tensor_list.append(t)
3164
3165 for i in range(0, n_tensors - 1):
Nikita Shulgaae62cf72022-10-21 14:10:05 +00003166 t = tensor_list[i].view(1, n_tensor_elems)
Kulin Seth596bb412022-07-20 14:27:54 +00003167 t_mps = t.to("mps")
Nikita Shulgaae62cf72022-10-21 14:10:05 +00003168 self.assertEqual(t, t_mps.cpu(), f"i={i}")
Kulin Sethe011a8e2022-05-13 18:28:53 +00003169
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00003170 # See https://github.com/pytorch/pytorch/issues/82427
Nikita Shulgaff533b12022-08-18 21:59:15 +00003171 # and https://github.com/pytorch/pytorch/issues/83692
3172 def test_full_bugs(self):
3173 # Test should not crash
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00003174 x = torch.full((3, 3), True, device='mps')
Nikita Shulgaff533b12022-08-18 21:59:15 +00003175 # torch.full should work for uint8
3176 y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
3177 y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
3178 self.assertEqual(y_mps, y_cpu)
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00003179
Denis Vieriu71ec2612023-02-15 06:09:56 +00003180 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +00003181 # See https://github.com/pytorch/pytorch/issues/84995
3182 def test_div_bugs(self):
3183 for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
Kulin Seth299ada92023-02-10 00:10:08 +00003184 if dtype != torch.int64:
3185 x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
3186 y = torch.div(x, 101, rounding_mode=mode)
3187 self.assertEqual(y.sum(), 0)
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +00003188
Nikita Shulgadcf51882022-08-03 14:54:47 +00003189 # See https://github.com/pytorch/pytorch/issues/82663
3190 def test_bool_expand(self):
3191 x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
3192 y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
PyTorch MergeBotcba96362022-12-02 21:36:13 +00003193 self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
Nikita Shulgadcf51882022-08-03 14:54:47 +00003194
Nikita Shulga420c5762022-08-02 21:15:37 +00003195 # Empty unary op should return tensor of the same size
3196 def test_empty_neg(self):
3197 x = torch.tensor([[]], device='mps')
3198 y = -x
3199 self.assertEqual(x, y)
3200
Kulin Sethfc596642023-01-04 22:15:13 +00003201 def _test_unique_scalar_empty(self, dtype, device, f):
3202 # test scalar
3203 x = torch.tensor(0, dtype=dtype, device=device)
3204 unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
3205 expected_unique = torch.tensor([0], dtype=dtype, device=device)
3206 expected_inverse = torch.tensor(0, device=device)
3207 expected_counts = torch.tensor([1], device=device)
3208 self.assertEqual(unique, expected_unique)
3209 self.assertEqual(inverse, expected_inverse)
3210 self.assertEqual(counts, expected_counts)
3211
3212 # test zero sized tensor
3213 x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
3214 unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
3215 expected_unique = torch.tensor([], dtype=dtype, device=device)
3216 expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
3217 expected_counts = torch.tensor([], dtype=torch.long, device=device)
3218 self.assertEqual(unique, expected_unique)
3219 self.assertEqual(inverse, expected_inverse)
3220 self.assertEqual(counts, expected_counts)
3221
3222 def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
3223 def ensure_tuple(x):
3224 if isinstance(x, torch.Tensor):
3225 return (x,)
3226 return x
3227
3228 for return_inverse in [True, False]:
3229 for return_counts in [True, False]:
3230 # test with expected
3231 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
3232 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
3233 self.assertEqual(expected_unique, ret[0])
3234 if return_inverse:
3235 self.assertEqual(expected_inverse, ret[1])
3236 if return_counts:
3237 count_index = 1 + int(return_inverse)
3238 self.assertEqual(expected_counts, ret[count_index])
3239
3240 # tests per-element unique on a higher rank tensor.
3241 y = x.view(additional_shape)
3242 y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
3243 self.assertEqual(expected_unique, y_unique)
3244 self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
3245 self.assertEqual(expected_counts, y_counts)
3246
3247 def test_unique_all_dtypes(self, device="mps"):
3248 def helper(dtype):
3249 def ensure_tuple(x):
3250 if isinstance(x, torch.Tensor):
3251 return (x,)
3252 return x
3253
3254 if dtype is torch.bool:
3255 x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
3256 expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
3257 expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
3258 expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
3259 else:
3260 x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
3261 expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
3262 expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
3263 expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
3264
3265 # test sorted unique
3266 fs = (
3267 lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
3268 lambda x, **kwargs: x.unique(sorted=True, **kwargs),
3269 )
3270 x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
3271 xs = (x, x_sliced)
3272 for f, x in product(fs, xs):
3273 self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
3274 self._test_unique_scalar_empty(dtype, device, f)
3275
3276 # test unsorted unique
3277 fs = (
3278 lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
3279 lambda x, **kwargs: x.unique(sorted=False, **kwargs)
3280 )
3281 for f, x in product(fs, xs):
3282 self._test_unique_scalar_empty(dtype, device, f)
3283 for return_inverse, return_counts in product((True, False), repeat=2):
3284 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
3285 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
3286 x_list = x.tolist()
3287 x_unique_list = ret[0].tolist()
3288 self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
3289 if return_inverse:
3290 x_inverse_list = ret[1].tolist()
3291 for i, j in enumerate(x_inverse_list):
3292 self.assertEqual(x_list[i], x_unique_list[j])
3293 if return_counts:
3294 count_index = 1 + int(return_inverse)
3295 x_counts_list = ret[count_index].tolist()
3296 for i, j in zip(x_unique_list, x_counts_list):
3297 count = 0
3298 for k in x_list:
3299 if k == i:
3300 count += 1
3301 self.assertEqual(j, count)
3302 [helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]]
3303
3304 def test_unique(self):
3305 def helper(x, return_inverse, return_counts):
3306 cpu_x = x
3307 x = cpu_x.detach().clone().to('mps')
3308
3309 result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts)
3310 result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts)
3311
3312 self.assertEqual(result, result_cpu)
3313 helper(torch.tensor([1, 2, 4, 2, 1]), False, False)
3314 helper(torch.randint(3, (10, )), False, False)
3315 helper(torch.randint(3, (10, )), True, False)
3316 helper(torch.randint(3, (10, )), False, True)
3317 helper(torch.randint(3, (10, )), True, True)
3318 helper(torch.randint(3, (1, )), True, True)
3319 helper(torch.randint(3, (0, )), True, True)
3320
3321 def test_unique_consecutive(self):
3322 def helper(x, dim, return_inverse, return_counts):
3323 cpu_x = x
3324 x = cpu_x.detach().clone().to('mps')
3325
3326 result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
3327 result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
3328
3329 self.assertEqual(result, result_cpu)
3330 helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False)
3331 helper(torch.randint(3, (10, )), 0, False, False)
3332 helper(torch.randint(3, (10, )), 0, True, False)
3333 helper(torch.randint(3, (10, )), 0, False, True)
3334 helper(torch.randint(3, (10, )), 0, True, True)
3335 helper(torch.randint(3, (10, )), 0, True, True)
3336 helper(torch.randint(3, (1, )), 0, True, True)
3337 helper(torch.randint(3, (0, )), 0, True, True)
3338
3339 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False)
3340 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True)
3341 helper(torch.randint(2, (20, 2)), 0, True, True)
3342 helper(torch.randint(2, (1, 2)), 0, True, True)
3343 helper(torch.randint(2, (0, 2)), 0, True, True)
3344
3345 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False)
3346 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True)
3347 helper(torch.randint(2, (2, 20)), 1, True, True)
3348 helper(torch.randint(2, (2, 1)), 1, True, True)
3349 helper(torch.randint(2, (2, 0)), 1, True, True)
3350
Nikita Shulga1367f242022-09-27 15:44:53 +00003351 # See https://github.com/pytorch/pytorch/issues/85675
3352 def test_cat_non_contiguous(self):
Kulin Sethc74f4382023-02-11 19:43:33 +00003353 def rotate_subset(data, dim):
3354 x1 = data[:, :, :2, :]
3355 x2 = data[:, :, 2:, :]
3356 self.assertFalse(x1.is_contiguous())
3357 self.assertFalse(x2.is_contiguous())
3358 return torch.concat((x1, x2), dim=dim)
Nikita Shulga1367f242022-09-27 15:44:53 +00003359 for dtype in MPS_DTYPES:
3360 if dtype == torch.bool:
3361 continue
Kulin Sethc74f4382023-02-11 19:43:33 +00003362 data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
3363 data = data.to(memory_format=torch.channels_last)
Nikita Shulga1367f242022-09-27 15:44:53 +00003364 mps_data = data.to("mps")
Kulin Sethc74f4382023-02-11 19:43:33 +00003365 self.assertEqual(data, mps_data)
3366 for dim in range(data.dim()):
3367 cpu_result = rotate_subset(data, dim)
3368 mps_result = rotate_subset(mps_data, dim)
3369 self.assertEqual(cpu_result, mps_result.to("cpu"))
3370 # TODO: enable memory format test
3371 # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
Nikita Shulga1367f242022-09-27 15:44:53 +00003372
Nikita Shulgab9b24c32022-10-02 20:13:05 +00003373 # See https://github.com/pytorch/pytorch/issues/85967
3374 def test_from_numpy_non_contiguous(self):
3375 a = np.arange(9).reshape(3, 3)[:, :2]
3376 t_cpu = torch.tensor(a, device="cpu")
3377 t_mps = torch.tensor(a, device="mps")
3378 self.assertEqual(t_cpu, t_mps.to("cpu"))
3379
Nikita Shulgaae62cf72022-10-21 14:10:05 +00003380 # See https://github.com/pytorch/pytorch/issues/86954
3381 def test_copy_non_contiguous(self):
3382 x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
3383 self.assertFalse(x.is_contiguous())
3384 y = x.to('mps')
3385 self.assertFalse(y.is_contiguous())
3386 self.assertEqual(x, y.to('cpu'))
3387
3388 x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
3389 y = x.to('mps')
3390 self.assertEqual(x, y.to('cpu'))
3391
3392 x = torch.full((4, 4, 4, 4), 13, device="cpu")
3393 y = torch.full((4, 4, 4, 4), 13, device="mps")
3394 z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
3395 x.permute(3, 2, 1, 0)[1::, ::2] = z
3396 # As y is on MPS and z on CPU, this dispatches to a copy operator
3397 y.permute(3, 2, 1, 0)[1::, ::2] = z
3398 self.assertEqual(x, y.to('cpu'))
3399
Li-Huai (Allan) Linb7c2a652023-02-28 05:24:31 +00003400 # See https://github.com/pytorch/pytorch/issues/95417
3401 def test_copy_storage_offset(self):
3402 x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
3403 x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
3404 update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
3405 update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
3406 x_cpu[2:4] = update_cpu
3407 x_mps[2:4] = update_mps # implicit type casting and copy
3408 self.assertEqual(x_cpu, x_mps)
3409
Lukas Hoenig81a8fdc2022-11-17 04:54:23 +00003410 # See https://github.com/pytorch/pytorch/pull/84742
3411 # and https://github.com/pytorch/pytorch/pull/78319
3412 def test_binops_dtype_precedence(self):
3413 # Test dtype precedence (casting order) in binary operations by comparing to CPU result
3414 # Example values for all dtypes supported on the MPS backend
3415 sample_vals = {
3416 torch.bool: [False, True],
3417 torch.int16: [-15, 0, 1, 10],
3418 torch.int32: [-376, 0, 1, 13],
3419 torch.int64: [-8, 0, 1, 77],
3420 torch.float16: [-234.5, 0.0, 1.0, 2.0],
3421 torch.float32: [-1.0, 0.0, 0.1, 111.99],
3422 }
3423 # Test all combinations of dtypes, operations, dimensionality
3424 for dtype1, dtype2, binop in itertools.product(
3425 sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']):
3426 # bool minus bool is generally unsupported, so skip
3427 if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
3428 continue
3429 full_shape = (10,)
3430 for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
3431 # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
3432 # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
3433 # (torch.tensor(val2, dtype=dtype2, device='mps')))
3434 # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
3435 # (torch.tensor(val2, dtype=dtype2, device='cpu')))
3436 self.assertEqual(
3437 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
3438 (torch.tensor(val2, dtype=dtype2, device='mps')),
3439 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
3440 (torch.tensor(val2, dtype=dtype2, device='cpu')))
3441 self.assertEqual(
3442 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
3443 (torch.tensor([val2], dtype=dtype2, device='mps')),
3444 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
3445 (torch.tensor([val2], dtype=dtype2, device='cpu')))
3446 self.assertEqual(
3447 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
3448 (torch.tensor([val2], dtype=dtype2, device='mps')),
3449 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
3450 (torch.tensor([val2], dtype=dtype2, device='cpu')))
3451 self.assertEqual(
3452 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
3453 (torch.tensor(val2, dtype=dtype2, device='mps')),
3454 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
3455 (torch.tensor(val2, dtype=dtype2, device='cpu')))
3456 # Test tensors created with torch.full
3457 x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
3458 y1 = torch.tensor(val2, dtype=dtype2, device='mps')
3459 x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
3460 y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
3461 self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
3462 x3 = torch.tensor(val1, dtype=dtype1, device='mps')
3463 y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
3464 x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
3465 y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
3466 self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
3467 self.assertEqual(
3468 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
3469 (torch.full(full_shape, val2, dtype=dtype2, device='mps')),
3470 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
3471 (torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
Nikita Shulgaae62cf72022-10-21 14:10:05 +00003472
Soof Golan19264b52023-02-09 10:30:51 +00003473 def test_nansum(self):
3474 def helper(dtype, noncontiguous, dim):
3475 zero_cpu = torch.zeros((), dtype=dtype)
3476
3477 # Randomly scale the values
3478 scale = random.randint(10, 100)
3479 x_cpu: torch.Tensor = make_tensor(
3480 (5, 5), dtype=dtype, device='cpu',
3481 low=-scale, high=scale, noncontiguous=noncontiguous)
3482
3483 if dtype.is_floating_point:
3484 nan_mask_cpu = x_cpu < (0.2 * scale)
3485 x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
3486 x_cpu[nan_mask_cpu] = np.nan
3487 else:
3488 x_no_nan_cpu = x_cpu
3489
3490 x_mps = x_cpu.to('mps')
3491 actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
3492 expect_out_cpu = torch.empty(0, dtype=dtype)
3493 dim_kwargs = {"dim": dim} if dim is not None else {}
3494 expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
3495
3496 actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
3497 # Sanity check on CPU
3498 self.assertEqual(expect, actual_cpu)
3499
3500 # Test MPS
3501 actual_mps = torch.nansum(x_mps, **dim_kwargs)
3502 # Test out= variant
3503 torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
3504 torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
3505 self.assertEqual(expect, actual_mps)
3506 self.assertEqual(expect_out_cpu, actual_out_mps)
3507
3508 args = itertools.product(
3509 (torch.float16, torch.float32, torch.int32, torch.int64), # dtype
3510 (True, False), # noncontiguous
3511 (0, 1, None), # dim
3512 )
3513
3514 for dtype, noncontiguous, dim in args:
3515 with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
3516 helper(dtype, noncontiguous, dim)
3517
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00003518 def test_cumsum_all_dtypes(self):
3519 def helper(dtype):
3520 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
3521 t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
3522
3523 a = t.cumsum(0, dtype=dtype)
3524 a_cpu = t_cpu.cumsum(0, dtype=dtype)
3525
3526 self.assertEqual(a.cpu(), a_cpu)
3527 [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
3528
3529 try:
3530 helper(torch.int64)
3531 except Exception as e:
3532 e_string = str(e)
Denis Vieriu4d3352e2023-03-02 00:26:21 +00003533 self.assertEqual(e_string, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3")
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00003534
3535 def test_cumsum_minus_one_axis(self):
3536 def helper(dtype):
3537 # Test with axis -1
3538 cpu_x = None
3539 if(dtype == torch.float32):
3540 cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
3541 else:
3542 cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
3543 x = cpu_x.detach().clone().to('mps')
3544
3545 cpu_y = cpu_x.cumsum(-1)
3546 y = x.cumsum(-1)
3547
3548 self.assertEqual(y, cpu_y)
3549
3550 [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00003551
Kulin Seth105f7202023-02-09 19:29:07 +00003552 def test_median_int16(self):
3553 def helper(shape, dtype):
3554 cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
3555 x = cpu_x.detach().clone().to('mps')
3556
3557 median_result = torch.median(x)
3558 median_result_cpu = torch.median(cpu_x)
3559 self.assertEqual(median_result, median_result_cpu)
3560
3561 helper((2, 8, 4, 5), torch.int16)
3562
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00003563class TestLogical(TestCaseMPS):
qqaatw5943aaa2022-06-29 02:44:35 +00003564 def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
3565 return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
3566
3567 def test_logical_not(self):
3568 def helper(x):
3569 cpu_x = x
3570 x = cpu_x.detach().clone().to('mps')
3571
3572 result = torch.logical_not(x)
3573 result_cpu = torch.logical_not(cpu_x)
3574
3575 self.assertEqual(result, result_cpu)
3576
3577 helper(self._wrap_tensor([1, 1, 0, 0]))
3578 helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True))
3579 helper(self._wrap_tensor([True, True, False, False]))
3580 helper(self._wrap_tensor(1))
3581 helper(self._wrap_tensor(0))
3582 helper(self._wrap_tensor(True))
3583 helper(self._wrap_tensor(False))
3584
3585 def test_logical_and(self):
3586 def helper(x, other):
3587 cpu_x = x
3588 x = cpu_x.detach().clone().to('mps')
3589
3590 cpu_other = other
3591 other = cpu_other.detach().clone().to('mps')
3592
3593 result = torch.logical_and(x, other)
3594 result_cpu = torch.logical_and(cpu_x, cpu_other)
3595 self.assertEqual(result, result_cpu)
3596
3597 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor(([1, 0, 0, 1])))
3598 helper(
3599 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
3600 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
3601 )
3602 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
3603 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
3604 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
3605 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
3606 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
3607
3608 def test_logical_or(self):
3609 def helper(x, other):
3610 cpu_x = x
3611 x = cpu_x.detach().clone().to('mps')
3612
3613 cpu_other = other
3614 other = cpu_other.detach().clone().to('mps')
3615
3616 result = torch.logical_or(x, other)
3617 result_cpu = torch.logical_or(cpu_x, cpu_other)
3618
3619 self.assertEqual(result, result_cpu)
3620
3621 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor(([1, 0, 0, 1])))
3622 helper(
3623 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
3624 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
3625 )
3626 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
3627 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
3628 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
3629 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
3630 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
3631
3632 def test_logical_xor(self):
3633 def helper(x, other):
3634 cpu_x = x
3635 x = cpu_x.detach().clone().to('mps')
3636
3637 cpu_other = other
3638 other = cpu_other.detach().clone().to('mps')
3639
3640 result = torch.logical_xor(x, other)
3641 result_cpu = torch.logical_xor(cpu_x, cpu_other)
3642
3643 self.assertEqual(result, result_cpu)
3644
3645 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor(([1, 0, 0, 1])))
3646 helper(
3647 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
3648 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
3649 )
3650 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
3651 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
3652 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
3653 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
3654 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
3655
Denis Vieriubdbf1882022-12-23 17:30:42 +00003656 def test_min_max(self):
3657 def helper(dtype):
3658 for _ in range(10):
3659 if dtype == torch.float32 or dtype == torch.float16:
3660 x = torch.randn((30, 15), device='mps', dtype=dtype)
3661 else:
3662 x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
3663 x_cpu = x.to("cpu")
3664
3665 y = x.max()
3666 y_cpu = x_cpu.max()
3667 self.assertEqual(y, y_cpu)
3668
3669 z = x.min()
3670 z_cpu = x_cpu.min()
3671 self.assertEqual(z, z_cpu)
3672
3673 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
qqaatw5943aaa2022-06-29 02:44:35 +00003674
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00003675class TestSmoothL1Loss(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003676
3677 def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
3678 # CPU
3679 input_cpu = torch.randn(4, 7, requires_grad=requires_grad)
3680 target_cpu = torch.randn(4, 7)
3681
3682 # MPS
3683 input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
3684 target_mps = target_cpu.detach().clone().to('mps')
3685
3686 smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
3687 smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
3688
3689 self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
3690
3691 if requires_grad:
3692 smooth_l1_loss_cpu.backward()
3693 smooth_l1_loss_mps.backward()
3694 self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
3695
3696 return smooth_l1_loss_cpu, smooth_l1_loss_mps
3697
3698 def test_smooth_l1_loss_reduction_none(self):
3699 self._smooth_l1_loss_helper(reduction="none")
3700
3701 def test_smooth_l1_loss_reduction_mean(self):
3702 self._smooth_l1_loss_helper(reduction="mean")
3703
3704 def test_smooth_l1_loss_reduction_sum(self):
3705 self._smooth_l1_loss_helper(reduction="sum")
3706
3707 def test_smooth_l1_loss_reduction_mean_backward(self):
3708 self._smooth_l1_loss_helper(reduction="mean", requires_grad=True)
3709
3710 def test_smooth_l1_loss_reduction_mean_sum_backward(self):
3711 self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
3712
3713
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00003714class TestNLLLoss(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003715 def test_nll_loss_mismatched_batch(self, device='mps'):
3716 x = torch.randn((10, 3), requires_grad=True, device=device)
3717 # t should have size (10,)
3718 t = torch.zeros((3,), dtype=torch.int64, device=device)
3719 with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
3720 F.nll_loss(x, t)
3721
3722 def test_nll_loss_out_of_bounds_ignore_index(self):
3723
3724 def _test_nll_loss_out_of_bounds_ignore_index(device):
3725 output = []
3726 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
3727 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
3728 t = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
3729 for reduction in ['mean', 'none']:
3730 output.append(F.nll_loss(x, t, ignore_index=255, reduction=reduction))
3731 return output
3732
3733 output_cpu = _test_nll_loss_out_of_bounds_ignore_index(device='cpu')
3734 output_mps = _test_nll_loss_out_of_bounds_ignore_index(device='mps')
3735
3736 for cpu, mps in zip(output_cpu, output_mps):
3737 self.assertEqual(cpu, mps.to('cpu'))
3738
3739 def test_nll_loss_invalid_target_dim(self):
3740
3741 def _test_nll_loss_invalid_target_dim(device):
3742 output = []
3743 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
3744 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
3745 t = torch.zeros((6, 2), dtype=torch.int64, device=device)
3746 with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
3747 F.nll_loss(x, t)
3748
3749 _test_nll_loss_invalid_target_dim(device='cpu')
3750 _test_nll_loss_invalid_target_dim(device='mps')
3751
3752 def test_nll_loss_invalid_weights(self):
3753
3754 def _test_nll_loss_invalid_weights(device):
3755 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
3756 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
3757 t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
3758 invalid_weights = [
3759 torch.zeros(4, device=device),
3760 torch.zeros((1, 3), device=device),
3761 ]
3762 msg = "weight tensor should be defined either for all 3 classes or no classes"
3763 for weight in invalid_weights:
3764 with self.assertRaisesRegex(RuntimeError, msg):
3765 F.nll_loss(x, t, weight=weight)
3766
3767 _test_nll_loss_invalid_weights(device='cpu')
3768 _test_nll_loss_invalid_weights(device='mps')
3769
3770 def _nll_loss_helper(self, input_size, reduction, expected):
3771
3772 # CPU
3773 input = torch.rand(input_size, requires_grad=True, device='cpu')
3774 num_channels = input_size[1]
3775 target_size = (input_size[0], ) + tuple(input_size[2:])
3776 target = torch.randint(num_channels, target_size, device='cpu')
Ramin Azarmehr368e3642023-02-07 01:54:16 +00003777 weights = torch.randn(num_channels)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003778
3779 # MPS
3780 input_mps = input.detach().clone().to('mps').requires_grad_()
3781 target_mps = target.detach().clone().to('mps')
Ramin Azarmehr368e3642023-02-07 01:54:16 +00003782 weights_mps = weights.to("mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00003783
Ramin Azarmehr368e3642023-02-07 01:54:16 +00003784 output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
3785 output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
Sergii Dymchenko09f23732022-11-30 17:00:36 +00003786 self.assertEqual(output_cpu, output_mps.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00003787
3788 output_cpu.sum().backward()
3789 output_mps.sum().backward()
3790 self.assertEqual(input.grad, input_mps.grad.to('cpu'))
3791
Abhishek Pathakae83e442022-07-12 19:46:59 +00003792 def _nll_loss_1d_helper(self, input_size, reduction):
3793
3794 # CPU
3795 input = torch.rand(input_size, requires_grad=True, device='cpu')
3796 num_channels = input_size[0]
3797 target = torch.randint(num_channels, [], device='cpu')
3798
3799 # MPS
3800 input_mps = input.detach().clone().to('mps').requires_grad_()
3801 target_mps = target.detach().clone().to('mps')
3802
3803 output_cpu = F.nll_loss(input, target, reduction=reduction)
3804 output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
Sergii Dymchenko09f23732022-11-30 17:00:36 +00003805 self.assertEqual(output_cpu, output_mps.to('cpu'))
Abhishek Pathakae83e442022-07-12 19:46:59 +00003806
3807 output_cpu.sum().backward()
3808 output_mps.sum().backward()
3809 self.assertEqual(input.grad, input_mps.grad.to('cpu'))
3810
Kulin Sethe011a8e2022-05-13 18:28:53 +00003811 def test_as_strided(self):
Kulin Seth54361342022-07-06 03:39:20 +00003812 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3813 values_1 = [[1.0, 1.0], [1.0, 1.0]]
3814 cpu_x = torch.tensor(values, device='cpu')
3815 ones1 = torch.tensor(values_1, device='mps')
3816 x = cpu_x.detach().clone().to('mps').requires_grad_()
3817 strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
3818 strided_mps = torch.as_strided(x, (2, 2), (1, 2))
3819 self.assertEqual(strided_mps, strided_cpu)
3820 strided_cpu_out = strided_cpu + ones1.to('cpu')
3821 strided_mps_out = strided_mps + ones1
3822 self.assertEqual(strided_cpu_out, strided_mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003823
Kulin Seth54361342022-07-06 03:39:20 +00003824 # test with storage offsets
3825 cpu_x = torch.rand(3, 3, device='cpu')
3826 mps_x = cpu_x.to('mps')
3827 strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0)
3828 strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0)
3829 strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1)
3830 strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1)
3831 strided_cpu_out = strided_cpu1 - strided_cpu2
3832 strided_mps_out = strided_mps1 - strided_mps2
3833 self.assertEqual(strided_cpu_out, strided_mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003834
Denis Vieriu4477a5b2022-12-22 21:21:00 +00003835 def test_unfold(self):
3836 x = torch.arange(1., 8)
3837 x_mps = torch.arange(1., 8, device="mps")
Kulin Seth54361342022-07-06 03:39:20 +00003838
Denis Vieriu4477a5b2022-12-22 21:21:00 +00003839 y = x.unfold(0, 2, 1)
3840 y_mps = x_mps.unfold(0, 2, 1)
3841
3842 self.assertEqual(y, y_mps)
3843
3844 def test_unfold_all_devices_and_dtypes(self):
3845 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
3846 for dt in supported_dtypes:
3847 x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
3848 self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
3849
3850 def test_unfold_scalars(self):
3851 x = torch.tensor(0.5, device="mps")
3852 # unfold on a 0-dimensional tensor should always return a 1-d dimensional
3853 # tensor of shape [size] (i.e., the second parameter to unfold)
3854
3855 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
3856 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
3857 self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
Kulin Sethe011a8e2022-05-13 18:28:53 +00003858
Denis Vieriuf7939b22023-01-03 06:01:07 +00003859 def test_bincount_simple(self):
3860 input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
3861 input_cpu = input.to("cpu")
3862 weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
3863 weights_cpu = weights.to("cpu")
3864
3865 x = torch.bincount(input)
3866 x_cpu = torch.bincount(input_cpu)
3867 self.assertEqual(x, x_cpu)
3868
3869 y = input.bincount(weights)
3870 y_cpu = input_cpu.bincount(weights_cpu)
3871 self.assertEqual(y, y_cpu)
3872
3873 def test_bincount_reduction(self):
3874 device = "mps"
3875 # negative input throws
3876 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
3877 torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32))
3878 # n-d input, with n > 1 throws
3879 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
3880 torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
3881 # minlength < 0 throws
3882 with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
3883 torch.bincount(torch.tensor([1, 3], device=device),
3884 torch.tensor([.2, .2], device=device),
3885 minlength=-1)
3886 # n-d weights, with n > 1 throws
3887 with self.assertRaisesRegex(RuntimeError, '1-d'):
3888 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
3889 torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float))
3890 # input and weights dim mismatch
3891 with self.assertRaisesRegex(RuntimeError, 'same length'):
3892 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
3893 torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float))
3894 # 1-d input with no elements and default minlength
3895 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
3896 torch.zeros(0, dtype=torch.long, device=device))
3897 # 1-d input with no elements and specified minlength
3898 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
3899 torch.zeros(10, dtype=torch.long, device=device))
3900
3901 # test tensor method without weights
3902 long_counts = torch.tensor(
3903 [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
3904 self.assertEqual(
3905 torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
3906 long_counts)
3907 # test avoiding overflow for uint8 (#76979)
3908 count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
3909 count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
3910 self.assertEqual(count_uint8, count_int16)
3911 # test minlength functionality
3912 int_counts = torch.bincount(
3913 torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5)
3914 self.assertEqual(
3915 torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
3916 int_counts)
3917 # test weights
3918 byte_counts = torch.bincount(
3919 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
3920 torch.tensor([.1, .2, .3, .4, .5], device=device))
3921 self.assertEqual(
3922 torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
3923 byte_counts = torch.bincount(
3924 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
3925 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
3926 self.assertEqual(
3927 torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts)
3928 # test non-contiguous inputs and weights
3929 inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32)
3930 weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
3931 for i in [0, 1]:
3932 assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
3933 assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
3934 # inputs are non-contiguous but weights are contiguous
3935 self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
3936 # inputs and weights are non-contiguous
3937 self.assertEqual(
3938 inputs[:, 1].bincount(weights[:, 1]),
3939 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
3940 # weights are non-contiguous but inputs are contiguous
3941 self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
3942 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
3943
3944 # test bincount on non-contiguous slices
3945 all0s = torch.zeros((32, 2), dtype=torch.int32, device=device)
3946 self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
3947
3948 all1s = torch.ones((32, 2), dtype=torch.int32, device=device)
3949 self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
3950
3951 # test large number of bins - global memory use
3952 big_exp = torch.zeros(100, device=device)
3953 big_exp[-1] = 50.0
3954 big_w = torch.tensor([.5] * 100, device=device)
3955 big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w)
3956 self.assertEqual(big_exp, big_out)
3957 # test large input size
3958 big_exp = torch.zeros(2, device=device, dtype=torch.int64)
3959 big_exp[1] = 10
3960 big_out = torch.ones(10, dtype=torch.int8, device=device).bincount()
3961 self.assertEqual(big_exp, big_out)
3962
3963 def test_bincount(self):
3964 device = "mps"
3965 input_size = (5000,)
3966 w = torch.randn(input_size, dtype=torch.float, device=device)
3967 w_cpu = w.cpu()
3968
3969 t = torch.randint(50, input_size, dtype=torch.int8, device=device)
3970 self.assertEqual(t.cpu().bincount(), t.bincount())
3971 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
3972
3973 t = torch.randint(500, input_size, dtype=torch.int32, device=device)
3974 self.assertEqual(t.cpu().bincount(), t.bincount())
3975 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
3976
3977 t = torch.randint(2000, input_size, dtype=torch.int32, device=device)
3978 self.assertEqual(t.cpu().bincount(), t.bincount())
3979 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
3980
3981 t = torch.zeros([10], dtype=torch.int32, device=device)
3982 t[0] = 35488
3983 counted = t.bincount(minlength=65536)
3984 self.assertEqual(torch.sum(counted), 10)
3985
Kulin Seth3d833212022-05-20 03:18:09 +00003986 def test_sum_backward(self):
3987 def helper(n, c):
3988 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3989 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
3990 x = cpu_x.detach().clone().to('mps').requires_grad_()
3991
3992 all_sum = torch.sum(x)
3993 all_sum_cpu = torch.sum(cpu_x)
3994
3995 all_sum.backward()
3996 all_sum_cpu.backward()
Kulin Seth3d833212022-05-20 03:18:09 +00003997 self.assertEqual(all_sum, all_sum_cpu)
3998 self.assertEqual(x.grad, cpu_x.grad)
3999
4000 helper(3, 3)
4001
Abhishek Pathakae83e442022-07-12 19:46:59 +00004002 def test_nll_loss_1d(self, device='cpu'):
4003 self._nll_loss_1d_helper([10], "none")
4004 self._nll_loss_1d_helper([10], "mean")
4005 self._nll_loss_1d_helper([10], "sum")
4006
Kulin Sethe011a8e2022-05-13 18:28:53 +00004007 def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
4008 self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
4009 self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
4010 self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
4011 self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
4012 self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
4013
4014 @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
4015 def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
4016 nan = torch.tensor(float('nan'), device=device)
4017 self._nll_loss_helper([1, 3], "mean", nan)
4018 self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
4019 self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
4020 self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
4021 self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
4022
4023 def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
4024 zero = torch.tensor(0, device=device)
4025 self._nll_loss_helper([1, 3], "sum", zero)
4026 self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
4027 self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
4028 self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
4029 self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
4030
4031 def test_nll_loss_byte_target_matches_long(self, device='cpu'):
4032 N, C = 10, 4
4033 input = torch.randn(N, C, device=device, requires_grad=True)
4034 target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
4035
4036 def compute_result_and_gradient(reduction, target_dtype):
4037 result, grad = {}, {}
4038 for dev in ['cpu', 'mps']:
4039 input_dev = input.to(dev)
4040 input_ = input_dev.detach()
4041 input_.requires_grad_()
4042
4043 target_dev = target.to(dev)
4044
4045 prob = F.log_softmax(input_, dim=-1)
4046 loss = nn.NLLLoss(reduction=reduction)
4047 result[dev] = loss(prob, target_dev.to(target_dtype))
4048 result[dev].sum().backward()
4049 grad[dev] = input_.grad
4050
4051 return result, grad
4052
4053 for reduction in ["none", "mean", "sum"]:
4054 result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
4055 result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
4056
4057 self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
4058 self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
4059
qqaatwff44bfa2022-06-24 17:18:30 +00004060 # L1 loss
4061 def test_l1_loss(self):
4062 def helper(shape, reduction):
4063 # create the criterion
4064 loss = torch.nn.L1Loss(reduction=reduction)
4065
4066 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4067 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4068 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4069 targetMPS = targetCPU.detach().clone().to('mps')
4070
4071 # forward pass
4072 outputCPU = loss(inputCPU, targetCPU)
4073 outputMPS = loss(inputMPS, targetMPS)
4074 self.assertEqual(outputCPU, outputMPS)
4075
4076 # backward pass
4077 if reduction != 'none':
4078 # chose 2 just to make the grad_output > 1 in backward pass
4079 outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4080 outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4081 self.assertEqual(inputCPU.grad, inputMPS.grad)
4082
4083 helper([8, 5, 4], 'none')
4084 helper([7, 5, 2, 4], 'sum')
4085 # verify if changes in shape would cause cached graph lookup problems
4086 helper([7, 5, 2, 4, 6], 'sum')
4087 helper([8, 4, 5, 7, 6], 'mean')
4088
Kulin Sethe011a8e2022-05-13 18:28:53 +00004089 # Mean Squared Error
4090 def test_mse_loss(self):
4091 def helper(shape, reduction):
4092 # create the criterion
4093 loss = torch.nn.MSELoss(reduction=reduction)
4094
4095 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4096 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4097 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4098 targetMPS = targetCPU.detach().clone().to('mps')
4099
4100 # forward pass
4101 outputCPU = loss(inputCPU, targetCPU)
4102 outputMPS = loss(inputMPS, targetMPS)
4103 self.assertEqual(outputCPU, outputMPS)
4104
4105 # backward pass
4106 if reduction != 'none':
4107 # chose 2 just to make the grad_output > 1 in backward pass
4108 outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4109 outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4110 self.assertEqual(inputCPU.grad, inputMPS.grad)
4111
4112 helper([8, 5, 4], 'none')
4113 helper([7, 5, 2, 4], 'sum')
4114 # verify if changes in shape would cause cached graph lookup problems
4115 helper([7, 5, 2, 4, 6], 'sum')
4116 helper([8, 4, 5, 7, 6], 'mean')
4117
4118 # Binary Cross Enropy
Kulin Seth4615f6a2022-06-16 20:21:31 +00004119 def test_bce_loss_simple(self):
Kulin Sethe011a8e2022-05-13 18:28:53 +00004120 def helper(shape, reduction):
4121 # create the criterion
4122 loss = torch.nn.BCELoss(reduction=reduction)
4123
4124 # input and target must be within [0..1]
4125 input_t = np.random.random_sample(size=shape).astype(np.float32)
4126 target_t = np.random.random_sample(size=shape).astype(np.float32)
4127 inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
4128 targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
4129 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4130 targetMPS = targetCPU.detach().clone().to('mps')
4131
4132 # forward pass
4133 outputCPU = loss(inputCPU, targetCPU)
4134 outputMPS = loss(inputMPS, targetMPS)
4135 self.assertEqual(outputCPU, outputMPS)
4136
4137 # backward pass
4138 if reduction != 'none':
4139 # chose 0.6 just to have the grad_output != 1
4140 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
4141 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
4142 self.assertEqual(inputCPU.grad, inputMPS.grad)
4143
4144 helper([8, 5, 4], 'none')
4145 helper([7, 5, 2, 4], 'sum')
4146 # verify if changes in shape would cause cached graph lookup problems
4147 helper([7, 5, 2, 4, 6], 'sum')
4148 helper([8, 4, 5, 7, 6], 'mean')
Kulin Seth4615f6a2022-06-16 20:21:31 +00004149 helper([1, 1, 32, 32], 'mean')
4150
4151 def test_bce_loss_always_nonnegative(self):
4152 target = torch.ones(5, device='mps')
4153 input = torch.ones(5, device='mps')
4154 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4155
4156 target = torch.zeros(5, device='mps')
4157 input = torch.zeros(5, device='mps')
4158 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4159
4160 def test_bce_loss_size_mismatch(self):
4161 bceloss = nn.BCELoss()
4162 a = torch.rand(25, device='mps')
4163 b = torch.rand(25, 1, device='mps')
4164 with self.assertRaisesRegex(ValueError, r'Using a target size \('):
4165 bceloss(a, b)
4166
4167 def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
4168 x_size = 1024
4169 y_size = 256
4170 target = torch.rand(x_size, y_size, device='mps')
4171
4172 for reduction in ['none', 'mean', 'sum']:
4173 output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
4174 output_logits = output_sig.clone().detach()
4175
4176 output_sig.requires_grad = True
4177 output_logits.requires_grad = True
4178 weight = torch.rand(y_size, device='mps')
4179
4180 loss_sig = nn.BCELoss(weight, reduction=reduction)(
4181 torch.sigmoid(output_sig), target
4182 )
4183 loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
4184 output_logits, target
4185 )
4186
4187 self.assertEqual(loss_logits, loss_sig)
4188
4189 if reduction == 'none':
4190 grad = torch.rand(x_size, y_size, device='mps')
4191 loss_sig.backward(grad)
4192 loss_logits.backward(grad)
4193 else:
4194 loss_sig.backward()
4195 loss_logits.backward()
4196
4197 self.assertEqual(output_sig.grad, output_logits.grad)
4198
4199 def test_bce_with_logits_has_correct_grad_at_zero(self):
4200 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4201 target = torch.zeros(3, 1, device='mps')
4202 nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
4203 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4204 self.assertEqual(output.grad, expected_grad)
4205
4206 def test_bce_with_logits_broadcasts_weights(self):
4207 target = torch.rand(16, 4, device='mps')
4208 output = torch.rand(16, 4, device='mps') - 0.5
4209
4210 weight = torch.rand(4, device='mps')
4211 out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4212
4213 weight = weight.expand(16, 4).contiguous()
4214 out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4215
4216 self.assertEqual(out1, out2)
4217
4218 weight = torch.rand(16, 1, device='mps')
4219 out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4220
4221 weight = weight.expand(16, 4).contiguous()
4222 out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4223
4224 self.assertEqual(out1, out2)
4225
4226 def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
4227 target = torch.rand(64, 4, device='mps')
4228 output = torch.rand(64, 4, device='mps') - 0.5
4229 pos_weight = torch.ones(64, 4, device='mps')
4230
4231 self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
4232 nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
4233
4234 def test_bce_with_logits_broadcasts_pos_weights(self):
4235 target = torch.rand(64, 4, device='mps')
4236 output = torch.rand(64, 4, device='mps') - 0.5
4237 pos_weight = torch.rand(4, device='mps')
4238 out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4239
4240 pos_weight1 = pos_weight.expand(1, 4)
4241 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
4242
4243 pos_weight2 = pos_weight.expand(64, 4)
4244 out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
4245
4246 self.assertEqual(out1, out2)
4247 self.assertEqual(out1, out3)
4248
4249 def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
4250 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4251 target = torch.zeros(3, 1, device='mps')
4252 pos_weight = torch.ones(3, 1, device='mps')
4253 nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
4254 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4255 grad = output.grad
4256 self.assertEqual(grad, expected_grad)
4257
4258 def test_bce_with_logits_stability(self):
4259 output = torch.tensor([0., -120.], device='mps')
4260 target = torch.tensor([0., 1.], device='mps')
4261 pos_weight = torch.tensor([1., 1.], device='mps')
4262
4263 out1 = nn.BCEWithLogitsLoss()(output, target)
4264 self.assertTrue(torch.isfinite(out1).all().item())
4265
4266 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4267 self.assertTrue(torch.isfinite(out2).all().item())
4268
4269 def test_bce_loss_broadcasts_weights(self):
4270 sigmoid = nn.Sigmoid()
4271 target = torch.rand(16, 4, device='mps')
4272 output = torch.rand(16, 4, device='mps') - 0.5
4273
4274 weight = torch.rand(4, device='mps')
4275 out1 = nn.BCELoss(weight)(sigmoid(output), target)
4276
4277 weight = weight.expand(16, 4).contiguous()
4278 out2 = nn.BCELoss(weight)(sigmoid(output), target)
4279
4280 self.assertEqual(out1, out2)
4281
4282 weight = torch.rand(16, 1, device='mps')
4283 out1 = nn.BCELoss(weight)(sigmoid(output), target)
4284
4285 weight = weight.expand(16, 4).contiguous()
4286 out2 = nn.BCELoss(weight)(sigmoid(output), target)
4287
4288 self.assertEqual(out1, out2)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004289
4290 def test_log_softmax(self):
4291 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
4292 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4293 mps_x = torch.tensor(values, device='mps', requires_grad=True)
4294
4295 cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
4296 mps_log_softmax = F.log_softmax(mps_x, dim=0)
4297 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
4298
4299 cpu_grad = torch.ones_like(cpu_log_softmax)
4300 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
4301
4302 cpu_log_softmax.backward(gradient=cpu_grad)
4303 mps_log_softmax.backward(gradient=mps_grad)
4304
4305 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
4306
alexdremova17a7cc2023-02-18 18:26:29 +00004307 def test_log_softmax_large_numbers(self):
4308 values = [
4309 [10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
4310 [-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
4311 ]
4312 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4313 mps_x = torch.tensor(values, device='mps', requires_grad=True)
4314
4315 cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
4316 mps_log_softmax = F.log_softmax(mps_x, dim=-1)
4317 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
4318
4319 cpu_grad = torch.ones_like(cpu_log_softmax)
4320 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
4321
4322 cpu_log_softmax.backward(gradient=cpu_grad)
4323 mps_log_softmax.backward(gradient=mps_grad)
4324
4325 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
4326
Kulin Sethe011a8e2022-05-13 18:28:53 +00004327 def test_eq(self):
4328 values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
4329 values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
4330 mps_x = torch.tensor(values1, device='mps')
4331 mps_y = torch.tensor(values2, device='mps')
4332 cpu_x = torch.tensor(values1, device='cpu')
4333 cpu_y = torch.tensor(values2, device='cpu')
4334 result_mps = torch.eq(mps_x, mps_y)
4335 result_cpu = torch.eq(cpu_x, cpu_y)
4336
4337 self.assertEqual(result_cpu, result_mps.to('cpu'))
4338
Denis Vieriu71ec2612023-02-15 06:09:56 +00004339 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Ramin Azarmehr6485d262022-12-23 17:11:55 +00004340 def test_signed_vs_unsigned_comparison(self):
4341 cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
4342 mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
4343 # in the comparison of signed vs. unsigned we should always cast to unsigned
4344 self.assertEqual(cpu_x == -1, mps_x == -1)
4345 self.assertEqual(cpu_x > -1, mps_x > -1)
4346 self.assertEqual(cpu_x < -1, mps_x < -1)
4347
Kulin Sethe011a8e2022-05-13 18:28:53 +00004348 def test_eq_int64(self):
4349 values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
4350 values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
4351 mps_x = torch.tensor(values1, device='mps')
4352 mps_y = torch.tensor(values2, device='mps')
4353 cpu_x = torch.tensor(values1, device='cpu')
4354 cpu_y = torch.tensor(values2, device='cpu')
4355 result_mps = torch.eq(mps_x, mps_y)
4356 result_cpu = torch.eq(cpu_x, cpu_y)
4357
4358 self.assertEqual(result_cpu, result_mps.to('cpu'))
4359
4360 def test_ne(self):
4361 def helper(shape):
4362 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4363 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
4364 mps_x = cpu_x.detach().clone().to('mps')
4365 mps_y = cpu_y.detach().clone().to('mps')
4366 result_mps = torch.ne(mps_x, mps_y)
4367 result_cpu = torch.ne(cpu_x, cpu_y)
4368
4369 self.assertEqual(result_cpu, result_mps.to('cpu'))
4370
4371 helper((2, 3, 4, 5))
4372
4373 def test_ne_scalar(self):
4374 def helper(shape):
4375 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4376 mps_x = cpu_x.detach().clone().to('mps')
4377 result_mps = torch.ne(mps_x, 0.0)
4378 result_cpu = torch.ne(cpu_x, 0.0)
4379
4380 self.assertEqual(result_cpu, result_mps.to('cpu'))
4381
4382 helper((2, 3, 4, 5))
4383
4384 def test_lt(self):
4385 def helper(shape):
4386 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4387 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
4388 mps_x = cpu_x.detach().clone().to('mps')
4389 mps_y = cpu_y.detach().clone().to('mps')
4390 result_mps = torch.lt(mps_x, mps_y)
4391 result_cpu = torch.lt(cpu_x, cpu_y)
4392
4393 self.assertEqual(result_cpu, result_mps.to('cpu'))
4394
4395 helper((2, 3, 4, 5))
4396
4397 def test_lt_scalar(self):
4398 def helper(shape):
4399 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4400 mps_x = cpu_x.detach().clone().to('mps')
4401 result_mps = torch.lt(mps_x, 0.0)
4402 result_cpu = torch.lt(cpu_x, 0.0)
4403
4404 self.assertEqual(result_cpu, result_mps.to('cpu'))
4405
4406 helper((2, 3, 4, 5))
4407
4408 def test_le(self):
4409 def helper(shape):
4410 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4411 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
4412 mps_x = cpu_x.detach().clone().to('mps')
4413 mps_y = cpu_y.detach().clone().to('mps')
4414 result_mps = torch.le(mps_x, mps_y)
4415 result_cpu = torch.le(cpu_x, cpu_y)
4416
4417 self.assertEqual(result_cpu, result_mps.to('cpu'))
4418
4419 helper((2, 3, 4, 5))
4420
4421 def test_le_scalar(self):
4422 def helper(shape):
4423 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4424 mps_x = cpu_x.detach().clone().to('mps')
4425 result_mps = torch.le(mps_x, 0.0)
4426 result_cpu = torch.le(cpu_x, 0.0)
4427
4428 self.assertEqual(result_cpu, result_mps.to('cpu'))
4429
4430 helper((2, 3, 4, 5))
4431
4432 def test_ge(self):
4433 def helper(shape):
4434 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4435 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
4436 mps_x = cpu_x.detach().clone().to('mps')
4437 mps_y = cpu_y.detach().clone().to('mps')
4438 result_mps = torch.ge(mps_x, mps_y)
4439 result_cpu = torch.ge(cpu_x, cpu_y)
4440
4441 self.assertEqual(result_cpu, result_mps.to('cpu'))
4442
4443 helper((2, 3, 4, 5))
4444
4445 def test_ge_scalar(self):
4446 def helper(shape):
4447 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4448 mps_x = cpu_x.detach().clone().to('mps')
4449 result_mps = torch.ge(mps_x, 0.0)
4450 result_cpu = torch.ge(cpu_x, 0.0)
4451
4452 self.assertEqual(result_cpu, result_mps.to('cpu'))
4453
4454 helper((2, 3, 4, 5))
4455
4456 def test_gt(self):
4457 def helper(shape):
4458 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4459 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
4460 mps_x = cpu_x.detach().clone().to('mps')
4461 mps_y = cpu_y.detach().clone().to('mps')
4462 result_mps = torch.gt(mps_x, mps_y)
4463 result_cpu = torch.gt(cpu_x, cpu_y)
4464
4465 self.assertEqual(result_cpu, result_mps.to('cpu'))
4466
4467 helper((2, 3, 4, 5))
4468
4469 def test_gt_scalar(self):
4470 def helper(shape):
4471 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
4472 mps_x = cpu_x.detach().clone().to('mps')
4473 result_mps = torch.gt(mps_x, 0.0)
4474 result_cpu = torch.gt(cpu_x, 0.0)
4475
4476 self.assertEqual(result_cpu, result_mps.to('cpu'))
4477
4478 helper((2, 3, 4, 5))
4479
qqaatw2458b3c2022-07-07 00:04:49 +00004480 # Test forward argmin argmax
4481 def test_argmin_argmax(self):
4482 def helper(n, c, h, w, reduction_type, dtype=torch.float32):
4483 if reduction_type == "max":
4484 arg_reduction_fn = torch.argmax
4485 else:
4486 arg_reduction_fn = torch.argmin
4487
Kulin Sethe011a8e2022-05-13 18:28:53 +00004488 cpu_x = None
4489 x = None
Thomas4935b592022-11-23 02:18:03 +00004490 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00004491 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
4492 x = cpu_x.detach().clone().to('mps')
4493 elif (dtype == torch.bool):
4494 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
4495 x = cpu_x.detach().clone().to('mps')
4496 else:
4497 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
4498 x = cpu_x.detach().clone().to('mps').requires_grad_()
4499
qqaatw2458b3c2022-07-07 00:04:49 +00004500 y = arg_reduction_fn(x)
4501 ref_y = arg_reduction_fn(cpu_x)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004502 self.assertEqual(y, ref_y)
4503
qqaatw2458b3c2022-07-07 00:04:49 +00004504 y_0 = arg_reduction_fn(x, dim=0)
4505 refy_0 = arg_reduction_fn(cpu_x, dim=0)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004506 self.assertEqual(y_0, refy_0)
4507
qqaatw2458b3c2022-07-07 00:04:49 +00004508 y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
4509 refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004510 self.assertEqual(y_0dim, refy_0dim)
4511
qqaatw2458b3c2022-07-07 00:04:49 +00004512 y_1 = arg_reduction_fn(x, dim=1)
4513 refy_1 = arg_reduction_fn(cpu_x, dim=1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004514 self.assertEqual(y_1, refy_1)
4515
qqaatw2458b3c2022-07-07 00:04:49 +00004516 y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
4517 refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004518 self.assertEqual(y_1dim, refy_1dim)
4519
qqaatw2458b3c2022-07-07 00:04:49 +00004520 y_2 = arg_reduction_fn(x, dim=2)
4521 refy_2 = arg_reduction_fn(cpu_x, dim=2)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004522 self.assertEqual(y_2, refy_2)
4523
qqaatw2458b3c2022-07-07 00:04:49 +00004524 y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
4525 refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004526 self.assertEqual(y_2dim, refy_2dim)
4527
qqaatw2458b3c2022-07-07 00:04:49 +00004528 y_3 = arg_reduction_fn(x, dim=3)
4529 refy_3 = arg_reduction_fn(cpu_x, dim=3)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004530 self.assertEqual(y_3, refy_3)
4531
qqaatw2458b3c2022-07-07 00:04:49 +00004532 y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
4533 refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004534 self.assertEqual(y_3dim, refy_3dim)
4535
qqaatw2458b3c2022-07-07 00:04:49 +00004536 helper(2, 8, 4, 4, "max", torch.float32)
4537 helper(2, 8, 4, 4, "max", torch.int32)
4538 helper(2, 8, 4, 4, "max", torch.float16)
4539 helper(2, 8, 4, 4, "max", torch.int64)
4540 helper(2, 8, 4, 4, "min", torch.float32)
4541 helper(2, 8, 4, 4, "min", torch.int32)
4542 helper(2, 8, 4, 4, "min", torch.float16)
4543 helper(2, 8, 4, 4, "min", torch.int64)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004544
Denis Vieriud0dd8982023-03-02 12:44:59 +00004545 @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
4546 def test_reduction_sum_max_long_val(self):
4547 x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
4548 x_cpu = x_mps.detach().clone().cpu()
4549
4550 res_mps = torch.sum(x_mps)
4551 res_cpu = torch.sum(x_cpu)
4552 self.assertEqual(res_mps, res_cpu)
4553
Kulin Sethe011a8e2022-05-13 18:28:53 +00004554 # Test forward max
4555 # Note - don't test grad now
4556 def test_max_el(self):
4557 def helper(n, c, h, w, dtype=torch.float32):
4558
Thomas4935b592022-11-23 02:18:03 +00004559 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00004560 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
4561 x = cpu_x.detach().clone().to('mps')
4562 elif (dtype == torch.bool):
4563 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
4564 x = cpu_x.detach().clone().to('mps')
4565 else:
4566 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
4567 x = cpu_x.detach().clone().to('mps')
4568
4569 ref_y = torch.max(cpu_x)
4570 y = torch.max(x)
4571 self.assertEqual(y, ref_y)
4572
4573 for dim in [0, 1, 2, 3]:
4574 for keepdim in [True, False]:
4575 y, idx = torch.max(x, dim=dim, keepdim=keepdim)
4576 refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
4577 self.assertEqual(y, refy)
4578 self.assertEqual(idx, refidx)
4579
4580 y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
4581 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
4582 torch.max(x, dim=0, out=(y_0, idx_0))
4583 refy_0, refidx_0 = torch.max(cpu_x, dim=0)
4584 self.assertEqual(y_0, refy_0)
4585 self.assertEqual(idx_0, refidx_0)
4586
4587 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
4588 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
4589 torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
4590 refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
4591 self.assertEqual(y_0dim, refy_0dim)
4592 self.assertEqual(idx_0dim, refidx_0dim)
4593
4594 y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
4595 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
4596 torch.max(x, dim=1, out=(y_1, idx_1))
4597 refy_1, refidx_1 = torch.max(cpu_x, dim=1)
4598 self.assertEqual(y_1, refy_1)
4599 self.assertEqual(idx_1, refidx_1)
4600
4601 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
4602 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
4603 torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
4604 refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
4605 self.assertEqual(y_1dim, refy_1dim)
4606 self.assertEqual(idx_1dim, refidx_1dim)
4607
4608 y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
4609 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
4610 torch.max(x, dim=2, out=(y_2, idx_2))
4611 refy_2, refidx_2 = torch.max(cpu_x, dim=2)
4612 self.assertEqual(y_2, refy_2)
4613 self.assertEqual(idx_2, refidx_2)
4614
4615 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
4616 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
4617 torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
4618 refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
4619 self.assertEqual(y_2dim, refy_2dim)
4620 self.assertEqual(idx_2dim, refidx_2dim)
4621
4622 y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
4623 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
4624 torch.max(x, dim=3, out=(y_3, idx_3))
4625 refy_3, refidx_3 = torch.max(cpu_x, dim=3)
4626 self.assertEqual(y_3, refy_3)
4627 self.assertEqual(idx_3, refidx_3)
4628
4629 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
4630 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
4631 torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
4632 refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
4633 self.assertEqual(y_3dim, refy_3dim)
4634 self.assertEqual(idx_3dim, refidx_3dim)
4635
4636 helper(2, 8, 4, 5, torch.float32)
4637 helper(2, 8, 4, 5, torch.int32)
4638 # helper(2, 8, 4, 5, torch.int64)
4639
Raman kumarfd0efb02022-11-18 02:53:39 +00004640 def test_median(self):
4641 def helper_dtype_int32(n1, n2, n3):
4642 cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
4643 mps_x = cpu_x.detach().clone().to('mps')
4644
4645 result_cpu = torch.median(cpu_x)
4646 result_mps = torch.median(mps_x)
4647
4648 self.assertEqual(result_cpu, result_mps)
4649
4650 for dim in [0, 1, 2]:
4651 for keepdim in [True, False]:
4652 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
4653 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
4654 self.assertEqual(y, refy)
4655 self.assertEqual(idx, refidx)
4656
4657 def helper_dtype_float32(n1, n2, n3):
4658 cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
4659 mps_x = cpu_x.detach().clone().to('mps')
4660
4661 result_cpu = torch.median(cpu_x)
4662 result_mps = torch.median(mps_x)
4663
4664 self.assertEqual(result_cpu, result_mps)
4665
4666 for dim in [0, 1, 2]:
4667 for keepdim in [True, False]:
4668 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
4669 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
4670 self.assertEqual(y, refy)
4671 self.assertEqual(idx, refidx)
4672
4673 helper_dtype_int32(10, 10, 10) # median at even place
4674 helper_dtype_int32(3, 3, 3) # median at odd place
4675 helper_dtype_int32(1, 1, 1)
4676 helper_dtype_int32(1, 2, 3)
4677 helper_dtype_float32(10, 10, 10)
4678 helper_dtype_float32(3, 3, 3)
4679 helper_dtype_float32(1, 1, 1)
4680
Kulin Sethe011a8e2022-05-13 18:28:53 +00004681 def test_any(self):
4682 def helper(shape):
4683 input_xs = []
4684 prod = 1
4685
4686 for i in range(len(shape)):
4687 prod *= shape[i]
4688 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
4689 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
4690 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
4691 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
4692 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
4693 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
4694 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
4695 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
4696 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
4697 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
4698
4699 for i, cpu_x in enumerate(input_xs):
4700 x = cpu_x.detach().clone().to('mps')
4701 y = torch.any(x)
4702 ref_y = torch.any(cpu_x)
4703 self.assertEqual(y, ref_y)
4704
4705 y_0 = torch.any(x, dim=0)
4706 refy_0 = torch.any(cpu_x, dim=0)
4707 self.assertEqual(y_0, refy_0)
4708
4709 y_0dim = torch.any(x, dim=0, keepdim=True)
4710 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
4711 self.assertEqual(y_0dim, refy_0dim)
4712
4713 y_0dim = torch.any(x, dim=0, keepdim=True)
4714 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
4715 self.assertEqual(y_0dim, refy_0dim)
4716
4717 y_1 = torch.any(x, dim=1)
4718 refy_1 = torch.any(cpu_x, dim=1)
4719 self.assertEqual(y_1, refy_1)
4720
4721 y_1dim = torch.any(x, dim=1, keepdim=True)
4722 refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
4723 self.assertEqual(y_1dim, refy_1dim)
4724
4725 if (len(shape) > 2):
4726 y_2 = torch.any(x, dim=2)
4727 refy_2 = torch.any(cpu_x, dim=2)
4728 self.assertEqual(y_2, refy_2)
4729
4730 y_2dim = torch.any(x, dim=2, keepdim=True)
4731 refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
4732 self.assertEqual(y_2dim, refy_2dim)
4733
4734 y_3 = torch.any(x, dim=3)
4735 refy_3 = torch.any(cpu_x, dim=3)
4736 self.assertEqual(y_3, refy_3)
4737
4738 y_3dim = torch.any(x, dim=3, keepdim=True)
4739 refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
4740 self.assertEqual(y_3dim, refy_3dim)
4741 helper((1, 1, 1, 1))
4742 helper((1, 1, 3, 3))
4743 helper((7, 13))
4744 helper((2, 8, 4, 5))
4745
4746 def test_all(self):
4747 def helper(shape):
4748 input_xs = []
4749 prod = 1
4750
4751 for i in range(len(shape)):
4752 prod *= shape[i]
4753 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
4754 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
4755 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
4756 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
4757 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
4758 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
4759 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
4760 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
4761 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
4762 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
4763
4764 for i, cpu_x in enumerate(input_xs):
4765 x = cpu_x.detach().clone().to('mps')
4766 y = torch.all(x)
4767 ref_y = torch.all(cpu_x)
4768 self.assertEqual(y, ref_y)
4769
4770 y_0 = torch.all(x, dim=0)
4771 refy_0 = torch.all(cpu_x, dim=0)
4772 self.assertEqual(y_0, refy_0)
4773
4774 y_0dim = torch.all(x, dim=0, keepdim=True)
4775 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
4776 self.assertEqual(y_0dim, refy_0dim)
4777
4778 y_0dim = torch.all(x, dim=0, keepdim=True)
4779 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
4780 self.assertEqual(y_0dim, refy_0dim)
4781
4782 y_1 = torch.all(x, dim=1)
4783 refy_1 = torch.all(cpu_x, dim=1)
4784 self.assertEqual(y_1, refy_1)
4785
4786 y_1dim = torch.all(x, dim=1, keepdim=True)
4787 refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
4788 self.assertEqual(y_1dim, refy_1dim)
4789 if (len(shape) > 2):
4790 y_2 = torch.all(x, dim=2)
4791 refy_2 = torch.all(cpu_x, dim=2)
4792 self.assertEqual(y_2, refy_2)
4793
4794 y_2dim = torch.all(x, dim=2, keepdim=True)
4795 refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
4796 self.assertEqual(y_2dim, refy_2dim)
4797
4798 y_3 = torch.all(x, dim=3)
4799 refy_3 = torch.all(cpu_x, dim=3)
4800 self.assertEqual(y_3, refy_3)
4801
4802 y_3dim = torch.all(x, dim=3, keepdim=True)
4803 refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
4804 self.assertEqual(y_3dim, refy_3dim)
4805
4806 helper((1, 1, 1, 1))
4807 helper((1, 1, 3, 3))
4808 helper((7, 13))
4809 helper((2, 8, 4, 5))
4810
4811 # Test forward min
4812 def test_min_el(self):
4813 def helper(n, c, h, w):
4814 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
4815 x = cpu_x.detach().clone().to('mps')
4816
4817 y = torch.min(x)
4818 ref_y = torch.min(cpu_x)
4819 self.assertEqual(y, ref_y)
4820
4821 y_0, idx_0 = torch.min(x, dim=0)
4822 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
4823 self.assertEqual(y_0, refy_0)
4824 self.assertEqual(idx_0, refidx_0)
4825
4826 y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
4827 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
4828 torch.min(x, dim=0, out=(y_0, idx_0))
4829 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
4830 self.assertEqual(y_0, refy_0)
4831 self.assertEqual(idx_0, refidx_0)
4832
4833 y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
4834 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
4835 self.assertEqual(y_0dim, refy_0dim)
4836 self.assertEqual(idx_0dim, refidx_0dim)
4837
4838 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
4839 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
4840 torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
4841 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
4842 self.assertEqual(y_0dim, refy_0dim)
4843 self.assertEqual(idx_0dim, refidx_0dim)
4844
4845 y_1, idx_1 = torch.min(x, dim=1)
4846 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
4847 self.assertEqual(y_1, refy_1)
4848 self.assertEqual(idx_1, refidx_1)
4849
4850 y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
4851 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
4852 torch.min(x, dim=1, out=(y_1, idx_1))
4853 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
4854 self.assertEqual(y_1, refy_1)
4855 self.assertEqual(idx_1, refidx_1)
4856
4857 y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
4858 refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
4859 self.assertEqual(y_1dim, refy_1dim)
4860 self.assertEqual(idx_1dim, refidx_1dim)
4861
4862 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
4863 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
4864 torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
4865 refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
4866 self.assertEqual(y_1dim, refy_1dim)
4867 self.assertEqual(idx_1dim, refidx_1dim)
4868
4869 y_2, idx_2 = torch.min(x, dim=2)
4870 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
4871 self.assertEqual(y_2, refy_2)
4872 self.assertEqual(idx_2, refidx_2)
4873
4874 y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
4875 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
4876 torch.min(x, dim=2, out=(y_2, idx_2))
4877 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
4878 self.assertEqual(y_2, refy_2)
4879 self.assertEqual(idx_2, refidx_2)
4880
4881 y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
4882 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
4883 self.assertEqual(y_2dim, refy_2dim)
4884 self.assertEqual(idx_2dim, refidx_2dim)
4885
4886 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
4887 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
4888 torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
4889 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
4890 self.assertEqual(y_2dim, refy_2dim)
4891 self.assertEqual(idx_2dim, refidx_2dim)
4892
4893 y_3, idx_3 = torch.min(x, dim=3)
4894 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
4895 self.assertEqual(y_3, refy_3)
4896 self.assertEqual(idx_3, refidx_3)
4897
4898 y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
4899 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
4900 torch.min(x, dim=3, out=(y_3, idx_3))
4901 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
4902 self.assertEqual(y_3, refy_3)
4903 self.assertEqual(idx_3, refidx_3)
4904
4905 y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
4906 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
4907 self.assertEqual(y_3dim, refy_3dim)
4908 self.assertEqual(idx_3dim, refidx_3dim)
4909
4910 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
4911 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
4912 torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
4913 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
4914 self.assertEqual(y_3dim, refy_3dim)
4915 self.assertEqual(idx_3dim, refidx_3dim)
4916
4917 helper(2, 8, 4, 5)
4918
4919 # Test forward sum
4920 def test_sum(self):
4921 def helper(n, c, h, w, dtype=torch.float32):
4922 cpu_x = None
4923 x = None
Thomas4935b592022-11-23 02:18:03 +00004924 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00004925 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
4926 x = cpu_x.detach().clone().to('mps')
4927 elif (dtype == torch.bool):
4928 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
4929 x = cpu_x.detach().clone().to('mps')
4930 else:
4931 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
4932 x = cpu_x.detach().clone().to('mps').requires_grad_()
4933
4934 all_sum = torch.sum(x)
4935 all_sum_cpu = torch.sum(cpu_x)
4936
4937 self.assertEqual(all_sum, all_sum_cpu)
4938
4939 nil_dim_sum = torch.sum(x, dim=[])
4940 nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
4941
4942 self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
4943
4944 nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
4945 nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
4946
4947 self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
4948
4949 zero_dim_sum = torch.sum(x, dim=[0])
4950 zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
4951
4952 self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
4953
4954 zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
4955 zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
4956
4957 self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
4958
4959 zero_one_dim_sum = torch.sum(x, dim=[0, 1])
4960 zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
4961
4962 self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
4963
4964 zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
4965 zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
4966
4967 self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
4968
4969 two_three_dim_sum = torch.sum(x, dim=[2, 3])
4970 two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
4971
4972 self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
4973
4974 two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
4975 two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
4976
4977 self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
4978
4979 helper(2, 8, 4, 5)
4980 helper(2, 8, 4, 5, dtype=torch.int32)
4981 helper(2, 8, 4, 5, dtype=torch.int64)
4982 helper(2, 8, 4, 5, dtype=torch.bool)
4983
4984 # Test forward prod
4985 def test_prod(self):
4986 def helper(shape, dtype=torch.float32):
4987 cpu_x = None
4988 x = None
Thomas4935b592022-11-23 02:18:03 +00004989 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00004990 cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
4991 x = cpu_x.detach().clone().to('mps')
4992 elif (dtype == torch.bool):
4993 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
4994 x = cpu_x.detach().clone().to('mps')
4995 else:
4996 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
4997 x = cpu_x.detach().clone().to('mps').requires_grad_()
4998
4999 all_prod = torch.prod(x)
5000 all_prod_cpu = torch.prod(cpu_x)
5001
5002 self.assertEqual(all_prod, all_prod_cpu)
5003
5004 for dim in range(len(shape)):
5005 dim_prod = torch.prod(x, dim=dim)
5006 dim_prod_cpu = torch.prod(cpu_x, dim=dim)
5007
5008 self.assertEqual(dim_prod, dim_prod_cpu)
5009
5010 dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
5011 dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
5012
5013 self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
5014
5015 for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
5016 helper((2, 3), dtype)
5017
5018 # Test forward mean
5019 def test_mean(self):
5020 def helper(n, c, h, w):
5021 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
5022 x = cpu_x.detach().clone().to('mps').requires_grad_()
5023
5024 all_mean = torch.mean(x)
5025 all_mean_cpu = torch.mean(cpu_x)
5026
5027 self.assertEqual(all_mean, all_mean_cpu)
5028
5029 nil_dim_mean = torch.mean(x, dim=[])
5030 nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
5031
5032 self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
5033
5034 nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
5035 nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
5036
5037 self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
5038
5039 zero_dim_mean = torch.mean(x, dim=[0])
5040 zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
5041
5042 self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
5043
5044 zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
5045 zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
5046
5047 self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
5048
5049 zero_one_dim_mean = torch.mean(x, dim=[0, 1])
5050 zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
5051
5052 self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
5053
5054 zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
5055 zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
5056
5057 self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
5058
5059 two_three_dim_mean = torch.mean(x, dim=[2, 3])
5060 two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
5061
5062 self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
5063
5064 two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
5065 two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
5066
5067 self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
5068
5069 helper(2, 8, 4, 5)
5070
5071 # Test std
5072 def test_std(self):
5073 def helper(shape):
5074 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5075 x = cpu_x.detach().clone().to('mps')
5076
5077 all_std = torch.std(x, unbiased=False)
5078 all_std_cpu = torch.std(cpu_x, unbiased=False)
5079
5080 self.assertEqual(all_std, all_std_cpu)
5081
5082 nil_dim_std = torch.std(x, dim=[], unbiased=False)
5083 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
5084
5085 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5086
5087 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
5088 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
5089
5090 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5091
5092 zero_dim_std = torch.std(x, dim=[0], unbiased=False)
5093 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
5094
5095 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5096
5097 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
5098 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
5099
5100 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5101
5102 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
5103 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
5104
5105 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5106
5107 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
5108 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
5109
5110 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5111
5112 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
5113 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
5114
5115 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5116
5117 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
5118 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
5119
5120 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5121
5122 all_std = torch.std(x, unbiased=True)
5123 all_std_cpu = torch.std(cpu_x, unbiased=True)
5124
5125 self.assertEqual(all_std, all_std_cpu)
5126
5127 nil_dim_std = torch.std(x, dim=[], unbiased=True)
5128 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
5129
5130 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5131
5132 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
5133 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
5134
5135 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5136
5137 zero_dim_std = torch.std(x, dim=[0], unbiased=True)
5138 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
5139
5140 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5141
5142 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
5143 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
5144
5145 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5146
5147 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
5148 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
5149
5150 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5151
5152 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
5153 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
5154
5155 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5156
5157 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
5158 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
5159
5160 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5161
5162 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
5163 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
5164
5165 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5166
5167 helper((4, 5, 6, 7))
qqaatwae6f07e2022-06-30 12:56:55 +00005168 # verify if a change in shape of input would cause problems with graph caching
5169 helper((9, 5, 6, 7))
Kulin Sethe011a8e2022-05-13 18:28:53 +00005170
5171 # Test var
Abhishek Pathakf0570352022-09-25 19:03:58 +00005172 def test_var_simple(self):
5173 def helper():
5174
5175 shape = [2, 3, 4, 5]
5176
Kulin Sethe011a8e2022-05-13 18:28:53 +00005177 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5178 x = cpu_x.detach().clone().to('mps')
5179
Abhishek Pathakf0570352022-09-25 19:03:58 +00005180 for unbiased in [False, True]:
5181 for keepdim in [False, True]:
Kulin Sethe011a8e2022-05-13 18:28:53 +00005182
Abhishek Pathakf0570352022-09-25 19:03:58 +00005183 zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
5184 zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005185
Abhishek Pathakf0570352022-09-25 19:03:58 +00005186 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005187
Abhishek Pathakf0570352022-09-25 19:03:58 +00005188 all_var = torch.var(x, unbiased=unbiased)
5189 all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005190
Abhishek Pathakf0570352022-09-25 19:03:58 +00005191 self.assertEqual(all_var, all_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005192
Abhishek Pathakf0570352022-09-25 19:03:58 +00005193 nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
5194 nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005195
Abhishek Pathakf0570352022-09-25 19:03:58 +00005196 self.assertEqual(nil_dim_var, nil_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005197
Abhishek Pathakf0570352022-09-25 19:03:58 +00005198 zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
5199 zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005200
Abhishek Pathakf0570352022-09-25 19:03:58 +00005201 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005202
Abhishek Pathakf0570352022-09-25 19:03:58 +00005203 zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
5204 zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005205
Abhishek Pathakf0570352022-09-25 19:03:58 +00005206 self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005207
Abhishek Pathakf0570352022-09-25 19:03:58 +00005208 two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
5209 two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005210
Abhishek Pathakf0570352022-09-25 19:03:58 +00005211 self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005212
Abhishek Pathakf0570352022-09-25 19:03:58 +00005213 helper()
Kulin Sethe011a8e2022-05-13 18:28:53 +00005214
Abhishek Pathak074dc742022-06-18 00:14:05 +00005215 # Test forward amax
5216 def test_amax(self):
5217 def helper(shape, dim, keepdim):
5218 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5219 x = cpu_x.detach().clone().to('mps').requires_grad_()
5220
5221 result = torch.amax(x, dim=dim, keepdim=keepdim)
5222 result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
5223
5224 cpu_grad = torch.randn(result_cpu.shape)
5225 grad = cpu_grad.to('mps')
5226
5227 result_cpu.backward(gradient=cpu_grad)
5228 result.backward(gradient=grad)
5229
5230 self.assertEqual(result, result_cpu)
5231 self.assertEqual(x.grad, cpu_x.grad)
5232
5233 for dim in ([], [0], [0, 1], [2, 3]):
5234 for keepdim in [False, True]:
5235 helper((2, 8, 4, 5), dim, keepdim)
5236
5237 # Test forward amin
5238 def test_amin(self):
5239 def helper(shape, dim, keepdim):
5240 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5241 x = cpu_x.detach().clone().to('mps').requires_grad_()
5242
5243 result = torch.amin(x, dim=dim, keepdim=keepdim)
5244 result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
5245
5246 cpu_grad = torch.randn(result_cpu.shape)
5247 grad = cpu_grad.to('mps')
5248
5249 result_cpu.backward(gradient=cpu_grad)
5250 result.backward(gradient=grad)
5251
5252 self.assertEqual(result, result_cpu)
5253 self.assertEqual(x.grad, cpu_x.grad)
5254
5255 for dim in ([], [0], [0, 1], [2, 3]):
5256 for keepdim in [False, True]:
5257 helper((2, 8, 4, 5), dim, keepdim)
5258
Kulin Sethe011a8e2022-05-13 18:28:53 +00005259 # Test minimum and maximum
5260 def test_minimum_maximum(self):
5261 def helper(n, c, h, w):
5262 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5263 cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5264 mps_x = cpu_x.detach().clone().to('mps')
5265 mps_y = cpu_y.detach().clone().to('mps')
5266
5267 minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
5268 minimum_result_mps = torch.minimum(mps_x, mps_y)
5269 self.assertEqual(minimum_result_cpu, minimum_result_mps)
5270
5271 maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
5272 maximum_result_mps = torch.maximum(mps_x, mps_y)
5273 self.assertEqual(maximum_result_cpu, maximum_result_mps)
5274
5275 helper(1, 1, 4, 5)
5276
5277 # Test clamp_min
5278 def test_clamp_min(self):
5279 def helper(n, c, h, w):
5280 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5281 x = cpu_x.detach().clone().to('mps')
5282
5283 cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5284 min_t = cpu_min_t.detach().clone().to('mps')
5285
5286 clamp_min_result = torch.clamp_min(x, min=5.0)
5287 clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
5288
5289 self.assertEqual(clamp_min_result, clamp_min_result_cpu)
5290
5291 clamp_min_t_result = torch.clamp_min(x, min=min_t)
5292 clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
5293
5294 self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
5295
5296 helper(2, 8, 4, 5)
5297
5298 # Test clamp_max
5299
5300 def test_clamp_max(self):
5301 def helper(n, c, h, w):
5302 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5303 x = cpu_x.detach().clone().to('mps')
5304
5305 cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5306 max_t = cpu_max_t.detach().clone().to('mps')
5307
5308 clamp_max_result = torch.clamp_max(x, max=100.0)
5309 clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
5310
5311 self.assertEqual(clamp_max_result, clamp_max_result_cpu)
5312
5313 clamp_max_t_result = torch.clamp_max(x, max=max_t)
5314 clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
5315
5316 self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
5317
5318 helper(2, 8, 4, 5)
5319
5320 # Test clamp
5321 def test_clamp(self):
5322 def helper(n, c, h, w):
5323 import numpy as np
5324 upper_bound = 1000
5325 half_upper_bound = upper_bound / 2
5326
5327 # x=[0..1000)
5328 x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
5329 cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
5330 x = cpu_x.detach().clone().to('mps')
5331
5332 # x=[0..500)
5333 min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
5334 cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
5335 min_t = cpu_min_t.detach().clone().to('mps')
5336
5337 # x=[500..1000), to ensure max's are greater than mins
5338 max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
5339 cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
5340 max_t = cpu_max_t.detach().clone().to('mps')
5341
5342 # [200..600]: just an arbitrary range between [0..1000]
5343 clamp_result = torch.clamp(x, min=200.0, max=600.0)
5344 clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
5345 self.assertEqual(clamp_result, clamp_result_cpu)
5346
5347 # test optional scalar refs and cached graph keys by passing only max
5348 clamp_opt_result = torch.clamp(x, max=600.0)
5349 clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
5350 self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
5351
5352 clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
5353 clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
5354 self.assertEqual(clamp_t_result, clamp_t_result_cpu)
5355
5356 # test optional tensor refs and cached graph keys by passing only max
5357 clamp_topt_result = torch.clamp(x, max=max_t)
5358 clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
5359 self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
5360
5361 # test inplace clamping
5362 x.clamp_(min=200.0, max=600.0)
5363 cpu_x.clamp_(min=200.0, max=600.0)
5364 self.assertEqual(cpu_x, x)
5365
5366 helper(2, 8, 4, 5)
5367
5368 def test_divmode(self):
5369 def helper(shape, rounding_mode):
Abhishek Pathakbccc26f2022-09-10 03:10:04 +00005370 for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
Kulin Seth5d9d8c62023-03-01 20:52:28 +00005371 if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or
5372 (rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False:
Kulin Seth299ada92023-02-10 00:10:08 +00005373 cpu_x = None
5374 cpu_y = None
5375 if (dtype in [torch.float32, torch.float16]):
5376 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
5377 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
5378 else:
5379 cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
5380 cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
Abhishek Pathakbccc26f2022-09-10 03:10:04 +00005381
Kulin Seth299ada92023-02-10 00:10:08 +00005382 mps_x = cpu_x.detach().clone().to('mps')
5383 # clamp to avoid division by 0
5384 mps_y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00005385
Kulin Seth299ada92023-02-10 00:10:08 +00005386 if (rounding_mode == "floor_divide"):
5387 result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
5388 result_div_mps = torch.floor_divide(mps_x, mps_y)
5389 self.assertEqual(result_div_mps, result_div_cpu)
5390 else:
5391 result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
5392 result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
5393 self.assertEqual(result_div_mps, result_div_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005394
Kulin Setha6347f52022-06-07 18:22:10 +00005395 helper((2, 8, 4, 5), None)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005396 helper((2, 8, 4, 5), "floor")
5397 helper((2, 8, 4, 5), "trunc")
Ramin Azarmehrb63f0312022-12-20 17:02:29 +00005398 helper((2, 8, 4, 5), "floor_divide")
Kulin Sethe011a8e2022-05-13 18:28:53 +00005399
5400 def test_rounding(self):
5401 def helper(shape):
5402 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5403 mps_x = cpu_x.detach().clone().to('mps')
5404
5405 result_floor_cpu = torch.floor(cpu_x)
5406 result_floor_mps = torch.floor(mps_x)
5407 self.assertEqual(result_floor_mps, result_floor_cpu)
5408
5409 result_ceil_cpu = torch.ceil(cpu_x)
5410 result_ceil_mps = torch.ceil(mps_x)
5411 self.assertEqual(result_ceil_mps, result_ceil_cpu)
5412
5413 result_trunc_cpu = torch.trunc(cpu_x)
5414 result_trunc_mps = torch.trunc(mps_x)
5415 self.assertEqual(result_trunc_mps, result_trunc_cpu)
5416
5417 result_round_cpu = torch.round(cpu_x)
5418 result_round_mps = torch.round(mps_x)
5419 self.assertEqual(result_round_mps, result_round_cpu)
5420
5421 helper((2, 6, 3, 5))
5422 helper((2, 8, 4, 5))
5423
Denis Vieriucedb7e32023-02-14 01:06:49 +00005424 def test_remainder(self):
5425 res_cpu = torch.remainder(
5426 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32))
5427 res_mps = torch.remainder(
5428 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32))
5429 self.assertEqual(res_cpu, res_mps)
5430
5431 res_cpu = torch.remainder(
5432 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5)
5433 res_mps = torch.remainder(
5434 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
5435 self.assertEqual(res_cpu, res_mps)
5436
Kulin Sethe011a8e2022-05-13 18:28:53 +00005437 def test_expand(self):
5438 def helper(n, c):
5439 values = [[1.0], [4.0], [7.0]]
5440 cpu_x = torch.tensor(values, device='cpu')
5441 x = cpu_x.detach().clone().to('mps')
5442
5443 strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
5444 strided_mps = torch.as_strided(x, (3, 4), (1, 0))
5445
Kulin Sethe011a8e2022-05-13 18:28:53 +00005446 self.assertEqual(strided_mps, strided_cpu)
5447
5448 helper(3, 1)
5449
Kulin Seth0fe11582023-02-10 15:22:59 +00005450 def test_im2col(self):
5451 def helper(x):
5452 return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3)
5453 x_cpu = torch.rand(1, 1, 200, 100)
5454 x = x_cpu.detach().clone().to('mps')
5455 self.assertEqual(helper(x_cpu), helper(x))
5456
Kulin Sethe011a8e2022-05-13 18:28:53 +00005457 def test_select(self):
5458 def helper(n, c):
5459 cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
5460 x = cpu_x.detach().clone().to('mps').requires_grad_()
5461
5462 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
5463 strided_mps = torch.as_strided(x, (3, 1), (3, 1))
5464 self.assertEqual(strided_mps, strided_cpu)
5465
5466 strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
5467 strided_mps = torch.as_strided(x, (1, 3), (3, 1))
5468 self.assertEqual(strided_mps, strided_cpu)
5469
5470 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
5471 strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005472
5473 self.assertEqual(strided_mps, strided_cpu)
5474
5475 helper(3, 3)
5476
Kulin Seth54c0f3762023-02-12 00:57:53 +00005477 def test_topk(self):
5478 def helper(shape):
5479 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5480 x = cpu_x.detach().clone().to('mps')
5481 for largest_val in [True, False]:
5482 if (type(shape) == tuple):
5483 for curr_dim in range(0, len(shape)):
5484 dim_size = shape[curr_dim]
5485 for k in range(1, dim_size + 1):
5486 topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest_val)
5487 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest_val)
5488 self.assertEqual(topk_values, topk_values_cpu)
5489 self.assertEqual(topk_indices, topk_indices_cpu)
5490 else:
5491 for k in range(1, shape):
5492 topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest_val)
5493 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest_val)
5494 self.assertEqual(topk_values, topk_values_cpu)
5495 self.assertEqual(topk_indices, topk_indices_cpu)
Kulin Seth355a1c82022-06-16 16:06:45 +00005496
Kulin Seth54c0f3762023-02-12 00:57:53 +00005497 helper(2)
5498 helper((5, 1))
5499 helper((1, 5))
5500 helper((5, 9, 7, 4))
5501 helper((50, 20, 7, 4))
Kulin Sethe011a8e2022-05-13 18:28:53 +00005502
Kulin Seth18587cb2023-02-13 01:03:22 +00005503 def test_sort(self):
5504 for SIZE in (4, 2049):
5505 device = 'mps'
5506 x = torch.rand(4, SIZE, device=device)
5507 res1val, res1ind = torch.sort(x)
5508
5509 res2val = torch.tensor((), device=device)
5510 res2ind = torch.tensor((), device=device, dtype=torch.long)
5511 torch.sort(x, out=(res2val, res2ind))
5512 self.assertEqual(res1val, res2val, atol=0, rtol=0)
5513 self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
5514 self.assertEqual(torch.argsort(x), res1ind)
5515 self.assertEqual(x.argsort(), res1ind)
5516
5517 self.assertEqual(
5518 torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
5519 torch.tensor((10, 20, 30, 40, 50), device=device),
5520 atol=0, rtol=0
5521 )
5522
Kulin Sethe011a8e2022-05-13 18:28:53 +00005523 def test_upsample_nearest2d(self):
Denis Vieriua2afc652023-02-17 05:07:22 +00005524 def helper(N, C, H, W, memory_format):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005525 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
Denis Vieriua2afc652023-02-17 05:07:22 +00005526 requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005527 inputCPU.retain_grad()
Alban Desmaisonbde246f2022-05-30 10:36:31 -04005528 inputMPS = inputCPU.detach().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00005529
Alban Desmaisonbde246f2022-05-30 10:36:31 -04005530 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00005531
Alban Desmaisonbde246f2022-05-30 10:36:31 -04005532 for i in values:
5533 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00005534 upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
5535
5536 outputCPU = upsample_nearest2d(inputCPU)
5537 outputMPS = upsample_nearest2d(inputMPS)
5538
5539 self.assertEqual(outputCPU, outputMPS)
5540 upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
5541
5542 outputCPU = upsample_nearest2d(inputCPU)
5543 outputMPS = upsample_nearest2d(inputMPS)
5544
5545 self.assertEqual(outputCPU, outputMPS)
5546
5547 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
5548 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
5549
5550 self.assertEqual(inputCPU.grad, inputMPS.grad)
5551
Denis Vieriua2afc652023-02-17 05:07:22 +00005552 for memory_format in [torch.channels_last, torch.contiguous_format]:
5553 helper(1, 1, 4, 4, memory_format=memory_format)
5554 helper(7, 5, 3, 2, memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005555
5556 def test_upsample_bilinear2d(self):
5557 def helper(N, C, H, W):
5558 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
5559 requires_grad=True).reshape(N, C, H, W)
5560 inputCPU.retain_grad()
5561 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
5562
Alban Desmaisonbde246f2022-05-30 10:36:31 -04005563 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00005564
Alban Desmaisonbde246f2022-05-30 10:36:31 -04005565 for i in values:
5566 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00005567 upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
5568
5569 outputCPU = upsample_bilinear2d(inputCPU)
5570 outputMPS = upsample_bilinear2d(inputMPS)
5571
5572 self.assertEqual(outputCPU, outputMPS)
5573
5574 upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
5575
5576 outputCPU = upsample_bilinear2d(inputCPU)
5577 outputMPS = upsample_bilinear2d(inputMPS)
5578
5579 self.assertEqual(outputCPU, outputMPS)
5580
5581 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
5582 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
5583
5584 self.assertEqual(inputCPU.grad, inputMPS.grad)
5585
5586 helper(1, 1, 4, 4)
5587 helper(7, 5, 3, 2)
5588
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00005589 def test_interpolate(self):
5590 def helper(shape, output_size, scales, mode, align_corners=False):
5591 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5592 inputCPU.retain_grad()
5593 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
Kulin Seth067c8062022-07-13 21:39:50 +00005594
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00005595 # align_corners is used for 2D interpolation only
5596 if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
5597 if scales is not None:
5598 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
5599 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
5600 else:
5601 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
5602 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
5603 elif scales is not None:
5604 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
5605 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
5606 else:
5607 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
5608 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
Kulin Seth067c8062022-07-13 21:39:50 +00005609
5610 self.assertEqual(outputCPU, outputMPS)
5611
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00005612 # backward pass (chose 0.6 just to have the grad_output != 1)
5613 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
5614 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
5615 self.assertEqual(inputCPU.grad, inputMPS.grad)
5616
5617 # 1D interpolation
5618 for mode in ['nearest', 'nearest-exact']:
5619 helper([2, 3, 4], [3], None, mode) # downsample with size
5620 helper([2, 3, 4], [6], None, mode) # upsample with size
5621 helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor
5622 helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor
5623 # 2D interpolation
5624 for mode in ['nearest', 'nearest-exact', 'bilinear']:
5625 helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size
5626 helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size
5627 helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor
5628 helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor
5629 # align_corners=True
5630 helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
5631 helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
Kulin Seth067c8062022-07-13 21:39:50 +00005632
Kulin Sethe011a8e2022-05-13 18:28:53 +00005633 # Test concat forward
5634 def test_cat1(self):
5635 def helper(shape_x, shape_y, shape_z):
5636 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
5637 x = cpu_x.detach().clone().to('mps')
5638
5639 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
5640 y = cpu_y.detach().clone().to('mps')
5641
5642 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
5643 z = cpu_z.detach().clone().to('mps')
5644
5645 cat = torch.cat([x, y, z], dim=1)
5646 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
5647
5648 self.assertEqual(cat, cat_cpu)
5649
5650 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
Abhishek Pathakd7210e62022-07-20 16:31:44 +00005651 helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
5652 helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5])
5653 helper([2, 2, 6, 5], [0], [2, 5, 6, 5])
5654 helper([0], [2, 3, 6, 5], [2, 5, 6, 5])
5655 helper([2, 3, 4, 5], [2, 5, 4, 5], [0])
5656 helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5])
5657 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
5658 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5])
Kulin Sethe011a8e2022-05-13 18:28:53 +00005659
Kulin Seth76cff182022-07-04 06:41:39 +00005660 def test_constant_pad(self):
5661 m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
5662 input_cpu = torch.randn(1, 16, 16, 16)
5663 input_mps = input_cpu.detach().clone().to("mps")
5664 r_cpu = m(input_cpu)
5665 r_mps = m(input_mps)
5666 self.assertEqual(r_cpu, r_mps.to("cpu"))
5667
Li-Huai (Allan) Lin544756a2022-12-13 17:28:54 +00005668 # Arbitrary input dimensions
5669 pad = (1, 1, 0, 0, 0, 0)
5670 value = 3.5
5671 input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
5672 input_mps = input_cpu.detach().clone().to("mps")
5673 r_cpu = F.pad(input_cpu, pad=pad, value=value)
5674 r_mps = F.pad(input_mps, pad=pad, value=value)
5675 self.assertEqual(r_cpu, r_mps.to("cpu"))
5676
Denis Vieriu0adc2e32022-07-14 19:54:15 +00005677 def test_circular_pad(self):
5678 # https://github.com/pytorch/pytorch/issues/80856
5679 k_cpu = torch.ones(3, 3, 9, 9)
5680 k_mps = k_cpu.detach().clone().to("mps")
5681
5682 x_cpu = torch.rand(1, 3, 32, 32)
5683 x_mps = x_cpu.detach().clone().to("mps")
5684
5685 x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular')
5686 x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular')
5687
5688 y_cpu = F.conv2d(x_pad_cpu, k_cpu)
5689 y_mps = F.conv2d(x_pad_mps, k_mps)
5690
5691 self.assertEqual(y_cpu, y_mps.cpu())
5692
Ramin Azarmehrbf667c62022-10-01 00:33:23 +00005693 def test_constant_pad_4d_warning(self):
5694 inputCPU = torch.rand((1, 2, 2, 2, 1, 1))
5695 inputMPS = inputCPU.detach().clone().to('mps')
5696 outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0])
5697 outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0])
5698 self.assertEqual(outputCPU, outputMPS)
5699
Kulin Sethe011a8e2022-05-13 18:28:53 +00005700 def test_pad(self):
Ramin Azarmehr38b41142022-07-29 16:34:07 +00005701 def helper(shape, padding, op, value=0):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005702 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5703 inputCPU.retain_grad()
5704 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
5705
Ramin Azarmehr38b41142022-07-29 16:34:07 +00005706 if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
5707 padCriteria = op(padding, value)
5708 else:
5709 padCriteria = op(padding)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005710 outputCPU = padCriteria(inputCPU)
5711 outputMPS = padCriteria(inputMPS)
5712 self.assertEqual(outputCPU, outputMPS)
5713
5714 # backward pass (chose 0.6 just to have the grad_output != 1)
5715 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
5716 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
5717 self.assertEqual(inputCPU.grad, inputMPS.grad)
5718
5719 # 1D Padding
5720 helper((2, 4, 3), 2, nn.ReflectionPad1d)
5721 # verify if a change in shape of input would cause problems with graph caching
5722 helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
5723 # Replication 1D
5724 helper((2, 1, 6), 3, nn.ReplicationPad1d)
Ramin Azarmehr38b41142022-07-29 16:34:07 +00005725 # Constant Pad 1D
5726 helper((2, 3, 4), 2, nn.ConstantPad1d)
Ramin Azarmehrd1be36c2022-08-22 17:07:09 +00005727 # Constant Pad 1D with single dimension input
5728 helper((16), (1, 2), nn.ConstantPad1d)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005729
5730 # 2D Padding
5731 helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
5732 # verify if a change in shape of input would cause problems with graph caching
5733 helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
5734 # this should make the padding (2, 2, 2, 2)
5735 helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
5736 # verify if a change in shape of padding would cause problems with graph caching
5737 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
Ramin Azarmehr38b41142022-07-29 16:34:07 +00005738 # Constant Pad 2D
5739 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
Ramin Azarmehr334686b2022-09-30 22:57:57 +00005740 # input size < pad size
5741 helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
Li-Huai (Allan) Lin544756a2022-12-13 17:28:54 +00005742 # pad dims < input dims
5743 helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
5744 # pad dims == input dims
5745 helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
5746 # input.numel() == 0 but output.numel() > 0
5747 helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
5748 # pad dims < input dims - 2
5749 helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005750
5751 # 3D Padding
5752 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
5753 # verify if a change in shape of padding would cause problems with graph caching
5754 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
Ramin Azarmehr50beab22023-03-17 01:41:09 +00005755 # case where input_d == pad_front/back for ReplicationPad3d
5756 helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
Ramin Azarmehr38b41142022-07-29 16:34:07 +00005757 # Constant Pad 3D
5758 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
Li-Huai (Allan) Lin544756a2022-12-13 17:28:54 +00005759 # input size < pad size
5760 helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
Ramin Azarmehr13de5a02023-01-04 22:00:37 +00005761 # check the workaround for the right padding bug in Monterey
5762 helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005763
5764 # Test stack forward
5765 def test_stack(self):
5766 # All shapes must be same
Denis Vieriue3b98ba2022-07-14 22:00:57 +00005767 def helper(shape, dtype=torch.float32):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005768
Denis Vieriue3b98ba2022-07-14 22:00:57 +00005769 x, cpu_x = None, None
5770 y, cpu_y = None, None
5771 z, cpu_z = None, None
Kulin Sethe011a8e2022-05-13 18:28:53 +00005772
Thomas4935b592022-11-23 02:18:03 +00005773 if (dtype not in [torch.float32, torch.bool]):
Denis Vieriue3b98ba2022-07-14 22:00:57 +00005774 cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
5775 x = cpu_x.detach().clone().to('mps')
5776 cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
5777 y = cpu_y.detach().clone().to('mps')
5778 cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
5779 z = cpu_z.detach().clone().to('mps')
5780 elif (dtype == torch.bool):
5781 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
5782 x = cpu_x.detach().clone().to('mps')
5783 cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
5784 y = cpu_y.detach().clone().to('mps')
5785 cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
5786 z = cpu_z.detach().clone().to('mps')
5787 else:
5788 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
5789 x = cpu_x.detach().clone().to('mps').requires_grad_()
5790 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
5791 y = cpu_y.detach().clone().to('mps').requires_grad_()
5792 cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
5793 z = cpu_z.detach().clone().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00005794
5795 stack = torch.stack([x, y, z], dim=1)
5796 stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
5797
5798 self.assertEqual(stack, stack_cpu)
5799
5800 helper([2, 8, 4, 5])
Denis Vieriue3b98ba2022-07-14 22:00:57 +00005801 helper([2, 8, 4, 5], dtype=torch.float16)
5802 helper([2, 8, 4, 5], dtype=torch.int32)
5803 helper([2, 8, 4, 5], dtype=torch.int64)
5804 helper([2, 8, 4, 5], dtype=torch.bool)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005805 # Empty test - Currently failing! Empty tensor not handled!
5806 # helper([0, 2, 4, 5])
5807
5808 # Test abs
5809 def test_abs(self):
5810 def helper(shape):
5811 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5812 x = cpu_x.detach().clone().to('mps')
5813
5814 abs_result = torch.abs(x)
5815 abs_result_cpu = torch.abs(cpu_x)
5816
5817 self.assertEqual(abs_result, abs_result_cpu)
5818
5819 helper((2, 8, 4, 5))
5820
5821 def test_log(self):
5822 def helper(shape):
5823 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5824 x = cpu_x.detach().clone().to('mps')
5825
5826 log_result = torch.log(x)
5827 log_result_cpu = torch.log(cpu_x)
5828
5829 self.assertEqual(log_result, log_result_cpu)
5830
5831 helper((2, 8, 4, 5))
5832
5833 def test_log_ten(self):
5834 def helper(shape):
5835 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5836 x = cpu_x.detach().clone().to('mps')
5837
5838 log_ten_result = torch.log10(x)
5839 log_ten_result_cpu = torch.log10(cpu_x)
5840
5841 self.assertEqual(log_ten_result, log_ten_result_cpu)
5842
5843 helper((2, 8, 4, 5))
5844
5845 def test_log_two(self):
5846 def helper(shape):
5847 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5848 x = cpu_x.detach().clone().to('mps')
5849
5850 log_two_result = torch.log2(x)
5851 log_two_result_cpu = torch.log2(cpu_x)
5852
5853 self.assertEqual(log_two_result, log_two_result_cpu)
5854
5855 helper((2, 8, 4, 5))
5856
5857 def test_log1p(self):
5858 def helper(shape):
5859 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5860 x = cpu_x.detach().clone().to('mps')
5861
5862 log_result = torch.log1p(x)
5863 log_result_cpu = torch.log1p(cpu_x)
5864
5865 self.assertEqual(log_result, log_result_cpu)
5866
5867 helper((2, 8, 4, 5))
5868
5869 def test_logaddexp(self):
5870 def helper(shape):
5871 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5872 x = cpu_x.detach().clone().to('mps')
5873
5874 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5875 y = cpu_y.detach().clone().to('mps')
5876
5877 log_result = torch.logaddexp(x, y)
5878 log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
5879
5880 self.assertEqual(log_result, log_result_cpu)
5881
5882 helper((2, 8, 4, 5))
5883
5884 def test_logaddexp2(self):
5885 def helper(shape):
5886 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5887 x = cpu_x.detach().clone().to('mps')
5888
5889 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5890 y = cpu_y.detach().clone().to('mps')
5891
5892 log_result = torch.logaddexp2(x, y)
5893 log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
5894
5895 self.assertEqual(log_result, log_result_cpu)
5896
5897 helper((2, 8, 4, 5))
5898
5899 # Test concat forward
5900 def test_cat2(self):
5901
5902 def helper1(shape_x, shape_y, shape_z, shape_w):
5903 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
5904 x = cpu_x.detach().clone().to('mps')
5905
5906 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
5907 y = cpu_y.detach().clone().to('mps')
5908
5909 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
5910 z = cpu_z.detach().clone().to('mps')
5911
5912 cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
5913 w = cpu_w.detach().clone().to('mps')
5914
5915 cat = torch.cat([x, y, z, w], dim=1)
5916 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
5917
5918 self.assertEqual(cat, cat_cpu)
5919
5920 def helper(shape_x, shape_y, shape_z):
5921 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
5922 x = cpu_x.detach().clone().to('mps')
5923
5924 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
5925 y = cpu_y.detach().clone().to('mps')
5926
5927 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
5928 z = cpu_z.detach().clone().to('mps')
5929
5930 cat = torch.cat([x, y, z], dim=1)
5931 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
5932
5933 self.assertEqual(cat, cat_cpu)
5934
5935 helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
5936 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
5937 # Empty test - Currently failing! Empty tensor not handled!
5938 # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
5939
5940 # Test isnan
5941 def test_isnan(self):
5942 def helper(shape):
5943 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5944 nan_index = [random.randrange(0, shape[0])]
5945 # make a selected row inf
5946 cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
5947 x = cpu_x.detach().clone().to('mps')
5948
5949 isnan_result = torch.isnan(x)
5950 isnan_result_cpu = torch.isnan(cpu_x)
5951
5952 self.assertEqual(isnan_result, isnan_result_cpu)
5953
5954 helper((8, 2, 4, 5))
5955
5956 # Test reciprocal
5957 def test_reciprocal(self):
5958 def helper(shape):
5959 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5960 x = cpu_x.detach().clone().to('mps').requires_grad_()
5961
5962 reciprocal_result = torch.reciprocal(x)
5963 reciprocal_result_cpu = torch.reciprocal(cpu_x)
5964
5965 cpu_grad = torch.ones_like(reciprocal_result_cpu)
5966 grad = cpu_grad.to('mps')
5967
5968 reciprocal_result.backward(gradient=grad)
5969 reciprocal_result_cpu.backward(gradient=cpu_grad)
5970
5971 self.assertEqual(reciprocal_result, reciprocal_result_cpu)
5972 self.assertEqual(x.grad, cpu_x.grad)
5973
5974 helper((2, 8, 4, 5))
5975
5976 # Test sqrt
5977 def test_sqrt(self):
5978 def helper(shape):
5979 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5980 x = cpu_x.detach().clone().to('mps').requires_grad_()
5981
5982 sqrt_result = torch.sqrt(x)
5983 sqrt_result_cpu = torch.sqrt(cpu_x)
5984
5985 cpu_grad = torch.ones_like(sqrt_result_cpu)
5986 grad = cpu_grad.to('mps')
5987
5988 sqrt_result.backward(gradient=grad)
5989 sqrt_result_cpu.backward(gradient=cpu_grad)
5990
5991 self.assertEqual(sqrt_result, sqrt_result_cpu)
5992 self.assertEqual(x.grad, cpu_x.grad)
5993
5994 helper((2, 8, 4, 5))
5995
5996 # Test selu, elu, celu
5997 def test_elu(self):
Denis Vieriu4a762cb2023-02-11 22:05:18 +00005998 def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
5999 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
6000 cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006001
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006002 x = cpu_x.detach().clone().to('mps').requires_grad_(True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006003 for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
6004 elu_result = activation_func(x)
6005 elu_result_cpu = activation_func(cpu_x)
6006
6007 cpu_grad = torch.randn(elu_result_cpu.shape)
6008 grad = cpu_grad.to('mps')
6009
6010 elu_result.backward(gradient=grad)
6011 elu_result_cpu.backward(gradient=cpu_grad)
6012
6013 self.assertEqual(elu_result, elu_result_cpu)
6014 self.assertEqual(x.grad, cpu_x.grad)
6015
6016 # Test empty shape too
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006017 for memory_fromat in [torch.channels_last, torch.contiguous_format]:
6018 for shape in [(2, 8, 4, 5)]:
6019 for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
6020 helper(shape, alpha, memory_fromat)
Kulin Setha6347f52022-06-07 18:22:10 +00006021
qqaatwc980fc32022-06-30 08:58:42 +00006022 # Test glu
6023 def test_glu(self):
6024 def helper(shape, dim=0):
6025 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6026 x = cpu_x.detach().clone().to('mps').requires_grad_()
Kulin Setha6347f52022-06-07 18:22:10 +00006027
qqaatwc980fc32022-06-30 08:58:42 +00006028 for activation_func in [torch.nn.GLU(dim=dim)]:
6029 glu_result = activation_func(x)
6030 glu_result_cpu = activation_func(cpu_x)
6031
6032 cpu_grad = torch.randn(glu_result_cpu.shape)
6033 grad = cpu_grad.to('mps')
6034
6035 glu_result.backward(gradient=grad)
6036 glu_result_cpu.backward(gradient=cpu_grad)
6037
6038 self.assertEqual(glu_result, glu_result_cpu)
6039 self.assertEqual(x.grad, cpu_x.grad)
6040
6041 for shape in [[4], (2, 4), (2, 8, 4, 6)]:
6042 for dim in range(len(shape)):
6043 helper(shape, dim)
6044
6045 # Test softplus
Kulin Setha6347f52022-06-07 18:22:10 +00006046 def test_softplus(self):
Kulin Sethca741052023-02-07 03:04:53 +00006047 def helper(shape, beta=1, threshold=20):
Kulin Setha6347f52022-06-07 18:22:10 +00006048 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6049 x = cpu_x.detach().clone().to('mps').requires_grad_()
6050
Li-Huai (Allan) Lin7c353eb2022-11-10 09:40:05 +00006051 softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
6052 softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
Kulin Setha6347f52022-06-07 18:22:10 +00006053
qqaatw87451182022-07-06 06:13:21 +00006054 cpu_grad = torch.randn(softplus_result.shape)
6055 grad = cpu_grad.to('mps')
6056
6057 softplus_result.backward(gradient=grad)
6058 softplus_result_cpu.backward(gradient=cpu_grad)
6059
Kulin Setha6347f52022-06-07 18:22:10 +00006060 self.assertEqual(softplus_result, softplus_result_cpu)
qqaatw87451182022-07-06 06:13:21 +00006061 self.assertEqual(x.grad, cpu_x.grad)
Kulin Setha6347f52022-06-07 18:22:10 +00006062
6063 # Test empty shape too
6064 for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
Kulin Sethca741052023-02-07 03:04:53 +00006065 for beta in [0.5, 1, 2, 3, 4]:
6066 for threshold in [0.5, 20, 30, 40, 50]:
6067 helper(shape, beta, threshold)
Kulin Setha6347f52022-06-07 18:22:10 +00006068
Kulin Sethe011a8e2022-05-13 18:28:53 +00006069 # Test silu
6070
6071 def test_silu(self):
6072 def helper(shape):
6073 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6074 x = cpu_x.detach().clone().to('mps').requires_grad_()
6075
6076 silu_result = torch.nn.SiLU()(x)
6077 silu_result_cpu = torch.nn.SiLU()(cpu_x)
6078
6079 cpu_grad = torch.randn(silu_result_cpu.shape)
6080 grad = cpu_grad.to('mps')
6081
6082 silu_result.backward(gradient=grad)
6083 silu_result_cpu.backward(gradient=cpu_grad)
6084
6085 self.assertEqual(silu_result, silu_result_cpu)
6086 self.assertEqual(x.grad, cpu_x.grad)
6087
6088 # Test empty shape too
6089 for shape in [[], (2, 3), (2, 8, 4, 5)]:
6090 helper(shape)
6091
Denis Vieriu4247cc92022-09-14 17:24:24 +00006092 def test_cast_mps_to_cpu(self):
6093 def helper(src_dtype, dst_dtype):
6094 input = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6095 input_cast_mps = input.to('mps')
6096 input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype)
6097
6098 # needs to match the initial Tensor
6099 self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype))
6100 helper(torch.half, torch.float)
6101 helper(torch.float, torch.half)
6102
6103 def test_cast_mps_to_mps(self):
6104 def helper(src_dtype, dst_dtype):
6105 input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6106 input_mps = input_cpu.to('mps')
6107 output_mps = input_mps.to(dtype=dst_dtype)
6108 output_cpu = input_cpu.to(dtype=dst_dtype)
6109 self.assertEqual(output_mps.cpu(), output_cpu)
6110 helper(torch.half, torch.float)
6111 helper(torch.float, torch.half)
6112 helper(torch.half, torch.long)
6113 helper(torch.float, torch.int)
6114
Ramin Azarmehr6c80d0a2023-02-09 02:06:40 +00006115 def test_avg_pool2d_count_include_pad(self):
6116 cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True)
6117 x = cpu_x.detach().clone().to('mps').requires_grad_()
6118 pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True)
6119 ref_y = pool(cpu_x)
6120 y = pool(x)
6121 self.assertEqual(y, ref_y)
6122 cpu_grad = torch.randn(ref_y.shape)
6123 grad = cpu_grad.to('mps')
6124 ref_y.backward(gradient=cpu_grad)
6125 y.backward(gradient=grad)
6126 self.assertEqual(x.grad, cpu_x.grad)
6127
Kulin Sethe011a8e2022-05-13 18:28:53 +00006128 # Test adaptive avg pool2d - when the input size is a multiple of output size
6129 # Not testing for channels last right now
6130 def test_adaptive_avg_pool2d_simple(self):
6131 def helper(input_shape, out_shape, channels_last):
6132 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00006133 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006134 cpu_x = cpu_x.to(memory_format=torch.channels_last)
6135 cpu_x.retain_grad()
6136 x = cpu_x.detach().clone().to('mps').requires_grad_()
6137
6138 avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
6139 avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
6140
6141 cpu_grad = torch.randn(avg_result_cpu.shape)
6142 grad = cpu_grad.to('mps')
6143
6144 avg_result.backward(gradient=grad)
6145 avg_result_cpu.backward(gradient=cpu_grad)
6146
6147 self.assertEqual(avg_result, avg_result_cpu)
6148 self.assertEqual(x.grad, cpu_x.grad)
6149
6150 helper((2, 2, 4, 4), (2, 2), False)
6151 helper((2, 2, 9, 9), (3, 3), False)
6152 helper((2, 2, 9, 9), (9, 9), False)
6153 helper((2, 2, 16, 16), (2, 2), False)
6154 helper((2, 2, 16, 16), (2, 16), False)
6155
6156 helper((2, 16, 16), (4, 4), False)
6157
Abhishek Pathake746fff2022-09-27 19:08:22 +00006158 # Output shape larger than input shape
6159
6160 helper((2, 2, 4, 4), (8, 8), False)
6161 helper((2, 2, 2, 2), (4, 4), False)
6162 helper((2, 2, 3, 3), (9, 9), False)
6163 helper((2, 2, 2, 2), (16, 16), False)
6164 helper((2, 2, 2, 16), (16, 16), False)
6165
6166 helper((2, 4, 4), (16, 16), False)
6167
6168 try:
6169 helper((2, 2, 3, 3), (7, 7), False)
6170 except Exception as e:
6171 pass
6172
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006173 # Test max avg pool2d - when the input size is a multiple of output size
6174 # Not testing for channels last right now
6175 def test_adaptive_max_pool2d_simple(self):
6176 def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
6177 cpu_x = None
Thomas4935b592022-11-23 02:18:03 +00006178 if (dtype in [torch.float16, torch.float32]):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006179 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
6180 else:
6181 cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00006182 if (channels_last):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006183 cpu_x = cpu_x.to(memory_format=torch.channels_last)
6184 cpu_x.retain_grad()
6185 x = cpu_x.detach().clone().to('mps').requires_grad_()
6186
6187 max_result, max_indices = None, None
6188 max_result_cpu, max_indices_cpu = None, None
6189
Thomas4935b592022-11-23 02:18:03 +00006190 if (return_indices):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006191 max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6192 max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6193 else:
6194 max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6195 max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6196
6197 cpu_grad = torch.randn(max_result_cpu.shape)
6198 grad = cpu_grad.to('mps')
6199
6200 max_result.backward(gradient=grad)
6201 max_result_cpu.backward(gradient=cpu_grad)
6202
6203 self.assertEqual(max_result, max_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00006204 if (return_indices):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006205 self.assertEqual(max_indices, max_indices_cpu)
6206 self.assertEqual(x.grad, cpu_x.grad)
6207
6208 for dtype in [torch.float32]:
6209 for return_indices in [False, True]:
6210 helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
6211 helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
6212 helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
6213 helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
6214 helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
6215 helper((2, 16, 16), (4, 4), return_indices, dtype)
6216
Kulin Sethe011a8e2022-05-13 18:28:53 +00006217 def test_gelu_simple(self):
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006218 def helper(shape, dtype=torch.float):
6219 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006220 x = cpu_x.detach().clone().to('mps').requires_grad_()
6221
6222 gelu_result = torch.nn.GELU()(x)
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006223 # GELU is not supported on CPU, so cast it to float
6224 gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float))
Kulin Sethe011a8e2022-05-13 18:28:53 +00006225
6226 cpu_grad = torch.ones_like(gelu_result_cpu)
6227 grad = cpu_grad.to('mps')
6228
6229 gelu_result.backward(gradient=grad)
6230 gelu_result_cpu.backward(gradient=cpu_grad)
6231
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006232 atol = 1e-5 if dtype == torch.float else 1e-2
6233 rtol = 1e-3 if dtype == torch.float else 1e-2
6234 self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
6235 self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006236
6237 # Test empty shape too
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006238 for dtype in [torch.float, torch.half]:
6239 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
6240 helper(shape, dtype)
6241 # Test that gelu would raise an assert for integral types
6242 for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
6243 self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
Kulin Sethe011a8e2022-05-13 18:28:53 +00006244
Kulin Seth3d833212022-05-20 03:18:09 +00006245 def test_gelu(self):
6246 def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
6247 numpy_dtype = {
6248 torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
6249 }[dtype]
6250 devices = ['cpu']
6251 devices += ['mps']
6252
6253 def _gelu_ref(X):
6254 return X * stats.norm.cdf(X)
6255
6256 for d in devices:
6257 X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
6258 res = X
6259 ref = (X.to(numpy_dtype).cpu().detach().numpy())
6260 self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
6261
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006262 for n in [1, 5, 10]:
6263 for m in [1, 5, 10]:
Kulin Seth3d833212022-05-20 03:18:09 +00006264 _test_gelu(n, m, torch.float32, True)
6265 _test_gelu(n, m, torch.float32, False)
6266
6267 # Test multi threaded
6268 num_threads = torch.get_num_threads()
6269 torch.set_num_threads(4)
6270 try:
6271 _test_gelu(32, 32, torch.float32, False)
6272 finally:
6273 torch.set_num_threads(num_threads)
6274
Denis Vieriu7ce785b2023-02-11 00:24:30 +00006275 def test_gelu_tanh(self):
6276 def helper(shape):
6277 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
6278 x = cpu_x.detach().clone().to('mps')
6279
6280 gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
6281 gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
6282 self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
6283
6284 helper((2, 8, 4, 5))
6285
Kulin Sethe011a8e2022-05-13 18:28:53 +00006286 # Test hardtanh
6287 def test_hardtanh(self):
6288 def helper(shape, min_val, max_val, inplace=False):
6289 cpu_x = None
6290 x = None
6291
Thomas4935b592022-11-23 02:18:03 +00006292 if (not inplace):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006293 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6294 x = cpu_x.detach().clone().to('mps').requires_grad_()
6295 else:
6296 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6297 x = cpu_x.detach().clone().to('mps')
6298
6299 hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
6300 hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
6301
6302 self.assertEqual(hardtanh_result, hardtanh_result_cpu)
6303
Thomas4935b592022-11-23 02:18:03 +00006304 if (not inplace):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006305 cpu_grad = torch.randn(hardtanh_result_cpu.shape)
6306 grad = cpu_grad.to('mps')
6307 hardtanh_result.backward(gradient=grad)
6308 hardtanh_result_cpu.backward(gradient=cpu_grad)
6309 self.assertEqual(x.grad, cpu_x.grad)
6310
6311 # Test empty shape too
6312 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
6313 for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
6314 helper(shape, min_val, max_val)
6315 helper(shape, min_val, max_val, inplace=True)
6316
Thomas4935b592022-11-23 02:18:03 +00006317 def test_hardswish(self):
6318 def helper(shape, inplace=False, requires_grad=True):
6319 m = nn.Hardswish(inplace=inplace)
6320
6321 input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad)
6322 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
6323
6324 if inplace and requires_grad: # check that both raise runtime error
6325 self.assertRaises(RuntimeError, lambda: m(input_cpu))
6326 self.assertRaises(RuntimeError, lambda: m(input_mps))
6327 return
6328
6329 output_cpu = m(input_cpu)
6330 output_mps = m(input_mps)
6331
6332 cpu_grad = torch.ones_like(output_cpu)
6333 mps_grad = cpu_grad.to('mps')
6334
6335 self.assertEqual(output_cpu, output_mps)
6336
6337 if requires_grad:
6338 output_cpu.backward(gradient=cpu_grad)
6339 output_mps.backward(gradient=mps_grad)
6340
6341 self.assertEqual(input_cpu.grad, input_mps.grad)
6342
6343 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
6344 helper(shape, inplace=False, requires_grad=False)
6345 helper(shape, inplace=True, requires_grad=False)
6346 helper(shape, inplace=False, requires_grad=True)
6347 helper(shape, inplace=True, requires_grad=True)
6348
Kulin Seth3d833212022-05-20 03:18:09 +00006349 def test_transpose_2D(self):
6350 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
6351 values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
6352 cpu_x = torch.tensor(values, device='cpu')
6353 mps_x = torch.tensor(values, device='mps')
6354 mps_x1 = torch.tensor(values1, device='mps')
6355
6356 cpu_transpose = torch.transpose(cpu_x, 0, 1)
6357 mps_transpose = torch.transpose(mps_x, 0, 1)
6358 self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
6359
6360 def test_transpose_3D(self):
6361 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
6362 cpu_x = torch.tensor(values, device='cpu')
6363 mps_x = torch.tensor(values, device='mps')
6364
6365 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
6366 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
6367 self.assertEqual(cpu_transpose1, mps_transpose1)
6368
6369 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
6370 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
6371 self.assertEqual(cpu_transpose2, mps_transpose2)
6372
6373 cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
6374 mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
6375 self.assertEqual(cpu_transpose3, mps_transpose3)
6376
6377
6378 def test_transpose_4D(self):
6379 values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
6380 [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
6381 cpu_x = torch.tensor(values, device='cpu')
6382 mps_x = torch.tensor(values, device='mps')
6383
6384 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
6385 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
6386 self.assertEqual(cpu_transpose1, mps_transpose1)
6387
6388 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
6389 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
6390 self.assertEqual(cpu_transpose2, mps_transpose2)
6391
6392 cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
6393 mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
6394 self.assertEqual(cpu_transpose3, mps_transpose3)
6395
6396 cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
6397 mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
6398 self.assertEqual(cpu_transpose4, mps_transpose4)
6399
6400 cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
6401 mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
6402 self.assertEqual(cpu_transpose5, mps_transpose5)
6403
6404 cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
6405 mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
6406 self.assertEqual(cpu_transpose6, mps_transpose6)
6407
Kulin Sethe011a8e2022-05-13 18:28:53 +00006408 # Test sign
6409 def test_sign(self):
6410 def helper(shape):
6411 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6412 x = cpu_x.detach().clone().to('mps').requires_grad_()
6413
6414 sign_result = torch.sign(x)
6415 sign_result_cpu = torch.sign(cpu_x)
6416
6417 cpu_grad = torch.ones_like(sign_result_cpu)
6418 grad = cpu_grad.to('mps')
6419
6420 sign_result.backward(gradient=grad)
6421 sign_result_cpu.backward(gradient=cpu_grad)
6422
6423 self.assertEqual(sign_result, sign_result_cpu)
6424
6425 helper((2, 8, 4, 5))
6426
Daniel Falbele8185742022-10-25 07:12:28 +00006427 def test_signbit(self):
6428 def helper(shape, dtype):
6429 cpu_x = torch.randn(shape, device='cpu').to(dtype)
6430 x = cpu_x.clone().to('mps')
6431
6432 signbit_result = torch.signbit(x)
6433 signbit_result_cpu = torch.signbit(cpu_x)
6434
6435 self.assertEqual(signbit_result, signbit_result_cpu)
6436
6437 helper((2, 8, 4, 5), torch.int)
6438 helper((2, 8, 4, 5), torch.float)
6439 helper((2, 8, 4, 5), torch.int64)
6440
Kulin Sethe011a8e2022-05-13 18:28:53 +00006441 # Test neg
6442 def test_neg(self):
6443 def helper(shape):
6444 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6445 x = cpu_x.detach().clone().to('mps').requires_grad_()
6446
6447 neg_result = torch.neg(x)
6448 neg_result_cpu = torch.neg(cpu_x)
6449
6450 cpu_grad = torch.ones_like(neg_result_cpu)
6451 grad = cpu_grad.to('mps')
6452
6453 neg_result.backward(gradient=grad)
6454 neg_result_cpu.backward(gradient=cpu_grad)
6455
6456 self.assertEqual(neg_result, neg_result_cpu)
6457
6458 helper((2, 8, 4, 5))
6459
qqaatw1caa25e2022-07-14 23:40:00 +00006460 # Test index add
6461 def test_index_add(self):
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00006462 def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):
6463 cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False)
qqaatw1caa25e2022-07-14 23:40:00 +00006464 x = cpu_x.detach().clone().to('mps')
6465
6466 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
6467 idx = cpu_idx.detach().clone().to('mps')
6468
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00006469 cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False)
qqaatw1caa25e2022-07-14 23:40:00 +00006470 source = cpu_source.detach().clone().to('mps')
6471
6472 idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha)
6473 idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha)
6474 self.assertEqual(idx_result, idx_result_cpu)
6475
6476 helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5)
6477 helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0)
6478 helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5)
6479 helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0)
6480 helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4)
6481 helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0)
6482 # test result dim=1
6483 helper((2,), 0, [1], (1,), 6.0)
6484 helper(2, 0, 1, 1, 6)
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00006485 # test float16
6486 helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16)
qqaatw1caa25e2022-07-14 23:40:00 +00006487
qqaatwc4da23e2022-06-28 19:51:43 +00006488 # Test flip
6489 def test_flip(self):
6490 def helper(shape, dims):
6491 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6492 x = cpu_x.detach().clone().to('mps')
6493
6494 flip_result = torch.flip(x, dims=dims)
6495 flip_result_cpu = torch.flip(cpu_x, dims=dims)
6496
6497 self.assertEqual(flip_result, flip_result_cpu)
6498
6499 helper((2, 8, 4, 5), [0])
6500 helper((8, 8, 4, 5), [0, 1])
6501 helper((2, 8, 4, 5), (0, 1, 2, 3))
6502 helper((2, 3, 3), (-1,))
6503 # empty dims
6504 helper((2, 8, 4, 5), [])
6505 # input.numel() == 1
6506 helper((1,), (0,))
6507 # input.numel() == 0
6508 helper((0,), (0,))
Li-Huai (Allan) Linc95bcb62023-03-14 00:34:26 +00006509 # none of dims that needs to be flipped
6510 helper((1, 3), [0])
qqaatwc4da23e2022-06-28 19:51:43 +00006511
Kulin Sethe011a8e2022-05-13 18:28:53 +00006512 # Test index select
6513 def test_index_select(self):
6514 def helper(shape, dim, index, idx_dtype=torch.int32):
6515 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6516 x = cpu_x.detach().clone().to('mps')
6517
6518 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
6519 idx = cpu_idx.detach().clone().to('mps')
6520
Kulin Sethe011a8e2022-05-13 18:28:53 +00006521 idx_result = torch.index_select(x, dim=dim, index=idx)
6522 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
6523
6524 self.assertEqual(idx_result, idx_result_cpu)
6525
6526 helper((2, 8, 4, 5), 0, [1])
6527 helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
6528 helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
6529 helper((2, 8, 4, 5), 2, [3, 0, 1])
6530 helper((2, 8, 4, 5), 3, [2, 3, 0])
6531 helper((2, 3, 3), -1, [1, 2])
Li-Huai (Allan) Linccbdf492023-01-19 14:08:02 +00006532 helper((), 0, [0])
Nikita Shulga8a888522023-02-05 05:45:57 +00006533 helper((5), 0, [])
Li-Huai (Allan) Linccbdf492023-01-19 14:08:02 +00006534
6535 def test_index_select_scalar(self):
6536 def helper(value, dim, index, idx_dtype=torch.int32):
6537 cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
6538 x = cpu_x.detach().clone().to('mps')
6539
6540 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
6541 idx = cpu_idx.detach().clone().to('mps')
6542
6543 idx_result = torch.index_select(x, dim=dim, index=idx)
6544 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
6545
6546 self.assertEqual(idx_result, idx_result_cpu)
6547
Li-Huai (Allan) Lin4afef852023-03-28 19:23:55 +00006548 helper(22, 0, [0])
6549 with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
6550 helper(22, 0, [])
Kulin Sethe011a8e2022-05-13 18:28:53 +00006551
6552 def test_embedding_dense_backward(self):
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00006553 def helper(n, d, m, idx):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006554 embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
Nikita Shulga62ef15e2022-11-10 23:52:27 +00006555 emedding_weight = embeddingMPS.weight.detach().cpu()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006556 W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
Nikita Shulga62ef15e2022-11-10 23:52:27 +00006557 idx_MPS = torch.tensor(idx, device='mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00006558 a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable
6559 a_MPS.retain_grad()
6560 b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place
6561 b_MPS.retain_grad()
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00006562 out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006563 loss_MPS = out_MPS.sigmoid().prod()
6564 loss_MPS.backward()
6565
Nikita Shulga62ef15e2022-11-10 23:52:27 +00006566 embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006567 W_CPU = W_MPS.to('cpu')
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00006568 idx_CPU = torch.tensor(idx)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006569 a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable
6570 a_CPU.retain_grad()
6571 b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place
6572 b_CPU.retain_grad()
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00006573 out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006574 loss_CPU = out_CPU.sigmoid().prod()
6575 loss_CPU.backward()
6576
6577 self.assertEqual(b_CPU.grad, b_MPS.grad)
6578 self.assertEqual(a_CPU.grad, a_MPS.grad)
6579
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00006580 helper(3, 5, 7, [0, 1, 2])
6581 helper(3, 5, 7, 2) # test scalar index
Kulin Sethe011a8e2022-05-13 18:28:53 +00006582
6583 # Test pytorch gather
6584 def test_gather(self):
6585 def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
6586 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6587 x = cpu_x.detach().clone().to('mps').requires_grad_()
6588
6589 # Indices should be taken from range of axis along which gathering is done
6590 idx_np = np.random.randint(0, shape[dim], idx_shape)
6591
6592 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
6593 idx = cpu_idx.detach().clone().to('mps')
6594
6595 gather_result = torch.gather(x, dim=dim, index=idx)
6596 gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
6597
6598 cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
6599 grad = cpu_grad.to('mps')
6600 gather_result.backward(gradient=grad)
6601 gather_result_cpu.backward(gradient=cpu_grad)
6602
6603 self.assertEqual(gather_result, gather_result_cpu)
6604 self.assertEqual(cpu_x.grad, x.grad)
6605
6606 helper((6, 3, 3), 0, (3, 3, 3))
6607 helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
6608 helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
6609 helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
6610 helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
6611 helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
6612 helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
6613 helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
6614 helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
6615
Abhishek Pathak81b366a2022-09-30 00:24:16 +00006616 # Test pytorch gather
6617 def test_gather_scalar(self):
6618 idx_dtype = torch.int64
6619 cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
6620 x = cpu_x.detach().clone().to('mps').requires_grad_()
6621
6622 idx_np = [0]
6623
6624 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
6625 idx = cpu_idx.detach().clone().to('mps')
6626
6627 gather_result = torch.gather(x, dim=0, index=idx)
6628 gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
6629
6630 cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
6631 grad = cpu_grad.to('mps')
6632 gather_result.backward(gradient=grad)
6633 gather_result_cpu.backward(gradient=cpu_grad)
6634
6635 self.assertEqual(gather_result, gather_result_cpu)
6636 self.assertEqual(cpu_x.grad, x.grad)
6637
Kulin Sethe011a8e2022-05-13 18:28:53 +00006638 # Test pytorch scatter_add and scatter
6639 def test_scatter_add(self):
6640 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
6641 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6642 x = cpu_x.detach().clone().to('mps').requires_grad_()
6643
6644 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
6645 src = cpu_src.detach().clone().to('mps').requires_grad_()
6646
6647 # Indices should be taken from range of axis along which gathering is done
6648 idx_np = None
Thomas4935b592022-11-23 02:18:03 +00006649 if (do_add):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006650 idx_np = np.random.randint(0, shape[dim], idx_shape)
6651 else:
6652 idx_np = np.array([[0, 1, 2],
6653 [1, 2, 3],
6654 [2, 3, 4],
6655 [3, 4, 5],
6656 [4, 5, 6]])
6657
6658 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
6659 idx = cpu_idx.detach().clone().to('mps')
6660
6661 scatter_result = None
6662 scatter_result_cpu = None
6663
Thomas4935b592022-11-23 02:18:03 +00006664 if (do_add):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006665 scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
6666 scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
6667 else:
6668 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
6669 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
6670
6671 cpu_grad = None
6672 grad = None
6673
Thomas4935b592022-11-23 02:18:03 +00006674 if (idx_shape == src_shape):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006675 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
6676 grad = cpu_grad.to('mps')
6677 scatter_result.backward(gradient=grad)
6678 scatter_result_cpu.backward(gradient=cpu_grad)
6679
6680 self.assertEqual(scatter_result, scatter_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00006681 if (idx_shape == src_shape):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006682 self.assertEqual(cpu_x.grad, x.grad)
6683 self.assertEqual(cpu_src.grad, src.grad)
6684
6685 helper((2, 3), 0, (5, 3), (5, 3))
6686 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
6687 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
6688 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
6689 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
6690 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
6691
6692 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
6693 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
6694 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
6695 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
6696
6697 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
6698 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
6699 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
6700
6701 # Test scatter src
6702 helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
6703 helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
6704
Abhishek Pathak81b366a2022-09-30 00:24:16 +00006705 # Test pytorch scatter_add and scatter for scalar input
6706 def test_scatter_add_scalar(self):
6707 def helper(idx_dtype=torch.int64, do_add=True):
6708 cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
6709 x = cpu_x.detach().clone().to('mps').requires_grad_()
6710
6711 cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
6712 src = cpu_src.detach().clone().to('mps').requires_grad_()
6713
6714 # Indices should be taken from range of axis along which gathering is done
6715 idx_np = [0]
6716
6717 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
6718 idx = cpu_idx.detach().clone().to('mps')
6719
6720 scatter_result = None
6721 scatter_result_cpu = None
6722
Thomas4935b592022-11-23 02:18:03 +00006723 if (do_add):
Abhishek Pathak81b366a2022-09-30 00:24:16 +00006724 scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
6725 scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
6726 else:
6727 scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
6728 scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
6729
6730 cpu_grad = None
6731 grad = None
6732
6733 cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
6734 grad = cpu_grad.to('mps')
6735 scatter_result.backward(gradient=grad)
6736 scatter_result_cpu.backward(gradient=cpu_grad)
6737
6738 self.assertEqual(scatter_result, scatter_result_cpu)
6739 self.assertEqual(cpu_x.grad, x.grad)
6740 self.assertEqual(cpu_src.grad, src.grad)
6741
6742 helper()
6743 helper(do_add=False)
6744
Kulin Sethe011a8e2022-05-13 18:28:53 +00006745 # Test pytorch scatter_reduce
6746 def test_scatter_reduce(self):
6747 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
6748 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6749 x = cpu_x.detach().clone().to('mps').requires_grad_()
6750
6751 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
6752 src = cpu_src.detach().clone().to('mps').requires_grad_()
6753
6754 # Indices should be taken from range of axis along which gathering is done
6755 idx_np = np.random.randint(0, shape[dim], idx_shape)
6756
6757 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
6758 idx = cpu_idx.detach().clone().to('mps')
6759
6760 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
6761 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
6762
6763 self.assertEqual(scatter_result, scatter_result_cpu)
6764
6765 # for reduce in ["sum", "prod", "amax", "amin"]:
Denis Vieriu4acdc442023-02-13 23:31:06 +00006766 for reduce_type in ["add", "multiply"]:
6767 helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
6768 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
6769 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
6770 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
6771 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
6772 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006773
Denis Vieriu4acdc442023-02-13 23:31:06 +00006774 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
6775 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
6776 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
6777 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006778
Denis Vieriu4acdc442023-02-13 23:31:06 +00006779 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
6780 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
6781 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006782
6783 def test_is_nonzero(self):
6784 self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
6785 self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
6786 self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
6787 self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
6788
6789 # Test triu
6790 def test_triu(self):
6791 def helper(shape, diag=0):
6792 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6793 x = cpu_x.detach().clone().to('mps').requires_grad_()
6794
6795 triu_result = torch.triu(x, diag)
6796 triu_result_cpu = torch.triu(cpu_x, diag)
6797
6798 cpu_grad = torch.randn(triu_result_cpu.shape)
6799 grad = cpu_grad.to('mps')
6800
6801 triu_result.backward(gradient=grad)
6802 triu_result_cpu.backward(gradient=cpu_grad)
6803
6804 self.assertEqual(triu_result, triu_result_cpu)
6805 self.assertEqual(x.grad, cpu_x.grad)
6806
6807 helper((2, 8, 4, 5))
6808 helper((2, 8, 4, 5), diag=1)
6809 helper((2, 8, 4, 5), diag=2)
6810 helper((2, 8, 4, 5), diag=3)
6811 helper((2, 8, 4, 5), diag=-1)
6812 helper((2, 8, 4, 5), diag=-2)
6813 helper((2, 8, 4, 5), diag=-3)
6814
Kulin Seth8ecb49b2022-12-19 22:00:07 +00006815 # Test inverse
6816 def test_inverse(self):
6817 def helper(n):
6818 cpu_input = torch.randn(n, n, device='cpu')
6819 mps_input = cpu_input.to('mps')
6820
6821 cpu_result = torch.linalg.inv(cpu_input)
6822 mps_result = torch.linalg.inv(mps_input)
6823 self.assertEqual(cpu_result, mps_result)
6824
6825 helper(2)
6826 helper(6)
6827 helper(3)
6828 helper(8)
6829
Kulin Sethe011a8e2022-05-13 18:28:53 +00006830 # Test tril
6831 def test_tril(self):
6832 def helper(shape, diag=0):
6833 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6834 x = cpu_x.detach().clone().to('mps').requires_grad_()
6835
6836 tril_result = torch.tril(x, diag)
6837 tril_result_cpu = torch.tril(cpu_x, diag)
6838
6839 cpu_grad = torch.randn(tril_result_cpu.shape)
6840 grad = cpu_grad.to('mps')
6841
6842 tril_result.backward(gradient=grad)
6843 tril_result_cpu.backward(gradient=cpu_grad)
6844
6845 self.assertEqual(tril_result, tril_result_cpu)
6846 self.assertEqual(x.grad, cpu_x.grad)
6847
6848 helper((2, 8, 4, 5))
6849 helper((2, 8, 4, 5), diag=1)
6850 helper((2, 8, 4, 5), diag=2)
6851 helper((2, 8, 4, 5), diag=3)
6852 helper((2, 8, 4, 5), diag=-1)
6853 helper((2, 8, 4, 5), diag=-2)
6854 helper((2, 8, 4, 5), diag=-3)
6855
Kulin Seth8552acb2022-05-27 17:07:02 +00006856 # test eye
6857 def test_eye(self):
6858 def helper(n, m, dtype):
6859 cpu_result = None
6860 result = None
6861
Thomas4935b592022-11-23 02:18:03 +00006862 if (n == m):
Kulin Seth8552acb2022-05-27 17:07:02 +00006863 cpu_result = torch.eye(n, dtype=dtype, device='cpu')
6864 result = torch.eye(n, dtype=dtype, device='mps')
6865 else:
6866 cpu_result = torch.eye(n, m, device='cpu')
6867 result = torch.eye(n, m, device='mps')
6868
6869 self.assertEqual(result, cpu_result)
6870
Li-Huai (Allan) Lin100641aa2023-03-20 18:08:36 +00006871 for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]:
Kulin Seth8552acb2022-05-27 17:07:02 +00006872 helper(2, 2, dtype)
6873 helper(2, 3, dtype)
6874 helper(0, 2, dtype)
6875 helper(0, 0, dtype)
6876 helper(3, 8, dtype)
6877 helper(8, 3, dtype)
6878
Kulin Sethe011a8e2022-05-13 18:28:53 +00006879 # Test diag
6880 def test_diag(self):
6881 def helper(shape, diag=0):
6882 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6883 x = cpu_x.detach().clone().to('mps').requires_grad_()
6884
6885 diag_result = torch.diag(x, diag)
6886 diag_result_cpu = torch.diag(cpu_x, diag)
6887
6888 # cpu_grad = torch.randn(diag_result_cpu.shape)
6889 # grad = cpu_grad.to('mps')
6890
6891 # diag_result.backward(gradient=grad)
6892 # diag_result_cpu.backward(gradient=cpu_grad)
6893
6894 self.assertEqual(diag_result, diag_result_cpu)
6895 # self.assertEqual(x.grad, cpu_x.grad)
6896
6897 for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
6898 for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
6899 helper(shape, diag=diag)
6900
Kulin Setha3bdafe2022-06-01 13:47:14 +00006901 # Test linspace
6902 def test_linspace(self):
6903 def helper(start, end, steps, dtype=torch.float32):
6904 cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
6905 result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
6906 self.assertEqual(cpu_result, result)
6907
6908 for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
6909 helper(2, 5, 10, dtype)
6910 helper(2, 2, 10, dtype)
6911 helper(5, 2, 10, dtype)
6912 helper(2, 2, 0, dtype)
6913
Nikita Shulga55cac222022-06-03 21:54:41 +00006914 # Test argange
6915 def test_arange(self):
6916 self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
6917 self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
6918 self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
6919 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
6920
Kulin Sethf35f1232023-02-09 19:30:14 +00006921 def test_arange_empty(self):
6922 out_mps = torch.tensor([], device="mps")
6923 out_cpu = torch.tensor([], device="cpu")
6924
6925 y_mps = torch.arange(0, 0, 1, out=out_mps)
6926 y_cpu = torch.arange(0, 0, 1, out=out_cpu)
6927 self.assertEqual(y_mps, y_cpu)
6928
OwenPendrighElliott840fb742023-02-13 23:19:06 +00006929 # Test rgange
6930 def test_range(self):
6931 self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
6932 self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
6933 self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
6934 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
6935
Kulin Sethe011a8e2022-05-13 18:28:53 +00006936 # Test softmax
6937 def test_softmax(self):
6938 def helper(shape, dim, channels_last=False):
6939 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00006940 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006941 cpu_x = cpu_x.to(memory_format=torch.channels_last)
6942 cpu_x.retain_grad()
6943 x = cpu_x.detach().clone().to('mps').requires_grad_()
6944
6945 softmax_result = torch.nn.functional.softmax(x, dim=dim)
6946 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
6947
6948 # Currently NOT testing backward for channels last backward
6949 cpu_grad = None
6950 grad = None
6951
Thomas4935b592022-11-23 02:18:03 +00006952 if (not channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006953 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
6954 grad = cpu_grad.to('mps')
6955
6956 softmax_result.backward(gradient=grad)
6957 softmax_result_cpu.backward(gradient=cpu_grad)
6958
6959 self.assertEqual(softmax_result, softmax_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00006960 if (not channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006961 self.assertEqual(x.grad, cpu_x.grad)
6962
6963 def helper2(dim):
6964 cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
6965 x = cpu_x.detach().clone().to('mps').requires_grad_()
6966
6967 softmax_result = torch.nn.functional.softmax(x, dim=dim)
6968 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
6969
6970 cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
6971 grad = cpu_grad.to('mps')
6972
6973 softmax_result.backward(gradient=grad)
6974 softmax_result_cpu.backward(gradient=cpu_grad)
6975
6976 self.assertEqual(softmax_result, softmax_result_cpu)
6977 self.assertEqual(x.grad, cpu_x.grad)
6978
6979 helper2(0)
6980
Kulin Seth3d833212022-05-20 03:18:09 +00006981 for channels_last in [False]:
Kulin Sethe011a8e2022-05-13 18:28:53 +00006982 for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
Thomas4935b592022-11-23 02:18:03 +00006983 if (len(shape) != 4 and channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006984 continue
6985 for dim in [0, 1, 2, 3, -1, -2, -3]:
6986 helper(shape, dim, channels_last)
6987
Ramin Azarmehr229f12b2023-01-05 02:17:48 +00006988 def test_nan_to_num(self):
6989 inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
6990 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6991 outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0)
6992 outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0)
6993 self.assertEqual(outputMPS, outputCPU)
6994
Kulin Sethe011a8e2022-05-13 18:28:53 +00006995 # Test where
6996 def test_where(self):
6997 def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
6998
6999 cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
7000 cond = cpu_cond.detach().clone().to('mps')
7001
7002 cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7003 x = cpu_x.detach().clone().to('mps').requires_grad_()
7004
7005 cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7006 y = cpu_y.detach().clone().to('mps').requires_grad_()
7007
7008 cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
7009 out = torch.where(cond, x, y)
7010
7011 cpu_grad = torch.randn(cpu_out.shape)
7012 grad = cpu_grad.to('mps')
7013
7014 cpu_out.backward(gradient=cpu_grad)
7015 out.backward(gradient=grad)
7016
7017 self.assertEqual(out, cpu_out)
7018 self.assertEqual(x.grad, cpu_x.grad)
7019 self.assertEqual(y.grad, cpu_y.grad)
7020
7021 for shape in ([(0, 3), [], (2, 3), (9,)]):
7022 helper(shape, shape, shape)
7023
7024 helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
7025 helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
7026 helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
7027 helper([], (1, 1, 4), (2, 3, 1))
7028 helper([], (2, 3, 4), [])
Alexca69ddb2022-10-07 01:38:57 +00007029 helper((5, 2, 3), (2, 3), (2, 3))
7030 helper((2, 3), (5, 2, 3), (2, 3))
7031 helper((2, 3), (2, 3), (5, 2, 3))
7032 helper((2, 3), (5, 2, 3), (6, 5, 2, 3))
Kulin Sethe011a8e2022-05-13 18:28:53 +00007033
7034 # Test normal
7035 def test_normal(self):
7036 def helper(shape, mean=0.0, std=1.0):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007037 mps_out = torch.normal(mean, std, shape, device='mps')
7038
Kulin Sethe011a8e2022-05-13 18:28:53 +00007039 mean_array = np.ones(shape)
7040 mean_array *= mean
7041 cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
7042 mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
7043
7044 std_array = np.ones(shape)
7045 std_array *= std
7046 cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
7047 std_tensor = cpu_std_tensor.detach().clone().to('mps')
7048
qqaatwe1b15b72022-06-28 15:19:39 +00007049 # test out
Kulin Sethe011a8e2022-05-13 18:28:53 +00007050 mps_out = torch.zeros(shape, device='mps')
7051 torch.normal(mean_tensor, std, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007052
7053 mps_out = torch.zeros(shape, device='mps')
7054 torch.normal(mean, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007055
7056 mps_out = torch.zeros(shape, device='mps')
7057 torch.normal(mean_tensor, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007058
qqaatwe1b15b72022-06-28 15:19:39 +00007059 # test without out
7060 mps_out = torch.normal(mean_tensor, std)
7061 self.assertEqual(mps_out.size(), mean_tensor.size())
7062
7063 mps_out = torch.normal(mean, std_tensor)
7064 self.assertEqual(mps_out.size(), std_tensor.size())
7065
7066 inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
7067 mps_out = torch.normal(mean_tensor, std_tensor)
7068 self.assertEqual(mps_out.size(), inferred_shape)
7069
Kulin Sethe011a8e2022-05-13 18:28:53 +00007070 helper((2, 3, 4, 5, 6))
7071 helper((100, 100), 2.5, 1.2)
7072
7073 def test_bernoulli(self):
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007074 shape = (10, 10)
7075 all_ones = torch.ones(shape, device='mps')
7076 all_zeros = torch.zeros(shape, device='mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007077
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007078 prob_tensor = all_ones * 0.5
7079 # probability of drawing "1" is 0.5
7080 mps_out = torch.bernoulli(prob_tensor)
7081 # We can't check reliably the mean and std.
7082 # Just make sure we don't return constant values
7083 self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
7084 self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007085
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007086 # probability of drawing "1" is 0
7087 mps_out = torch.bernoulli(all_zeros)
7088 self.assertEqual(mps_out, all_zeros)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007089
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007090 # probability of drawing "1" is 1
7091 mps_out = torch.bernoulli(all_ones)
7092 self.assertEqual(mps_out, all_ones)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007093
Ramin Azarmehr688e3512023-01-03 16:01:19 +00007094 def test_mps_generator(self):
7095 # explicit manual seeding by creating an MPS Generator
7096 g_mps = torch.Generator(device='mps')
7097 g_mps.manual_seed(999)
7098 mps_x = torch.randn(5, device='mps', generator=g_mps)
7099 g_mps.manual_seed(999)
7100 mps_y = torch.randn(5, device='mps', generator=g_mps)
7101 # seed values were the same, so the random tensor contents should match
7102 self.assertEqual(mps_x, mps_y)
7103 # save generator's state to restore it later
7104 g_state = g_mps.get_state()
7105
7106 # generate random numbers without seeding
7107 mps_x = torch.randn(5, device='mps', generator=g_mps)
7108 # in this case, the random results must differ from the last generated random results
7109 self.assertNotEqual(mps_x, mps_y)
7110
7111 # restore the previously saved state, and the results should match again
7112 g_mps.set_state(g_state)
7113 mps_x = torch.randn(5, device='mps', generator=g_mps)
7114 self.assertEqual(mps_x, mps_y)
7115
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007116 def test_default_mps_generator(self):
7117 # manual seeding on the "default" MPS generator using
7118 # the global torch.manual_seed()
7119 torch.manual_seed(230)
7120 mps_x = torch.randn(5, device='mps')
7121 # manual seeding using torch.mps.manual_seed()
7122 # which should set the "default" MPS generator
7123 # like the global torch.manual_seed()
7124 torch.mps.manual_seed(230)
7125 mps_y = torch.randn(5, device='mps')
7126 # seed values were the same, so the random tensor contents should match
7127 self.assertEqual(mps_x, mps_y)
7128
7129 # save the default generator's state to restore it later
7130 g_state = torch.mps.get_rng_state()
7131
7132 # generate random numbers without seeding
7133 mps_x = torch.randn(5, device='mps')
7134 # in this case, the random results must differ from the last generated random results
7135 self.assertNotEqual(mps_x, mps_y)
7136
7137 # restore the previously saved state, and the results should match again
7138 torch.mps.set_rng_state(g_state)
7139 mps_x = torch.randn(5, device='mps')
7140 self.assertEqual(mps_x, mps_y)
7141
7142 def test_device_synchronize(self):
7143 # just running some ops each followed by a synchronize to wait for
7144 # MPS stream to finish running each of them
7145 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
7146 .to(device='mps', dtype=torch.float)
7147
7148 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7149 torch.mps.synchronize()
7150 x = net1(x)
7151 torch.mps.synchronize()
7152 x.backward(torch.randn_like(x))
7153 torch.mps.synchronize()
7154
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00007155 @unittest.expectedFailure
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00007156 def test_mps_allocator_module(self):
7157 # first garbage collect and empty the cached blocks
7158 gc.collect()
7159 torch.mps.empty_cache()
7160 # measure memory allocations from MPSAllocator
7161 current_alloc_before = torch.mps.current_allocated_memory()
7162 # after garbage collection and emptying the cache the
7163 # current_allocated_memory must be zero
7164 self.assertTrue(current_alloc_before == 0)
7165 # measure total memory allocations from Metal driver
7166 driver_alloc_before = torch.mps.driver_allocated_memory()
7167 # allocate a new 8 MB tensor to force allocation of a new Metal Heap
7168 x = torch.ones(1024 * 1024 * 8, device="mps")
7169 # get memory allocations after allocating tensor x
7170 current_alloc_after = torch.mps.current_allocated_memory()
7171 driver_alloc_after = torch.mps.driver_allocated_memory()
7172 # current and driver memory allocations must have
7173 # grown at this point
7174 self.assertTrue(current_alloc_after > current_alloc_before)
7175 self.assertTrue(driver_alloc_after > driver_alloc_before)
7176
PyTorch MergeBotf152a792023-02-10 11:32:25 +00007177 # Test random_.to and random_.from
Kulin Sethe011a8e2022-05-13 18:28:53 +00007178 def test_random(self):
7179 def helper(shape, low, high, dtype=torch.int32):
7180
Kulin Sethe011a8e2022-05-13 18:28:53 +00007181 mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
7182
Alban Desmaison02551a02022-05-28 12:39:10 -04007183 # We can't check reliably the mean and std.
7184 # Just make sure we don't return constant values
7185 self.assertNotEqual(mps_out.to('cpu').float().mean(), 0.)
7186 self.assertNotEqual(mps_out.to('cpu').float().std(), 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007187
7188 helper([100, 100], 0, 10)
7189 helper([100, 100], 23, 89)
7190 helper([100, 100], 23, 89, dtype=torch.float32)
7191 helper([100, 100], 23, 89, dtype=torch.int64)
7192 helper([100, 100], 0, 2, dtype=torch.bool)
7193
Kulin Seth83239352022-06-10 13:16:21 +00007194 # Test exponential
7195 def test_exponential(self):
7196 def helper(shape, lamda, dtype=torch.float32):
7197
7198 mps_out = torch.zeros(shape, device='mps', dtype=dtype)
7199 mps_out.exponential_(lamda)
7200
7201 print(mps_out.to('cpu').float().mean(), 1 / lamda)
7202 print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2))
7203
7204 for dtype in [torch.float32, torch.float16]:
7205 helper([100, 100], 2, dtype)
7206 helper([100, 100], 1, dtype)
7207 helper([100, 100], 3, dtype)
7208 helper([100, 100], 0.5, dtype)
7209
7210 def test_exponential_1(self):
7211 rate = torch.randn(5, 5).abs().requires_grad_()
7212 rate_1d = torch.randn(1).abs().requires_grad_()
7213 self.assertEqual(Exponential(rate).sample().size(), (5, 5))
7214 self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
7215 self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
7216 self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
7217 self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
7218 self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
7219
Kulin Sethe011a8e2022-05-13 18:28:53 +00007220 # Test add
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007221 def test_add_sub(self):
7222 def helper(shape, alpha, op_name, inplace):
7223 if op_name == "add":
7224 op = torch.Tensor.add_ if inplace else torch.add
7225 elif op_name == "sub":
7226 op = torch.Tensor.sub_ if inplace else torch.sub
7227
Kulin Setha6347f52022-06-07 18:22:10 +00007228 for dtype in [torch.float16, torch.float32]:
7229 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
7230 mps_x = cpu_x.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007231
Kulin Setha6347f52022-06-07 18:22:10 +00007232 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
7233 mps_y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007234
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007235 cpu_out = op(cpu_x, cpu_y, alpha=alpha)
7236 mps_out = op(mps_x, mps_y, alpha=alpha)
Kulin Setha6347f52022-06-07 18:22:10 +00007237 # fp16 isn't accurate when alpha is passed
7238 # TODO: remove or fix 'tol' when we fix problems with fp16
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007239 tol = 2e-3 if dtype is torch.float16 else None
Kulin Setha6347f52022-06-07 18:22:10 +00007240 self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol)
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007241 if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted.
7242 # create a scalar tensor
7243 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
7244 mps_s = cpu_s.detach().clone().to('mps')
7245 # primary tensor is scalar
7246 self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y))
Kulin Setha6347f52022-06-07 18:22:10 +00007247 # create a scalar tensor
7248 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
7249 mps_s = cpu_s.detach().clone().to('mps')
Kulin Setha6347f52022-06-07 18:22:10 +00007250 # secondary tensor is scalar
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007251 self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007252
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00007253
7254 for op_name, inplace in product(["add", "sub"], [True, False]):
7255 helper((), 0.0, op_name, inplace)
7256 helper((2, 8, 4, 5), 0.0, op_name, inplace)
7257 helper((2, 8, 4, 5), 0.1, op_name, inplace)
7258 helper((2, 8, 4, 5), 1.0, op_name, inplace)
7259 helper((2, 8, 3, 5), 0.1, op_name, inplace)
7260 helper((2, 8, 3, 5), 0.2, op_name, inplace)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007261
7262 # Test add
7263 def test_add_scalars(self):
Kulin Setha6347f52022-06-07 18:22:10 +00007264 def helper(alpha):
7265 for dtype in [torch.float16, torch.float32]:
7266 cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
7267 x = cpu_x.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007268
Kulin Setha6347f52022-06-07 18:22:10 +00007269 cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False)
7270 y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007271
Kulin Setha6347f52022-06-07 18:22:10 +00007272 cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
7273 out = torch.add(x, y, alpha=alpha)
7274 # fp16 isn't accurate when alpha is passed
7275 tol = 1e-3 if dtype is torch.float16 else None
7276 self.assertEqual(out, cpu_out, rtol=tol, atol=tol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007277
Kulin Setha6347f52022-06-07 18:22:10 +00007278 helper(1.0)
7279 helper(0.0)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007280 helper(0.1)
7281 helper(0.2)
7282
Nikita Shulga06f874e2022-06-25 02:21:34 +00007283 # Test int32 tensor + int64 scalar add
7284 # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
7285 x = torch.ones(4, dtype=torch.int32, device='mps')
7286 self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
PyTorch MergeBotcba96362022-12-02 21:36:13 +00007287 self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
Nikita Shulga06f874e2022-06-25 02:21:34 +00007288
Kulin Seth50f7b402022-06-09 17:33:06 +00007289 def test_types_binary_op(self):
7290 # Float * Bool
7291 cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
7292 mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
7293 self.assertEqual(cpu_x, mps_x)
7294 # Float * Int64
7295 cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
7296 mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
7297 self.assertEqual(cpu_y, mps_y)
7298
Kulin Setha6347f52022-06-07 18:22:10 +00007299 def test_unary_ops(self):
7300 def helper(shape, op):
7301 for dtypef in [torch.float32]:
7302 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
7303 mps_x = cpu_x.detach().clone().to('mps')
7304 self.assertEqual(op(cpu_x), op(mps_x))
7305
7306 for dtypei in [torch.int32, torch.int16]:
7307 cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False)
7308 mps_x = cpu_x.to('mps')
7309 self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4)
7310
7311 helper((2, 8, 4, 5), torch.exp)
7312 helper((2, 8, 3, 5), torch.exp2)
arnaudstiegler16e35bd2022-10-26 17:45:46 +00007313 helper((2, 8, 3, 5), torch.expm1)
Kulin Setha6347f52022-06-07 18:22:10 +00007314 helper((2, 8, 3, 5), torch.log)
7315 helper((2, 8, 3, 5), torch.cos)
7316
Kulin Sethe011a8e2022-05-13 18:28:53 +00007317 def test_atan2(self):
7318 def helper(shape):
7319 input_cpu = torch.randn(shape)
7320 input_mps = input_cpu.detach().clone().to("mps")
7321
7322 other_cpu = torch.randn(shape)
7323 other_mps = other_cpu.detach().clone().to("mps")
7324
7325 atan2_cpu = torch.atan2(input_cpu, other_cpu)
7326 atan2_mps = torch.atan2(input_mps, other_mps)
7327
7328 self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
7329
7330 helper(4)
7331 helper(10000)
7332 helper((10000, 40))
7333
Kulin Seth6a842e32022-10-03 21:05:30 +00007334 def test_multinomial(self):
7335 # Test with num_dist = 1
7336 def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
7337 cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
7338 prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
7339
7340 mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
Thomas4935b592022-11-23 02:18:03 +00007341 if (not replacement):
Kulin Seth6a842e32022-10-03 21:05:30 +00007342 print(mps_out.to('cpu'))
7343 else:
7344 # Compare "real" with theoretical values
7345 print(mps_out.to('cpu').float().mean(), compare_mean)
7346 print(mps_out.to('cpu').float().std() ** 2, compare_var)
7347
7348 # TODO: Add tests for data types
7349 helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
7350 helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
7351 helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
7352 helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
7353 helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007354
Nikita Shulga10a1efb2023-02-05 18:21:29 +00007355 def test_cumsum_dim_check(self):
7356 x = torch.rand((3, 3), device="mps")
7357 self.assertEqual(x.cumsum(1), x.cumsum(-1))
7358 self.assertEqual(x.cumsum(0), x.cumsum(-2))
7359 self.assertRaises(IndexError, lambda: x.cumsum(2))
7360 self.assertRaises(IndexError, lambda: x.cumsum(-3))
7361
Soof Golane4fe11e2023-02-09 10:42:48 +00007362
7363class TestTopK(TestCase):
7364 def _test_topk(self, shape, largest):
7365 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7366 x = cpu_x.detach().clone().to('mps')
7367 if isinstance(shape, tuple):
7368 for curr_dim, dim_size in enumerate(shape):
7369 for k in range(1, dim_size + 1):
7370 topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
7371 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
7372 self.assertEqual(topk_values, topk_values_cpu)
7373 self.assertEqual(topk_indices, topk_indices_cpu)
7374 else:
7375 for k in range(1, shape):
7376 topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
7377 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
7378 self.assertEqual(topk_values, topk_values_cpu)
7379 self.assertEqual(topk_indices, topk_indices_cpu)
7380
7381 def test_topk(self):
7382 largest_vals = [True, False]
7383 shapes = [
7384 # Zero Element Tensors
7385 0,
7386 (1, 0),
7387 (0, 1),
7388 (1, 0, 1),
7389 # Multiple Element Tensors
7390 1,
7391 2,
7392 (5, 1),
7393 (1, 5),
7394 (5, 9, 7, 4),
7395 ]
7396
7397 for shape in shapes:
7398 for largest_val in largest_vals:
7399 with self.subTest(shape=shape, largest_val=largest_val):
7400 self._test_topk(shape, largest_val)
7401
Kulin Sethe011a8e2022-05-13 18:28:53 +00007402class TestNNMPS(NNTestCase):
7403
7404 def _create_basic_net(self):
7405 class Layer(nn.Module):
7406 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007407 super().__init__()
Kulin Sethe011a8e2022-05-13 18:28:53 +00007408 self.layer_dummy_param = Parameter(torch.empty(3, 5))
7409 self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
7410
7411 class Net(nn.Module):
7412 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007413 super().__init__()
Kulin Sethe011a8e2022-05-13 18:28:53 +00007414 self.l1 = Layer()
7415 self.dummy_param = Parameter(torch.empty(3, 5))
7416 self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
7417
7418 l = Layer()
7419 n = Net()
7420 s = nn.Sequential(n, n)
7421
7422 return l, n, s
7423
7424 def test_requires_grad_(self):
7425 m = self._create_basic_net()[-1]
7426 assert len(list(m.buffers())) > 0, 'invalid test'
7427 assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
7428 assert len(list(m.parameters())) > 0, 'invalid test'
7429 assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
7430 for requires_grad in (False, True):
7431 self.assertIs(m.requires_grad_(requires_grad), m)
7432 for p in m.parameters():
7433 self.assertEqual(p.requires_grad, requires_grad)
7434 for b in m.buffers():
7435 self.assertFalse(b.requires_grad)
7436
7437 def test_module_backcompat(self):
7438 from torch.serialization import SourceChangeWarning
7439 path = download_file('https://download.pytorch.org/test_data/linear.pt')
7440 with warnings.catch_warnings():
7441 warnings.simplefilter('ignore', SourceChangeWarning)
7442 m = torch.load(path)
7443 input = torch.randn(2, 3, dtype=torch.float)
7444 self.assertEqual(m(input).size(), (2, 5))
7445
7446 def test_conv_backcompat(self):
7447 from torch.serialization import SourceChangeWarning
7448 # This file was generated by running on PyTorch 1.0.1 on Python 2:
7449 #
7450 # import torch
7451 # from torch import nn
7452 # m = nn.Conv2d(1, 1, 1)
7453 # torch.save(m, 'legacy_conv2d.pt')
7454 #
7455 # NB: This Pickle also contains some Unicode data!
7456 path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
7457 with warnings.catch_warnings():
7458 warnings.simplefilter('ignore', SourceChangeWarning)
7459 m = torch.load(path, encoding='utf-8')
7460 input = torch.randn((1, 1, 1, 1), dtype=torch.float)
7461 self.assertEqual(m(input).size(), (1, 1, 1, 1))
7462
Kulin Seth017b0ae2022-05-31 02:09:03 +00007463 def test_conv_expand(self):
7464 device = 'mps'
7465 input_ = torch.rand(2, 3, 16, 16, device=device)
7466 kernel = torch.rand(1, 1, 3, 11, device=device)
7467 tmp_kernel = kernel.expand(-1, 3, -1, -1)
7468 output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
7469
7470 # The test should not crash
7471 def test_permute(self):
PumeTufc1c0cd2022-11-18 07:24:33 +00007472 M_cpu = torch.randn(5, 5)
7473 M_mps = M_cpu.to('mps')
7474
7475 output_cpu = M_cpu.permute(1, 0)
7476 output_mps = M_mps.permute(1, 0)
7477
7478 self.assertEqual(output_cpu, output_mps)
7479 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Seth017b0ae2022-05-31 02:09:03 +00007480
7481 # Printing of non_contiguous should not crash
7482 def test_print_non_contiguous(self):
7483 print(torch.ones(100, 100, device='mps').nonzero())
7484 print(torch.ones(100, 100, device='mps').nonzero().contiguous())
7485
Kulin Sethe011a8e2022-05-13 18:28:53 +00007486 def test_zero_grad(self):
7487 i = torch.randn(2, 5, requires_grad=True)
7488 module = nn.Linear(5, 5)
7489 for p in module.parameters():
7490 p.requires_grad = False
7491 module.zero_grad()
7492
7493 module.weight.requires_grad = True
7494 module.zero_grad()
7495 self.assertIsNone(module.weight.grad) # uninitialized grad
7496
7497 module(i).sum().backward()
7498 self.assertIsNotNone(module.weight.grad)
7499 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
7500 module.zero_grad()
Jane Xub90496e2023-01-25 19:47:57 +00007501 self.assertIsNone(module.weight.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007502
7503 module.bias.requires_grad = True
7504 module.zero_grad()
Jane Xub90496e2023-01-25 19:47:57 +00007505 self.assertIsNone(module.weight.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007506 self.assertIsNone(module.bias.grad)
7507 module(i).sum().backward()
7508 self.assertIsNotNone(module.weight.grad)
7509 self.assertIsNotNone(module.bias.grad)
7510 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
7511 self.assertGreater(module.bias.grad.data.abs().sum(), 0)
Jane Xub90496e2023-01-25 19:47:57 +00007512
7513 # Force set to zeros.
7514 module.zero_grad(set_to_none=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007515 self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
7516 self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
7517
Jane Xub90496e2023-01-25 19:47:57 +00007518 module.zero_grad()
Kulin Sethe011a8e2022-05-13 18:28:53 +00007519 self.assertIsNone(module.weight.grad)
Jane Xub90496e2023-01-25 19:47:57 +00007520 self.assertIsNone(module.bias.grad)
7521
Kulin Sethe011a8e2022-05-13 18:28:53 +00007522
7523 def test_no_grad(self):
7524 for dtype in [torch.bfloat16, torch.float, torch.double]:
7525 module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
7526 input = torch.randn(1, 2, 10, 10).to(dtype)
7527 x = input
7528 y = input.clone()
7529
7530 output = module(x)
7531 self.assertTrue(output.requires_grad)
7532 output.backward(torch.ones(1, 5, 10, 10))
7533
7534 with torch.no_grad():
7535 output2 = module(y)
7536 self.assertFalse(output2.requires_grad)
7537 self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
7538
7539 def test_invalid_conv1d(self):
7540 for dtype in [torch.bfloat16, torch.float, torch.double]:
7541 module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
7542 input = torch.randn(1, 3, 4).to(dtype)
7543 with self.assertRaisesRegex(RuntimeError,
7544 r'Calculated padded input size per channel: \(4\). ' +
7545 r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
7546 module(input)
7547
7548 # Negative stride check
7549 module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
7550 input = torch.randn(1, 3, 4).to(dtype)
7551 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
7552 module(input)
7553
7554 def test_conv2d_discontiguous_weight(self):
7555 # Test for https://github.com/pytorch/pytorch/issues/55781
7556 x = torch.ones(64, 16, 16, 16)
7557 weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
7558 self.assertFalse(weight.is_contiguous())
7559 y = torch.nn.functional.conv2d(x, weight, None)
7560 if torch.backends.mkldnn.is_available():
7561 # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
7562 with torch.backends.mkldnn.flags(enabled=False):
7563 y_ = torch.nn.functional.conv2d(x, weight, None)
7564 self.assertEqual(y, y_)
7565 self.assertEqual(y.sum(), 4186112.)
7566
7567 def test_invalid_conv2d(self):
7568 for dtype in [torch.bfloat16, torch.float, torch.double]:
7569 module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
7570 input = torch.empty(1, 1, 4, 4).to(dtype)
7571 self.assertRaises(RuntimeError, lambda: module(input))
7572
7573 module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
7574 input = torch.randn(1, 3, 1, 1)
7575 with self.assertRaisesRegex(RuntimeError,
7576 r'Calculated padded input size per channel: \(1 x 1\). ' +
7577 r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
7578 module(input)
7579
7580 # Negative stride check
7581 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
7582 input = torch.randn(1, 3, 4, 4).to(dtype)
7583 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
7584 module(input)
7585
7586 # Zero stride check
7587 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
7588 input = torch.randn(1, 3, 4, 4).to(dtype)
7589 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
7590 module(input)
7591
Nikita Shulgafa799132022-10-06 15:38:57 +00007592 # Input and weights on different devices
7593 self.assertRaisesRegex(RuntimeError,
7594 'must be on the same device',
7595 lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
7596 self.assertRaisesRegex(RuntimeError,
7597 'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same',
7598 lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
7599
7600
Kulin Sethe011a8e2022-05-13 18:28:53 +00007601 def test_conv2d_valid_padding(self, device='mps'):
7602 # Test F.conv2d padding='valid' is the same as no padding
7603 x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
7604 y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
7605
7606 expect = F.conv2d(x, y)
7607 actual = F.conv2d(x, y, padding='valid')
7608 self.assertEqual(expect.to('cpu'), actual.to('cpu'))
7609
Kulin Seth4858c562022-06-02 06:17:19 +00007610 def test_gemm_permute_transpose(self):
7611 batch_size = 32
7612 n = 20
7613 hidden = 768
7614 num_attention_heads = 12
7615 attention_head_size = hidden // num_attention_heads
7616
7617 def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
7618 new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
7619 x = x.view(new_x_shape)
7620 return x.permute(0, 2, 1, 3)
7621
7622 def attention2(key, *, workaround=False, device):
7623 key = transpose_for_scores(key)
7624 res = key.transpose(-1, -2)
7625 return res
7626
7627 A = torch.randn(batch_size, n, hidden)
7628 A_mps = A.detach().clone().to("mps")
7629
7630 r1 = attention2(A, device="cpu")
7631 r2 = attention2(A_mps, device="mps")
7632
7633 r2_cpu = r2.to("cpu")
7634 self.assertEqual(r1, r2_cpu)
7635
Nikita Shulgafd3a7262022-12-21 21:35:54 -08007636 def test_group_norm_backward(self, device='mps'):
7637 # See https://github.com/pytorch/pytorch/issues/88331 for more detail
7638 shape = [1, 4, 16, 16]
7639 x = torch.full(shape, 7.0, device=device)
7640
7641 target = torch.ones((1, 3, 128, 128), device=device)
7642
7643 conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
7644 conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
7645 norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device)
7646
7647 with torch.enable_grad():
7648 x = x.detach().requires_grad_()
7649 out = 5.5 * x
7650 out = conv_in(out)
7651 out = out + norm(out)
7652 out = out + norm(out)
7653 out = out + norm(out)
7654 out = F.interpolate(out, scale_factor=8.0, mode="nearest")
7655 out = norm(out)
7656 out = conv_out(out)
7657
7658 loss = (out - target).norm(dim=-1).sum()
7659 grad = -torch.autograd.grad(loss, x)[0]
7660 self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd')
7661
7662
Kulin Sethe011a8e2022-05-13 18:28:53 +00007663 # def test_conv2d_same_padding(self, device='mps'):
7664 # x = torch.rand(1, 1, 10, 11, device=device)
7665 # y = torch.rand(1, 1, 4, 5, device=device)
7666 # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
7667 # actual = F.conv2d(x, y, padding='same')
7668 # self.assertEqual(expect.to('cpu'), actual.to('cpu'))
7669
7670 # # With dilation
7671 # y = torch.rand(1, 1, 3, 4, device=device)
7672 # expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
7673 # actual = F.conv2d(x, y, padding='same', dilation=2)
7674 # self.assertEqual(expect, actual)
7675
7676 # # Dilation with asymmetric padding
7677 # y = torch.rand(1, 1, 4, 4, device=device)
7678 # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
7679 # actual = F.conv2d(x, y, padding='same', dilation=3)
7680 # self.assertEqual(expect, actual)
7681
7682
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00007683class TestConstantPadNd(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007684 def test_preserves_memory_format(self):
7685 nchw_tensor = torch.rand((1, 2, 5, 3))
7686 nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
7687 self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
7688
7689 nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
7690 nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
7691 self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
7692
7693
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00007694class TestLinalgMPS(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007695 def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
7696 dtype = t.dtype
7697 numpy_dtype = dtype
7698 alpha = 1.2 if alpha is None else alpha
7699 beta = 0.8 if beta is None else beta
7700 res1 = f(t, m, v, alpha=alpha, beta=beta)
7701 res2 = torch.full_like(res1, math.nan)
7702 if transpose_out:
7703 res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
7704 f(t, m, v, alpha=alpha, beta=beta, out=res2)
7705 res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
7706 if beta != 0:
7707 res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
7708 res3 = torch.from_numpy(res3).to(dtype)
Kulin Seth978304f2022-05-14 13:33:16 +00007709 self.assertEqual(res1, res2)
7710 self.assertEqual(res1, res3)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007711
7712 def test_addmm(self, device="mps", dtype=torch.float32):
7713 M = torch.randn(10, 25, device=device).to(dtype)
7714 m1 = torch.randn(10, 50, device=device).to(dtype)
7715 m2 = torch.randn(50, 25, device=device).to(dtype)
7716 self._test_addmm_addmv(torch.addmm, M, m1, m2)
7717
Kulin Sethe011a8e2022-05-13 18:28:53 +00007718 # Test beta=0, M=nan
7719 M = torch.full((10, 25), math.nan, device=device).to(dtype)
7720 m1 = torch.randn(10, 50, device=device).to(dtype)
7721 m2 = torch.randn(50, 25, device=device).to(dtype)
7722 self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
7723
Kulin Seth978304f2022-05-14 13:33:16 +00007724 # Test transpose
7725 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
7726 def maybe_transpose(cond, m):
7727 if not cond:
7728 return m
7729 return m.t().clone(memory_format=torch.contiguous_format).t()
Kulin Sethe011a8e2022-05-13 18:28:53 +00007730
Kulin Seth978304f2022-05-14 13:33:16 +00007731 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
7732 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
7733 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
7734 self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007735
Denis Vieriu507b8c32023-02-11 00:16:46 +00007736 def _test_addr(self, f, t, m, v, alpha=None, beta=None):
7737 dtype = t.dtype
7738 numpy_dtype = dtype
7739 alpha = 1.2 if alpha is None else alpha
7740 beta = 0.8 if beta is None else beta
7741 res1 = f(t, m, v, alpha=alpha, beta=beta)
7742 res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
7743 if beta != 0:
7744 res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
7745 res2 = torch.from_numpy(res2).to(dtype)
7746 self.assertEqual(res1, res2)
7747
7748 def test_addr(self, device="mps", dtype=torch.float32):
7749 M = torch.randn(10, 25, device=device).to(dtype)
7750 m1 = torch.randn(10, device=device).to(dtype)
7751 m2 = torch.randn(25, device=device).to(dtype)
7752 self._test_addr(torch.addr, M, m1, m2)
7753
7754 # Test beta=0, M=nan
7755 M = torch.full((10, 25), math.nan, device=device).to(dtype)
7756 m1 = torch.randn(10, device=device).to(dtype)
7757 m2 = torch.randn(25, device=device).to(dtype)
7758 self._test_addr(torch.addr, M, m1, m2, beta=0)
7759
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00007760class TestGatherScatter(TestCaseMPS):
Kulin Sethb744e1c2022-07-01 15:10:56 +00007761 def test_slicing_with_step(self):
7762 # Slicing with step
7763 # https://github.com/pytorch/pytorch/issues/78886
7764 x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
7765 x_mps[::2] = 1.0
7766
Kulin Seth54361342022-07-06 03:39:20 +00007767 x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
Kulin Sethb744e1c2022-07-01 15:10:56 +00007768 x_cpu[::2] = 1.0
7769
7770 self.assertEqual(x_cpu, x_mps)
7771
Denis Vieriu4247cc92022-09-14 17:24:24 +00007772 def test_cast_gather_scatter(self):
7773 for _ in range(0, 50):
7774 input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8)
7775 with torch.no_grad():
7776 s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
7777 s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0)
7778 s = s.long()
7779 s_cpu = s_cpu.long()
7780 self.assertEqual(s.cpu(), s_cpu)
7781
7782 s = s.float()
7783 s_cpu = s_cpu.float()
7784 self.assertEqual(s.cpu(), s_cpu)
7785
7786 s /= 255
7787 s_cpu /= 255
7788 self.assertEqual(s.cpu(), s_cpu)
7789
Kulin Sethb744e1c2022-07-01 15:10:56 +00007790 def test_slicing_replace_column(self):
7791 # https://github.com/pytorch/pytorch/issues/78074
7792 def _helper(tensor_data):
7793 x_cpu = torch.tensor(tensor_data)
7794 x_mps = x_cpu.to('mps')
7795
7796 x_cpu[:, 0] = 7
7797 x_mps[:, 0] = 7
7798
7799 self.assertEqual(x_cpu, x_mps)
7800
7801 _helper([[1, 2, 3], [4, 5, 6]])
7802 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
7803 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
7804
7805 def test_inplace_scatter(self):
7806 # https://github.com/pytorch/pytorch/issues/79672
7807 a_mps = torch.ones((2, 2),).to(torch.device("mps"))
7808 b_mps = torch.ones((2, 2),).to(torch.device("mps"))
7809
7810 a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
7811 b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
7812
7813 a_mps[:, 0] += b_mps[:, 0]
7814 a_cpu[:, 0] += b_cpu[:, 0]
7815 self.assertEqual(a_cpu, a_mps)
7816
7817 a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
7818 a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
7819 self.assertEqual(a_cpu, a_mps)
7820
Kulin Seth76cff182022-07-04 06:41:39 +00007821# These tests were taken from test/test_view_ops.py
7822# They are subset of those tests as currently only this subset is working.
7823# This whole `class` will be removed when we add generic device testing. There
7824# are no additional tests added apart from what is part of test_view_ops.py
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00007825class TestViewOpsMPS(TestCaseMPS):
Kulin Sethb744e1c2022-07-01 15:10:56 +00007826 exact_dtype = True
7827
Ramin Azarmehr36062dd2023-02-07 15:51:26 +00007828 def test_permute_slicing(self):
7829 # test the fix for crash reported in
7830 # https://github.com/pytorch/pytorch/issues/94190
7831 cpu_x = (torch.randn([3, 2, 2]).float())
7832 mps_x = cpu_x.detach().clone().to('mps')
7833 cpu_out = cpu_x.permute((2, 0, 1)) * 2.0
7834 mps_out = mps_x.permute((2, 0, 1)) * 2.0
7835 # this print caused a crash prior to fix PR#94259
7836 print(torch.zeros_like(mps_out))
Ramin Azarmehr4f691d22023-02-09 19:07:13 +00007837 # test the fix for fill_scalar_mps() mentioned in issue #94190
7838 self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out))
7839 self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1))
Ramin Azarmehr36062dd2023-02-07 15:51:26 +00007840
Kulin Sethb744e1c2022-07-01 15:10:56 +00007841 def is_view_of(self, base, other):
7842 if (not other._is_view() or
7843 other is base or
7844 other._base is not base or
7845 base.device != other.device):
7846 return False
7847 # Note: only validates storage on native device types
7848 # because some accelerators, like XLA, do not expose storage
Kulin Seth76cff182022-07-04 06:41:39 +00007849 if base.device.type == 'mps':
Kulin Sethb744e1c2022-07-01 15:10:56 +00007850 if base.storage().data_ptr() != other.storage().data_ptr():
7851 return False
7852
7853 return True
7854
7855 # Returns true if v1 and v2 are views of the same base
7856 def is_view_of_same_base(self, v1, v2):
7857 if (not v1._is_view() or v1 is v2):
7858 return False
7859 return self.is_view_of(v1._base, v2)
7860
7861 # Performs transpose if contiguous=True, else returns the input tensor as is
7862 def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
7863 if contiguous:
7864 return x
7865 else:
7866 return x.transpose(dim0, dim1)
7867
7868 def test_diagonal_view(self, device="mps"):
7869 t = torch.ones((5, 5), device=device)
7870 v = torch.diagonal(t)
7871 self.assertTrue(self.is_view_of(t, v))
7872
7873 v[0] = 0
7874 self.assertEqual(t[0, 0], v[0])
7875
7876 t = torch.ones((3, 3, 3), device="mps")
7877 v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
7878 self.assertTrue(self.is_view_of(t, v))
7879
7880 v[0, 0] = 0
7881 self.assertEqual(t[0, 0, 1], v[0, 0])
7882
7883 def test_select_view(self, device="mps") -> None:
7884 t = torch.ones((5, 5), device=device)
7885 v = t.select(0, 2)
7886 self.assertTrue(self.is_view_of(t, v))
7887
7888 v[0] = 0
7889 self.assertEqual(t[2, 0], v[0])
7890
7891 def test_unbind_view(self, device="mps") -> None:
7892 t = torch.zeros((5, 5), device=device)
7893 tup = torch.unbind(t)
7894
7895 for idx, v in enumerate(tup):
7896 self.assertTrue(self.is_view_of(t, v))
7897
7898 v[0] = idx + 1
7899 self.assertEqual(t[idx, 0], v[0])
7900
7901 def test_expand_view(self, device="mps") -> None:
7902 t = torch.ones((5, 1), device=device)
7903 v = t.expand(5, 5)
7904 self.assertTrue(self.is_view_of(t, v))
7905
7906 v[2, 2] = 0
7907 self.assertEqual(t[2, 0], v[2, 2])
7908
7909 def test_expand_as_view(self, device="mps"):
7910 t = torch.ones((5, 1), device=device)
7911 e = torch.empty((5, 5), device=device)
7912 v = t.expand_as(e)
7913 self.assertTrue(self.is_view_of(t, v))
7914
7915 v[2, 2] = 0
7916 self.assertEqual(t[2, 0], v[2, 2])
7917
7918 def test_narrow_view(self, device="mps"):
7919 t = torch.ones((5, 5), device=device)
7920 v = torch.narrow(t, 1, 2, 2)
7921 self.assertTrue(self.is_view_of(t, v))
7922
7923 v[0, 0] = 0
7924 self.assertEqual(t[0, 2], v[0, 0])
7925
7926 def test_permute_view(self, device="mps") -> None:
7927 t = torch.ones((5, 5), device=device)
7928 v = t.permute(1, 0)
7929 self.assertTrue(self.is_view_of(t, v))
7930
7931 v[0, 1] = 0
7932 self.assertEqual(t[1, 0], v[0, 1])
7933
7934 def test_transpose_view(self, device="mps"):
7935 for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
7936 t = torch.ones((5, 5), device=device)
7937 v = fn(t, 0, 1)
7938 self.assertTrue(self.is_view_of(t, v))
7939
7940 v[0, 1] = 0
7941 self.assertEqual(t[1, 0], v[0, 1])
7942
7943 def test_transpose_inplace_view(self, device="mps"):
7944 t = torch.ones(5, 5, device=device)
7945 v = t.view_as(t)
7946 v = v.swapdims_(0, 1)
7947 self.assertTrue(self.is_view_of(t, v))
7948 v[0, 1] = 0
7949 self.assertEqual(t[1, 0], v[0, 1])
7950
7951 t = torch.ones(5, 5, device=device)
7952 v = t.view_as(t)
7953 v = v.swapaxes_(0, 1)
7954 self.assertTrue(self.is_view_of(t, v))
7955 v[0, 1] = 0
7956 self.assertEqual(t[1, 0], v[0, 1])
7957
7958 t = torch.ones(5, 5, device=device)
7959 v = t.view_as(t)
7960 v = v.transpose_(0, 1)
7961 self.assertTrue(self.is_view_of(t, v))
7962 v[0, 1] = 0
7963 self.assertEqual(t[1, 0], v[0, 1])
7964
7965 def test_t_view(self, device="mps"):
7966 t = torch.ones((5, 5), device=device)
7967 v = t.t()
7968 self.assertTrue(self.is_view_of(t, v))
7969
7970 v[0, 1] = 0
7971 self.assertEqual(t[1, 0], v[0, 1])
7972
7973 def test_t_inplace_view(self, device="mps"):
7974 t = torch.ones(5, 5, device=device)
7975 v = t.view_as(t)
7976 v = v.t_()
7977 self.assertTrue(self.is_view_of(t, v))
7978 v[0, 1] = 0
7979 self.assertEqual(t[1, 0], v[0, 1])
7980
7981 def test_T_view(self, device="mps"):
7982 for op in ("T", "H", "mT", "mH"):
7983 t = torch.ones((5, 5), device=device)
7984 v = getattr(t, op)
7985 self.assertTrue(self.is_view_of(t, v))
7986
7987 v[0, 1] = 0
7988 self.assertEqual(t[1, 0], v[0, 1])
7989
Denis Vieriu4477a5b2022-12-22 21:21:00 +00007990 def test_unfold_view(self, device="mps"):
7991 t = torch.ones(10, device=device)
7992 v = t.unfold(0, 3, 2)
7993 self.assertTrue(self.is_view_of(t, v))
Kulin Sethb744e1c2022-07-01 15:10:56 +00007994
Denis Vieriu4477a5b2022-12-22 21:21:00 +00007995 v[1, 0] = 0
7996 self.assertEqual(t[2], v[1, 0])
Kulin Sethb744e1c2022-07-01 15:10:56 +00007997
7998 def test_squeeze_view(self, device="mps"):
7999 t = torch.ones(5, 1, 5, device=device)
8000 v = torch.squeeze(t)
8001 self.assertTrue(self.is_view_of(t, v))
8002 v[0, 1] = 0
Kulin Seth76cff182022-07-04 06:41:39 +00008003 self.assertTrue(t is v._base)
Kulin Sethb744e1c2022-07-01 15:10:56 +00008004
8005 def test_squeeze_inplace_view(self, device="mps"):
8006 t = torch.ones(5, 5, device=device)
8007 v = t.view_as(t)
8008 v = v.squeeze_()
8009 self.assertTrue(self.is_view_of(t, v))
8010 v[0, 1] = 0
Kulin Seth76cff182022-07-04 06:41:39 +00008011 self.assertTrue(t is v._base)
Kulin Sethb744e1c2022-07-01 15:10:56 +00008012
8013 def test_unsqueeze_view(self, device="mps"):
8014 t = torch.ones(5, 5, device=device)
8015 v = torch.unsqueeze(t, 1)
8016 self.assertTrue(self.is_view_of(t, v))
8017
8018 v[0, 0, 1] = 0
8019 self.assertEqual(t[0, 1], v[0, 0, 1])
8020
8021 def test_unsqueeze_inplace_view(self, device="mps"):
8022 t = torch.ones(5, 5, device=device)
8023 v = t.view_as(t)
8024 v = v.unsqueeze_(1)
8025 self.assertTrue(self.is_view_of(t, v))
8026 v[0, 0, 1] = 0
8027 self.assertEqual(t[0, 1], v[0, 0, 1])
8028
8029 def test_as_strided_view(self, device="mps"):
8030 t = torch.ones(5, 5, device=device)
8031 v = torch.as_strided(t, (25,), (1,))
8032 self.assertTrue(self.is_view_of(t, v))
8033
8034 v[6] = 0
8035 self.assertEqual(t[1, 1], v[6])
8036
8037 def test_as_strided_inplace_view(self, device="mps"):
8038 t = torch.ones(5, 5, device=device)
8039 v = t.view_as(t)
8040 v = v.as_strided_((25,), (1,))
8041 self.assertTrue(self.is_view_of(t, v))
8042 v[6] = 0
8043 self.assertEqual(t[1, 1], v[6])
8044
8045 def test_view_view(self, device="mps"):
8046 t = torch.ones(5, 5, device=device)
8047 v = t.view(25)
8048 self.assertTrue(self.is_view_of(t, v))
8049
8050 v[6] = 0
8051 self.assertEqual(t[1, 1], v[6])
8052
8053 def test_view_as_view(self, device="mps"):
8054 t = torch.ones(5, 5, device=device)
8055 e = torch.empty((25,))
8056 v = t.view_as(e)
8057 self.assertTrue(self.is_view_of(t, v))
8058
8059 v[6] = 0
8060 self.assertEqual(t[1, 1], v[6])
8061
8062 def test_contiguous_self(self, device="mps"):
8063 t = torch.ones(5, 5, device=device)
8064 s = t.contiguous()
8065 self.assertTrue(s is t)
8066
8067 def test_contiguous_nonview(self, device="mps"):
8068 t = torch.ones(5, 5, device=device)
8069 nv = t.t().contiguous()
8070 self.assertTrue(not self.is_view_of(t, nv))
8071
8072 nv[0, 0] = 0
8073 self.assertNotEqual(t[0, 0], nv[0, 0])
8074
8075 def test_reshape_view(self, device="mps"):
8076 t = torch.ones(5, 5, device=device)
8077 v = torch.reshape(t, (25,))
8078 self.assertTrue(self.is_view_of(t, v))
8079
8080 v[6] = 0
8081 self.assertEqual(t[1, 1], v[6])
8082
8083 def test_reshape_as_view(self, device="mps"):
8084 t = torch.ones(5, 5, device=device)
8085 e = torch.empty((25,), device=device)
8086 v = t.reshape_as(e)
8087 self.assertTrue(self.is_view_of(t, v))
8088
8089 v[6] = 0
8090 self.assertEqual(t[1, 1], v[6])
8091
8092 def test_reshape_nonview(self, device="mps"):
8093 t = torch.ones(5, 5, device=device)
8094 nv = torch.reshape(t.t(), (25,))
8095 self.assertTrue(not self.is_view_of(t, nv))
8096
8097 nv[6] = 0
8098 self.assertNotEqual(t[1, 1], nv[6])
8099
8100 def test_flatten_view(self, device="mps"):
8101 def test_writes_propagate(t, v):
8102 idx_t = (0,) * t.ndim
8103 idx_v = (0,) * v.ndim
8104 v[idx_v] = 0
8105 self.assertEqual(t[idx_t], v[idx_v])
8106
8107 t = torch.ones(1, 2, 3, 4, device=device)
8108 v = t.flatten()
8109 self.assertTrue(self.is_view_of(t, v))
8110 test_writes_propagate(t, v)
8111
8112 # zero-dimensional tensor
8113 t = torch.tensor(1, device=device)
8114 v = t.flatten()
8115 test_writes_propagate(t, v)
8116 self.assertTrue(self.is_view_of(t, v))
8117
8118 t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
8119 v = t.flatten(0, 1)
8120 test_writes_propagate(t, v)
8121 self.assertTrue(self.is_view_of_same_base(t, v))
8122
8123 # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
8124 t = torch.ones(720, device=device) \
8125 .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
8126 # [--1--|---2---|-3-] [--1--|----2---|-3-]
8127 v1 = t.flatten(0, 1)
8128 v2 = v1.flatten(1, 3)
8129 v3 = v2.flatten(2, 2)
8130 test_writes_propagate(t, v1)
8131 self.assertTrue(self.is_view_of_same_base(t, v1))
8132 test_writes_propagate(t, v2)
8133 self.assertTrue(self.is_view_of_same_base(t, v2))
8134 test_writes_propagate(t, v3)
8135 self.assertTrue(self.is_view_of_same_base(t, v3))
8136
8137 def test_flatten_nonview(self, device="mps"):
8138 def assert_is_nonview(t, nv):
8139 idx_t = (0,) * t.ndim
8140 idx_nv = (0,) * nv.ndim
8141 self.assertTrue(not nv._is_view())
8142 nv[idx_nv] = 0
8143 self.assertNotEqual(t[idx_t], nv[idx_nv])
8144 t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
8145 nv = t.flatten(1, 3)
8146 assert_is_nonview(t, nv)
8147
8148 t = torch.ones(2, 2, device=device).T
8149 nv = t.flatten()
8150 assert_is_nonview(t, nv)
8151
8152 # flatten returns the original object if start_dim=end_dim
8153 t = t = torch.ones(2, 2, device=device)
8154 nv = t.flatten(1, 1)
8155 self.assertTrue(t is nv)
8156
8157 def test_basic_indexing_slice_view(self, device="mps"):
8158 t = torch.ones(5, 5, device=device)
8159 v = t[:2, :3]
8160 self.assertTrue(self.is_view_of(t, v))
8161
8162 v[0, 0] = 0
8163 self.assertEqual(t[0, 0], v[0, 0])
8164
8165 def test_basic_indexing_ellipses_view(self, device="mps"):
8166 t = torch.ones(5, 5, device=device)
8167 v = t[..., :2]
8168 self.assertTrue(self.is_view_of(t, v))
8169
8170 v[0, 0] = 0
8171 self.assertEqual(t[0, 0], v[0, 0])
8172
8173 def test_basic_indexing_newaxis_view(self, device="mps"):
8174 t = torch.ones(5, 5, device=device)
8175 v = t[None, :2, 3]
8176 self.assertTrue(self.is_view_of(t, v))
8177
8178 v[0, 0] = 0
8179 self.assertEqual(t[0, 3], v[0, 0])
8180
8181 def test_chunk_view(self, device="mps"):
8182 t = torch.zeros(3, 3, device=device)
8183 l = torch.chunk(t, 3)
8184
8185 for idx, v in enumerate(l):
8186 self.assertTrue(self.is_view_of(t, v))
8187
8188 v[0, 0] = idx + 1
8189 self.assertEqual(t[idx, 0], v[0, 0])
8190
8191 def test_split_view(self, device="mps"):
8192 t = torch.zeros(3, 3, device=device)
8193 l = torch.split(t, [1, 1, 1])
8194
8195 for idx, v in enumerate(l):
8196 self.assertTrue(self.is_view_of(t, v))
8197
8198 v[0, 0] = idx + 1
8199 self.assertEqual(t[idx, 0], v[0, 0])
8200
8201 def test_movedim_view(self, device="mps"):
8202 def run_test(device, op):
8203 t = torch.zeros(3, 3, device=device)
8204 out = op(t)
8205
8206 self.assertTrue(self.is_view_of(t, out))
8207
8208 # Randomly change values in output
8209 # and verify that original is changed
8210 # as well.
8211 for _ in range(3):
8212 idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
8213 out[idx_1, idx_2] = random.random()
8214 self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
8215
8216 for fn in [torch.movedim, torch.moveaxis]:
8217 op = partial(fn, source=(0, 1), destination=(1, 0))
8218 run_test(device, op)
8219
8220 op = partial(fn, source=0, destination=1)
8221 run_test(device, op)
8222
8223 # Testing that the generated view_copy kernel and its derivative are implemented correctly
8224 def test_view_copy(self, device="mps"):
8225 a = torch.randn(4, device=device, requires_grad=True)
8226 a_ref = a.clone().detach().requires_grad_()
8227 a_view = a_ref.view(2, 2)
8228 a_view_copy = torch.view_copy(a, (2, 2))
8229
8230 # view_copy ops don't preserve view relationship
8231 self.assertTrue(self.is_view_of(a_ref, a_view))
8232 self.assertFalse(self.is_view_of(a, a_view_copy))
8233
8234 a_view_copy.sum().backward()
8235 a_view.sum().backward()
8236
8237 # forward and backward give the same shape + result
8238 self.assertEqual(a_view_copy, a_view)
8239 self.assertEqual(a.grad, a_ref.grad)
8240
8241 def test_view_copy_out(self, device="mps"):
8242 a = torch.randn(2, 2, device=device)
8243 out = torch.empty(2, device=device)
8244
8245 torch.diagonal_copy(a, out=out)
8246 expected = torch.diagonal_copy(a)
8247
8248 self.assertEqual(expected, out)
8249
8250 a = torch.randn(4, device=device)
8251 out1 = torch.empty(2, device=device)
8252 out2 = torch.empty(2, device=device)
8253
8254 torch.split_copy(a, 2, out=(out1, out2))
8255 expected1, expected2 = torch.split_copy(a, 2)
8256
8257 self.assertEqual(expected1, out1)
8258 self.assertEqual(expected2, out2)
8259
Nikita Shulga13cff2e2022-10-14 17:35:18 +00008260 def test_detached_view_copy(self, device="mps"):
8261 # https://github.com/pytorch/pytorch/issues/86052
8262 x = torch.arange(2)
8263 # .detach() makes y not a view, but contig tensor
8264 # with non-zero offset
8265 y = x[1].detach()
8266 z = y.to(device)
8267 self.assertEqual(y, z.cpu())
8268
Kulin Sethb744e1c2022-07-01 15:10:56 +00008269 def test_empty_reshape(self, device="mps"):
8270 x = torch.randn(0, 6, device=device)
8271 self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
8272 # should be viewable -- i.e. data_ptr is the same.
8273 self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
8274
8275 # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
8276 self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
8277
8278 def test_expand(self, device="mps"):
8279 tensor = torch.rand(1, 8, 1, device=device)
8280 tensor2 = torch.rand(5, device=device)
8281 template = torch.rand(4, 8, 5, device=device)
8282 target = template.size()
8283 self.assertEqual(tensor.expand_as(template).size(), target)
8284 self.assertEqual(tensor.expand(4, 8, 5).size(), target)
8285 self.assertEqual(tensor.expand(target).size(), target)
8286 self.assertEqual(tensor2.expand_as(template).size(), target)
8287 self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
8288 self.assertEqual(tensor2.expand(target).size(), target)
8289
8290 # test double expand
8291 self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
8292
8293 # test non-contiguous
8294 noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
8295 self.assertFalse(noncontig.is_contiguous())
8296 self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
8297
8298 # make sure it's compatible with unsqueeze
8299 expanded = tensor2.expand(1, 1, 5)
8300 unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
8301 self.assertEqual(expanded, unsqueezed)
8302 self.assertEqual(expanded.stride(), unsqueezed.stride())
8303
8304 # test -1 as target size
8305 self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
8306 self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
8307
8308 # test expanding empty to empty
8309 self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
8310
8311 def test_view_empty(self, device="mps"):
8312 x = torch.randn(0, 6, device=device)
8313 self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
8314
8315 def test_reshape(self, device="mps"):
8316 x = torch.randn(3, 3, device=device)
8317 self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
8318 self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
8319 self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
8320 self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
8321
8322 y = torch.randn(4, 4, 4, device=device)[:, 0, :]
8323 # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
8324 if device != "meta":
8325 self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
8326 self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
8327 self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
8328
8329 s = torch.randn((), device=device)
8330 self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
8331 self.assertEqual(s.reshape(-1).shape, (1,))
8332 self.assertRaises(RuntimeError, lambda: s.reshape(2))
8333
8334 empty = torch.tensor([], device=device)
8335 self.assertEqual(empty, empty.reshape(-1))
8336 self.assertEqual(empty, empty.reshape([0]))
8337 # TODO: fix these once we have multi-dimensional empty tensors
8338 self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
8339 self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
8340 self.assertRaises(RuntimeError, lambda: empty.reshape(1))
8341
8342 x = torch.randn(3, 3, device=device)
8343 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
8344 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
8345 self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
8346
8347 def test_narrow(self, device="mps"):
8348 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
8349 self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
8350 self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
8351 self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
8352 self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
8353 self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
8354 self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
8355 self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
8356 self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
8357
8358 def test_narrow_tensor(self, device="mps"):
8359 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
8360 self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
8361 with self.assertRaises(Exception):
8362 x.narrow(0, torch.tensor(0.), 1)
8363 with self.assertRaises(Exception):
8364 x.narrow(0, torch.tensor([0]), 1)
8365 with self.assertRaises(Exception):
8366 x.narrow(0, torch.tensor([0, 1]), 1)
8367
8368 def test_t(self, device="mps"):
8369 # Test 0D tensors
8370 x = torch.randn(())
8371 self.assertEqual(x, x.t())
8372 x = x.to_sparse()
8373 self.assertEqual(x, x.t())
8374
8375 # Test 1D tensors
8376 x = torch.arange(4)
8377 self.assertEqual(x, x.t())
8378 x = x.to_sparse()
8379 self.assertEqual(x, x.t())
8380
8381 # Test 2D tensors
8382 x = torch.rand((2, 2))
8383 self.assertEqual(x.t(), x.transpose(0, 1))
8384 x = x.to_sparse()
8385 self.assertEqual(x.t(), x.transpose(0, 1))
8386
8387 # Test 3D tensor
8388 x = torch.rand((2, 2, 2))
8389 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
8390 x.t()
8391 x = x.to_sparse()
8392 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
8393 x.t()
8394
8395 def test_split(self, device="mps"):
8396 tensor = torch.rand(7, 4)
8397 split_size = 3
8398 dim = 0
8399 target_sizes = ([3, 4], [3, 4], [1, 4])
8400 splits = tensor.split(split_size, dim)
8401 start = 0
8402 for target_size, split in zip(target_sizes, splits):
8403 self.assertEqual(split.size(), target_size)
8404 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
8405 start = start + target_size[dim]
8406
8407 # Variable sections split
8408 tensor = torch.randn(20, 10)
8409 dim = 0
8410 split_sizes = [5, 5, 10]
8411 target_sizes = ([[5, 10], [5, 10], [10, 10]])
8412 splits = tensor.split(split_sizes, dim)
8413 start = 0
8414 for target_size, split in zip(target_sizes, splits):
8415 self.assertEqual(split.size(), target_size)
8416 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
8417 start = start + target_size[dim]
8418
8419 split_sizes = [2, 2, 6]
8420 target_sizes = ([20, 2], [20, 2], [20, 6])
8421 dim = 1
8422 splits = tensor.split(split_sizes, dim)
8423 start = 0
8424 for target_size, split in zip(target_sizes, splits):
8425 self.assertEqual(split.size(), target_size)
8426 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
8427 start = start + target_size[dim]
8428
8429 def test_chunk(self, device="mps"):
8430 tensor = torch.rand(4, 7)
8431 num_chunks = 3
8432 dim = 1
8433 target_sizes = ([4, 3], [4, 3], [4, 1])
8434 splits = tensor.chunk(num_chunks, dim)
8435 start = 0
8436 for target_size, split in zip(target_sizes, splits):
8437 self.assertEqual(split.size(), target_size)
8438 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
8439 atol=0, rtol=0)
8440 start = start + target_size[dim]
8441
8442 # Invalid chunk sizes
8443 error_regex = 'chunk expects.*greater than 0'
8444 with self.assertRaisesRegex(RuntimeError, error_regex):
8445 tensor.chunk(0)
8446 with self.assertRaisesRegex(RuntimeError, error_regex):
8447 tensor.chunk(-2)
8448
8449 def test_unsqueeze(self, device="mps") -> None:
8450 x = torch.randn(2, 3, 4)
8451 y = x.unsqueeze(1)
8452 self.assertEqual(y, x.view(2, 1, 3, 4))
8453 y = x.clone().unsqueeze_(2)
8454 self.assertEqual(y, x.view(2, 3, 1, 4))
8455
8456 x = x[:, 1]
8457 self.assertFalse(x.is_contiguous())
8458 y = x.unsqueeze(1)
8459 self.assertEqual(y, x.contiguous().view(2, 1, 4))
8460 y = x.clone().unsqueeze_(2)
8461 self.assertEqual(y, x.contiguous().view(2, 4, 1))
8462
8463 # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
8464 def test_big_transpose(self, device="mps"):
8465 t = torch.rand(456, 789, device=device)
8466 t1 = t.t().contiguous()
8467 t2 = torch.from_numpy(t.cpu().numpy().transpose())
8468 self.assertEqual(t1, t2)
8469
8470 def test_T(self, device="mps"):
8471 a = torch.randn(2, 3, 4, device=device)
8472 t1 = a.T
8473 t2 = a.permute(2, 1, 0)
8474 self.assertEqual(t2, t1)
8475 b = torch.randn(10, device=device)
8476 self.assertEqual(b, b.T)
Kulin Sethb744e1c2022-07-01 15:10:56 +00008477
8478 def test_transposes(self, device="mps", dtype=torch.float32):
8479 for op in ("T", "H", "mT", "mH", "adjoint"):
lezcano46a81c82023-01-15 19:35:15 +00008480 shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
Kulin Sethb744e1c2022-07-01 15:10:56 +00008481 for shape in shapes:
8482 a = make_tensor(shape, device=device, dtype=dtype)
8483 t1 = getattr(a, op)
8484 if op == "adjoint":
8485 t1 = t1()
8486 t2 = a
8487 if a.ndim != 0:
8488 t2 = t2.transpose(-2, -1)
8489 if op[-1] == "H" or op == "adjoint":
8490 t2 = t2.conj()
8491 self.assertEqual(t2, t1)
8492
8493 def test_transposes_errors(self, device="mps", dtype=torch.float32):
8494 for op in ("H", "mT", "mH", "adjoint"):
8495 shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
8496 for shape in shapes:
8497 a = make_tensor(shape, device=device, dtype=dtype)
8498 with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
8499 t1 = getattr(a, op)
8500 if op == "adjoint":
8501 t1 = t1()
8502
8503 def test_python_types(self, device="mps"):
8504 a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
8505 a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
8506 self.assertEqual(a1.dtype, a2.dtype)
8507
8508 b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
8509 b2 = torch.arange(10, 20, dtype=int, device=device)
8510 self.assertEqual(b1.dtype, b2.dtype)
8511
8512 c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
8513 c2 = torch.tensor([True, False], dtype=bool, device=device)
8514 self.assertEqual(c1.dtype, c2.dtype)
8515
8516 # TODO: is resize best put in test_view_ops?
8517 def test_resize_as_preserves_strides(self, device="mps"):
8518 x = torch.empty(2, 3).t()
8519 old_strides = x.stride()
8520 x.resize_as_(x)
8521 self.assertEqual(x.stride(), old_strides)
8522
8523 def test_memory_format_resize_as(self, device="mps"):
8524 def test_helper(shape, memory_format, device="mps"):
8525 xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
8526 flat = torch.randn(xc.numel(), device=device)
8527 flat.resize_as_(xc, memory_format=torch.preserve_format)
8528 self.assertTrue(flat.is_contiguous(memory_format=memory_format))
8529
8530 test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
8531 test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
8532
8533 def test_memory_format_resize_(self, device="mps"):
8534 def test_helper(shape, numel, memory_format, device="mps"):
8535 flat = torch.randn(numel, device=device)
8536 flat.resize_(shape, memory_format=memory_format)
8537 self.assertTrue(flat.is_contiguous(memory_format=memory_format))
8538
8539 test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
8540 test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
8541
8542 # TODO: OpInfo this
8543 def _test_atleast(self, device, torch_fn):
8544 # 0-dim
8545 s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
8546
8547 gradcheck(lambda x: torch_fn(x), s)
8548 gradgradcheck(lambda x: torch_fn(x), s)
8549
8550 # 1-dim
8551 a = torch.rand(4, dtype=torch.double, requires_grad=True)
8552
8553 gradcheck(lambda x: torch_fn(x), a)
8554 gradgradcheck(lambda x: torch_fn(x), a)
8555
8556 # 2,3,4-dim
8557 b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
8558 c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
8559 d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
8560
8561 input_tuple = (s, a, b, c, d)
8562 gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
8563 gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
8564
8565 def test_atleast_gradient(self, device="mps"):
8566 self._test_atleast(device, torch.atleast_1d)
8567 self._test_atleast(device, torch.atleast_2d)
8568 self._test_atleast(device, torch.atleast_3d)
8569
8570 def test_view(self, device="mps"):
8571 tensor = torch.rand(15, device=device)
8572 template = torch.rand(3, 5, device=device)
8573 empty = torch.empty(0, device=device)
8574 target = template.size()
8575 self.assertEqual(tensor.view_as(template).size(), target)
8576 self.assertEqual(tensor.view(3, 5).size(), target)
8577 self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
8578 self.assertEqual(tensor.view(-1, 5).size(), target)
8579 self.assertEqual(tensor.view(3, -1).size(), target)
8580 tensor_view = tensor.view(5, 3)
8581 tensor_view.fill_(random.uniform(0, 1))
8582 self.assertEqual(empty.view_as(empty), empty)
8583 self.assertEqual(empty.view(0), empty)
8584 self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
8585 self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
8586
8587 # test size inference with empty tensors
8588 self.assertEqual(empty.view(-1).size(), torch.Size([0]))
8589 self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
8590
8591 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
8592 empty.view(-1, 0)
8593
8594 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
8595 empty.view(3, 0, -1, 0)
8596
8597 self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
8598 self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
8599 self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
8600
Kulin Seth76cff182022-07-04 06:41:39 +00008601 def test_contiguous(self, device="mps"):
8602 x = torch.randn(1, 16, 5, 5, device=device)
8603 self.assertTrue(x.is_contiguous())
8604 stride = list(x.stride())
8605 stride[0] = 20
8606 # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
8607 x.set_(x.storage(), 0, x.size(), stride)
8608 self.assertTrue(x.is_contiguous())
Kulin Sethb744e1c2022-07-01 15:10:56 +00008609
Nikita Shulga436993d2023-03-04 01:29:07 +00008610 def test_resize_mps_dtypes(self, device="mps"):
Kulin Sethb744e1c2022-07-01 15:10:56 +00008611 shape = (2, 2)
Nikita Shulga436993d2023-03-04 01:29:07 +00008612 for dt in MPS_DTYPES:
Kulin Sethb744e1c2022-07-01 15:10:56 +00008613 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
8614 x.resize_(shape)
8615 self.assertEqual(shape, x.shape)
8616
Nikita Shulga436993d2023-03-04 01:29:07 +00008617 def test_resize_as_mps_dtypes(self, device="mps"):
8618 for dt in MPS_DTYPES:
Kulin Sethb744e1c2022-07-01 15:10:56 +00008619 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
8620 y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
8621 x.resize_as_(y)
8622 self.assertEqual(y.shape, x.shape)
8623
8624 def test_resize_overflow(self, device="mps"):
8625 x = torch.empty((), dtype=torch.float64)
8626 with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
8627 x.resize_([2, 4, 2**29, 2**29])
8628 with self.assertRaisesRegex(RuntimeError, 'overflow'):
8629 x.resize_([8, 8, 2**29, 2**29])
8630
8631 def test_view_all_dtypes_and_devices(self, device="mps"):
8632 for dt in (torch.float, torch.bool):
8633 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
8634 self.assertEqual(x.view(6).shape, [6])
Kulin Sethe011a8e2022-05-13 18:28:53 +00008635
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00008636class TestConvolutionMPS(TestCaseMPS):
Kulin Seth31d4b6f2022-08-17 00:26:41 +00008637 def test_conv1d_all_strides_paddings(self):
8638 # https://github.com/pytorch/pytorch/issues/82921
8639 def helper(stride, padding):
8640 y_cpu = torch.randn(1, 57, 40)
8641 conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False)
8642 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
8643 x_cpu = conv_cpu(y_cpu)
8644
8645 y_gpu = y_cpu.to(device='mps')
8646 x_gpu = conv_gpu(y_gpu)
8647 self.assertEqual(x_cpu, x_gpu.cpu())
8648 for stride in range(1, 4):
8649 for padding in range(1, 4):
8650 helper(stride, padding)
8651
8652
8653 def test_conv1d_channels_last(self):
8654 # https://github.com/pytorch/pytorch/issues/81557
8655 model_cpu = torch.nn.Conv1d(1, 128, 3)
8656 a_cpu = torch.arange((128 * 176), dtype=torch.float32)
8657 a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
8658 out_cpu = model_cpu(a_cpu)
8659
8660 a_mps = a_cpu.detach().clone().to("mps")
8661 model_mps = model_cpu.to("mps")
8662 out_mps = model_mps(a_mps)
8663
8664 self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
8665
8666 def test_conv_transpose_1d_all_strides(self):
8667 # https://github.com/pytorch/pytorch/issues/82711
8668 def helper(stride):
8669 y_cpu = torch.ones(1, 1, 2)
8670 deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1)
8671 deconv_cpu.weight.data = torch.ones(1, 1, 2)
8672 deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
8673 x_cpu = deconv_cpu(y_cpu)
8674
8675 y_gpu = y_cpu.to(device='mps')
8676 x_gpu = deconv_gpu(y_gpu)
8677 self.assertEqual(x_cpu, x_gpu.cpu())
8678 [helper(stride) for stride in [1, 2, 3]]
8679
8680 def test_conv_transpose_1d_nn_functional(self):
8681 # https://github.com/pytorch/pytorch/issues/82563
8682 tin = torch.rand((1, 512, 1245), dtype=torch.float32)
8683 tparams = torch.rand((512, 256, 16), dtype=torch.float32)
8684 tbias = torch.rand((256), dtype=torch.float32)
8685
8686 device = 'cpu'
8687 tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
8688
8689 device = 'mps'
8690 tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
8691
8692 self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
8693
Kulin Seth077db3d2022-09-20 06:19:40 +00008694 def test_conv_backward_1d_channels_last(self):
Denis Vieriue0b82d72023-01-10 18:30:18 +00008695 def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
8696 # https://github.com/pytorch/pytorch/issues/84511
Denis Vieriu5e475712023-02-22 18:04:09 +00008697 conv_cpu = torch.nn.Conv1d(
8698 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
Denis Vieriue0b82d72023-01-10 18:30:18 +00008699 conv_mps = torch.nn.Conv1d(
8700 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
8701 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
8702 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
Kulin Seth077db3d2022-09-20 06:19:40 +00008703
Kulin Seth077db3d2022-09-20 06:19:40 +00008704
Denis Vieriue0b82d72023-01-10 18:30:18 +00008705 data = torch.rand(shape, dtype=torch.float32)
8706 x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
8707 x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
8708 res_cpu = conv_cpu(x_cpu)
8709 res_mps = conv_mps(x_mps)
8710 self.assertEqual(res_cpu, res_mps)
8711 res_cpu = res_cpu.sum().backward()
8712 res_mps = res_mps.sum().backward()
8713
8714 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
8715 self.assertEqual(x_cpu.grad, x_mps.grad)
8716
8717 helper(shape=(1, 176, 1))
8718 helper(shape=(2, 12, 1))
8719 helper(shape=(3, 176, 1))
8720 helper(shape=(4, 376, 1))
8721 helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
8722 helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
Kulin Seth077db3d2022-09-20 06:19:40 +00008723
Kulin Seth31d4b6f2022-08-17 00:26:41 +00008724 def test_conv1d_contiguous(self):
8725 model_cpu = torch.nn.Conv1d(1, 128, 3)
8726 a_cpu = torch.ones(128, 1, 176)
8727 out_cpu = model_cpu(a_cpu)
8728
8729 a_mps = a_cpu.detach().clone().to("mps")
8730 model_mps = model_cpu.to("mps")
8731 out_mps = model_mps(a_mps)
8732
8733 self.assertEqual(out_cpu.shape, out_mps.shape)
8734 self.assertEqual(out_cpu, out_mps.cpu())
8735
8736 def test_conv2d_all_strides_paddings(self):
8737 # https://github.com/pytorch/pytorch/issues/83180
Denis Vieriu5e475712023-02-22 18:04:09 +00008738 def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
8739 x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
8740 x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
8741
8742 if permute_data:
8743 x_cpu.permute(0, 2, 3, 1)
8744 x_mps.permute(0, 2, 3, 1)
8745
8746 for strideX in range(1, 4):
8747 for strideY in range(1, 4):
8748 conv_cpu = torch.nn.Conv2d(
8749 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
8750 conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
8751
8752 conv_mps = torch.nn.Conv2d(
8753 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
8754 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
8755 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
8756
8757 res_cpu = conv_cpu(x_cpu)
8758 res_mps = conv_mps(x_mps)
8759 self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
8760
8761 res_cpu = res_cpu.sum().backward()
8762 res_mps = res_mps.sum().backward()
8763 self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
8764 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
8765 self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
8766 self.assertEqual(x_cpu.grad, x_mps.grad)
8767
8768 for mem_format_input in [torch.contiguous_format, torch.channels_last]:
8769 for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
8770 for permute_data in [True, False]:
8771 helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
8772 helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
8773 helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
8774
8775 def test_conv_transpose_2d_strided(self):
8776 def helper(m_cpu, memory_format):
8777 m_mps = copy.deepcopy(m_cpu).requires_grad_()
8778 m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
8779 m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
8780
8781 input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
8782 input_mps = input_cpu.detach().clone().to("mps")
8783
8784 output_cpu = m_cpu(input_cpu)
8785 output_mps = m_mps(input_mps)
8786 self.assertEqual(output_cpu, output_mps)
8787
8788 for mem_format_input in [torch.contiguous_format, torch.channels_last]:
8789 # With square kernels and equal stride
8790 helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
8791
8792 # non-square kernels and unequal stride and with padding
8793 helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
8794
8795 def test_conv_transpose_2d_specified_output(self):
8796 input_cpu = torch.randn(1, 16, 12, 12)
8797 input_mps = input_cpu.detach().clone().to("mps")
8798
8799 downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
8800 downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
8801 downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
8802 downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
8803
8804 upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
8805 upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
8806 upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
8807 upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
8808
8809 h_cpu = downsample_cpu(input_cpu)
8810 h_mps = downsample_mps(input_mps)
8811 self.assertEqual(h_cpu, h_mps)
8812
8813 size_cpu = h_cpu.size()
8814 size_mps = h_mps.size()
8815 self.assertEqual(size_cpu, size_mps)
8816
8817 output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
8818 output_mps = upsample_mps(h_mps, output_size=input_mps.size())
8819 self.assertEqual(output_cpu, output_mps)
8820 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Seth31d4b6f2022-08-17 00:26:41 +00008821
8822 def test_conv2d_single_stride(self):
8823 y_cpu = torch.randn(2, 2, 3, 6)
8824 y_gpu = y_cpu.to(device='mps')
8825 for stride in range(1, 4):
8826 conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride)
8827 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
8828 x_cpu = conv_cpu(y_cpu)
8829 x_gpu = conv_gpu(y_gpu)
8830 self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
8831
Denis Vieriu5b8e4852023-02-09 02:25:46 +00008832 def test_grid_sample(self):
8833 def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
8834 def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
8835 for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
8836 # grid_dim_contig_order specifies the dimension order that can
8837 # make grid to be contiguous.
8838 # i.e., grid.permute(grid_dim_contig_order) is contiguous.
8839 # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
8840 # initialized with contiguous tensor of shape [N, 2, H, W]
8841 # and permuted to [N, H, W, 2] afterwards.
8842 grid_shape = [N, H, W, 2]
8843 grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
8844 grid_fwd_permute = [None, None, None, None]
8845 for i, d in enumerate(grid_dim_contig_order):
8846 grid_fwd_permute[d] = i
8847
8848 def get_grid(device='cpu', data=None):
8849 if data is not None:
8850 assert list(data.shape) == grid_shape
8851 data = data.permute(grid_dim_contig_order).to(device)
8852 else:
8853 data = torch.randn(grid_init_shape, device=device)
8854 grid = data.permute(grid_fwd_permute)
8855 assert grid.permute(grid_dim_contig_order).is_contiguous()
8856 return grid
8857
8858 input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
8859 grid_cpu = get_grid().requires_grad_()
8860 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
8861 align_corners=align_corners)
8862 self.assertTrue(out_cpu.size() == torch.Size([N, C, H, W]))
8863
8864 gradients = torch.randn_like(out_cpu)
8865 out_cpu.backward(gradients)
8866
8867
8868 # Compare against unvectorized CPU fallback
8869
8870 # NOTE [ grid_sample CPU fallback ]
8871 # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
8872 # 32-bit floats. So we also have a fallback that is used only for float tensors
8873 # requiring 64-bit indexing. That requires too much memory to run on CI, so we
8874 # also export the fallback and test it here to ensure feature parity with
8875 # the vectorized version.
8876 input_fallback = input_cpu.float().detach_().requires_grad_()
8877 grid_fallback = grid_cpu.float().detach_().requires_grad_()
8878 out_fallback = torch._grid_sampler_2d_cpu_fallback(
8879 input_fallback, grid_fallback,
8880 F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
8881 F.GRID_SAMPLE_PADDING_MODES[padding_mode],
8882 align_corners)
8883 self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
8884
8885 out_fallback.backward(gradients.float())
8886 if input_requires_grad:
8887 self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
8888 self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
8889
8890 input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad)
8891 grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
8892 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
8893 self.assertEqual(out_cpu, out_mps)
8894 out_mps.backward(gradients.to("mps"))
8895 if input_requires_grad:
8896 self.assertEqual(input_cpu.grad, input_mps.grad)
8897 self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0)
8898
8899 # check that zero-dimensional input strides don't error out
8900 base_input = torch.randn(N, C, 1, IW)
8901 input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad)
8902 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
8903 align_corners=align_corners)
8904
8905 input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
8906 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
8907 self.assertEqual(out_cpu, out_mps)
8908
8909 # test same size output
8910 test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
8911
8912 # test larger output
8913 N = random.randint(2, 8)
8914 C = random.randint(2, 8)
8915 IH = random.randint(2, 8)
8916 IW = random.randint(2, 8)
8917 H = random.randint(IH + 1, 12)
8918 W = random.randint(IW + 1, 12)
8919 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
8920
8921 # test smaller output
8922 N = random.randint(2, 8)
8923 C = random.randint(2, 8)
8924 IH = random.randint(2, 8)
8925 IW = random.randint(2, 8)
8926 H = random.randint(2, IH)
8927 W = random.randint(2, IW)
8928 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
8929
8930 # test 1x1 inpput
8931 N = random.randint(2, 8)
8932 C = random.randint(2, 8)
8933 IH = 1
8934 IW = 1
8935 H = random.randint(2, 5)
8936 W = random.randint(2, 5)
8937 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
8938
8939 # testing empty grid
8940 N = random.randint(2, 8)
8941 C = random.randint(2, 8)
8942 IH = random.randint(2, 8)
8943 IW = random.randint(2, 8)
8944 W = random.randint(3, IW + 2)
8945 test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
8946
8947 # testing empty channel
8948 N = random.randint(2, 8)
8949 IH = random.randint(2, 8)
8950 IW = random.randint(2, 8)
8951 H = random.randint(3, IH + 2)
8952 W = random.randint(3, IW + 2)
8953 test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
8954
8955 # testing empty batch
8956 C = random.randint(2, 8)
8957 IH = random.randint(2, 8)
8958 IW = random.randint(2, 8)
8959 H = random.randint(3, IH + 2)
8960 W = random.randint(3, IW + 2)
8961 test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
8962
8963 for mode in ('bilinear', 'nearest'):
8964 for padding_mode in ('zeros', 'reflection'):
8965 for align_corners in (True, False):
8966 # test known input
8967 input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
8968 grid = torch.tensor(
8969 [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
8970 [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
8971 if mode == 'bilinear':
8972 if padding_mode == 'zeros':
8973 if align_corners:
8974 groundtruth = torch.tensor(
8975 [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
8976 [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
8977 else:
8978 groundtruth = torch.tensor(
8979 [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
8980 [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
8981 elif padding_mode == 'border':
8982 if align_corners:
8983 groundtruth = torch.tensor(
8984 [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
8985 [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
8986 else:
8987 groundtruth = torch.tensor(
8988 [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
8989 [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
8990 elif padding_mode == 'reflection':
8991 if align_corners:
8992 groundtruth = torch.tensor(
8993 [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
8994 [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
8995 else:
8996 groundtruth = torch.tensor(
8997 [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
8998 [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
8999 else:
9000 raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode))
9001 elif mode == 'nearest':
9002 if padding_mode == 'zeros':
9003 if align_corners:
9004 groundtruth = torch.tensor(
9005 [[0., 8., 5., 7., 9.],
9006 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
9007 else:
9008 groundtruth = torch.tensor(
9009 [[0., 8., 5., 7., 0.],
9010 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
9011 elif padding_mode == 'border':
9012 if align_corners:
9013 groundtruth = torch.tensor(
9014 [[1., 8., 5., 7., 9.],
9015 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
9016 else:
9017 groundtruth = torch.tensor(
9018 [[1., 8., 5., 7., 9.],
9019 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
9020 elif padding_mode == 'reflection':
9021 if align_corners:
9022 groundtruth = torch.tensor(
9023 [[1., 8., 5., 7., 9.],
9024 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
9025 else:
9026 groundtruth = torch.tensor(
9027 [[1., 8., 5., 7., 9.],
9028 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
9029 else:
9030 raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode))
9031 elif mode == 'bicubic':
9032 if padding_mode == 'zeros':
9033 if align_corners:
9034 groundtruth = torch.tensor(
9035 [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
9036 [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
9037 else:
9038 groundtruth = torch.tensor(
9039 [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
9040 [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
9041 elif padding_mode == 'border':
9042 if align_corners:
9043 groundtruth = torch.tensor(
9044 [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
9045 [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
9046 else:
9047 groundtruth = torch.tensor(
9048 [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
9049 [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
9050 elif padding_mode == 'reflection':
9051 if align_corners:
9052 groundtruth = torch.tensor(
9053 [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
9054 [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
9055 else:
9056 groundtruth = torch.tensor(
9057 [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
9058 [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
9059 else:
9060 raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode))
9061
9062 else:
9063 raise AssertionError("missing groundtruth test for interpolation mode '{}'".format(mode))
9064 output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
9065 align_corners=align_corners)
9066 self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
9067 msg="groundtruth comparison failed for mode={}, "
9068 "padding_mode={}".format(mode, padding_mode))
9069
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00009070class TestAdvancedIndexing(TestCaseMPS):
Kulin Sethce7177f2022-08-18 06:03:16 +00009071 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
Denis Vieriuce4f1872022-09-28 00:47:52 +00009072 supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
Kulin Sethce7177f2022-08-18 06:03:16 +00009073
Denis Vieriu38de9812023-01-04 00:02:24 +00009074 def test_nonzero_no_warning(self):
9075 device = "mps"
9076 t = torch.randn((2, 2), device=device)
9077 with warnings.catch_warnings(record=True) as w:
9078 warnings.simplefilter("always")
9079 torch.nonzero(t)
9080 t.nonzero()
9081 self.assertEqual(len(w), 0)
9082
9083 def test_nonzero(self):
9084 def helper(dtype):
9085 device = "mps"
9086 shapes = [
9087 torch.Size((12,)),
9088 torch.Size((12, 1)),
9089 torch.Size((1, 12)),
9090 torch.Size((6, 2)),
9091 torch.Size((3, 2, 2)),
9092 torch.Size((5, 5, 5)),
9093 ]
9094
9095 def gen_nontrivial_input(shape, dtype, device):
9096 if dtype != torch.bfloat16:
9097 return torch.randint(2, shape, device=device, dtype=dtype)
9098 else:
9099 # windows does not work for bfloat16 randing
9100 return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
9101
9102 for shape in shapes:
9103 tensor = gen_nontrivial_input(shape, dtype, device)
9104 dst1 = torch.nonzero(tensor, as_tuple=False)
9105 dst2 = tensor.nonzero(as_tuple=False)
9106 dst3 = torch.empty([], dtype=torch.long, device=device)
9107 dst3 = dst3.resize_(0)
9108 torch.nonzero(tensor, out=dst3)
9109 np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
9110 np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
9111 self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
9112 self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
9113 self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
9114 tup1 = torch.nonzero(tensor, as_tuple=True)
9115 tup2 = tensor.nonzero(as_tuple=True)
9116 tup1 = torch.stack(tup1).t().cpu()
9117 tup2 = torch.stack(tup2).t().cpu()
9118 self.assertEqual(tup1, np_result, atol=0, rtol=0)
9119 self.assertEqual(tup2, np_result, atol=0, rtol=0)
9120 [helper(dtype) for dtype in self.supported_dtypes]
9121
9122 def test_nonzero_astuple_out(self):
9123 device = "mps"
9124 t = torch.randn((3, 3, 3), device=device)
9125 out = torch.empty([], dtype=torch.long, device=device)
9126 out = out.resize_(0)
9127
9128 with self.assertRaises(RuntimeError):
9129 torch.nonzero(t, as_tuple=True, out=out)
9130
9131 self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
9132
9133 # Verifies that JIT script cannot handle the as_tuple kwarg
9134 # See Issue https://github.com/pytorch/pytorch/issues/45499.
9135 def _foo(t):
9136 tuple_result = torch.nonzero(t, as_tuple=True)
9137 nontuple_result = torch.nonzero(t, as_tuple=False)
9138 out = torch.empty_like(nontuple_result)
9139 torch.nonzero(t, as_tuple=False, out=out)
9140 return tuple_result, nontuple_result, out
9141
9142 with self.assertRaises(RuntimeError):
9143 scripted_foo = torch.jit.script(_foo)
9144
9145 # Verifies that JIT tracing works fine
9146 traced_foo = torch.jit.trace(_foo, t)
9147 traced_tuple, traced_nontuple, traced_out = traced_foo(t)
9148 expected_tuple = torch.nonzero(t, as_tuple=True)
9149 expected_nontuple = torch.nonzero(t)
9150
9151 self.assertEqual(traced_tuple, expected_tuple)
9152 self.assertEqual(traced_nontuple, expected_nontuple)
9153 self.assertEqual(traced_out, expected_nontuple)
9154
9155 def test_nonzero_discontiguous(self):
9156 device = "mps"
9157 shape = (4, 4)
9158 tensor = torch.randint(2, shape, device=device)
9159 tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
9160 dst1 = tensor.nonzero(as_tuple=False)
9161 dst2 = tensor_nc.nonzero(as_tuple=False)
9162 self.assertEqual(dst1, dst2, atol=0, rtol=0)
9163 dst3 = torch.empty_like(dst1)
9164 data_ptr = dst3.data_ptr()
9165 # expect dst3 storage to be reused
9166 torch.nonzero(tensor, out=dst3)
9167 self.assertEqual(data_ptr, dst3.data_ptr())
9168 self.assertEqual(dst1, dst3, atol=0, rtol=0)
9169 # discontiguous out
9170 dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
9171 data_ptr = dst4.data_ptr()
9172 strides = dst4.stride()
9173 torch.nonzero(tensor, out=dst4)
9174 self.assertEqual(data_ptr, dst4.data_ptr())
9175 self.assertEqual(dst1, dst4, atol=0, rtol=0)
9176 self.assertEqual(strides, dst4.stride())
9177
9178 def test_nonzero_non_diff(self):
9179 device = "mps"
9180 x = torch.randn(10, requires_grad=True)
9181 nz = x.nonzero()
9182 self.assertFalse(nz.requires_grad)
9183
Denis Vieriu6a14fcb2022-09-29 23:23:00 +00009184 def test_masked_select(self):
9185 x = torch.randn(3, 4)
9186 x_mps = x.to("mps")
9187 mask = x.ge(0.5)
9188 mask_mps = x_mps.ge(0.5)
9189
9190 res = torch.masked_select(x, mask)
9191 res_mps = torch.masked_select(x_mps, mask_mps)
9192
9193 self.assertEqual(res, res_mps)
9194
Kulin Sethce7177f2022-08-18 06:03:16 +00009195 # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
Denis Vieriuce4f1872022-09-28 00:47:52 +00009196 def test_indexing_get(self):
Kulin Sethce7177f2022-08-18 06:03:16 +00009197 def helper(dtype):
9198 x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
9199 x_mps = x_cpu.detach().clone().to("mps")
9200
9201 y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]]
9202 y_mps = x_mps[[0, 1, 2], [0, 1, 0]]
9203 self.assertEqual(y_cpu, y_mps, str(dtype))
9204 [helper(dtype) for dtype in self.supported_dtypes]
9205
9206 def test_indexing_select_corners(self):
9207 def helper(dtype):
9208 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
9209 x_mps = x_cpu.detach().clone().to("mps")
9210
9211 rows_cpu = torch.tensor([[0, 0], [3, 3]])
9212 rows_mps = rows_cpu.detach().clone().to("mps")
9213
9214 cols_cpu = torch.tensor([[0, 2], [0, 2]])
9215 cols_mps = cols_cpu.detach().clone().to("mps")
9216
9217 res_cpu = x_cpu[rows_cpu, cols_cpu]
9218 res_mps = x_mps[rows_mps, cols_mps]
9219
9220 self.assertEqual(res_cpu, res_mps, str(dtype))
9221 [helper(dtype) for dtype in self.supported_dtypes]
9222
9223 # FIXME: uint8 fails for this testcase, needs further debugging
9224 def test_slicing_using_advanced_index_for_column(self):
9225 def helper(dtype):
9226 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
9227 x_mps = x_cpu.detach().clone().to("mps")
9228
9229 z_cpu = x_cpu[1:4, 1:3]
9230 z_mps = x_mps[1:4, 1:3]
9231 self.assertEqual(z_cpu, z_mps, str(dtype))
9232
9233 # using advanced index for column
9234 y_cpu = x_cpu[1:4, [1, 2]]
9235 y_mps = x_mps[1:4, [1, 2]]
9236 self.assertEqual(y_cpu, y_mps, str(dtype))
9237 # FIXME: use supported_dtypes once uint8 is fixed
9238 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
9239
9240 # FIXME: conditional indexing not working
9241 # def test_boolean_array_indexing_1(self):
9242 # def helper(dtype):
9243 # x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
9244 # x_mps = x_cpu.detach().clone().to("mps")
9245
9246 # res_cpu = x_cpu[x_cpu > 5]
9247 # res_mps = x_mps[x_mps > 5]
9248
9249 # print(res_cpu)
9250 # print(res_mps)
9251
9252 # self.assertEqual(res_cpu, res_mps, str(dtype))
9253 # [helper(dtype) for dtype in self.supported_dtypes]
9254
Denis Vieriuce4f1872022-09-28 00:47:52 +00009255
9256 def test_advanced_indexing_3D_get(self):
9257 def helper(x_cpu):
9258 x_mps = x_cpu.detach().clone().to("mps")
9259 self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
9260 self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
9261 self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
9262
9263 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
9264 [0.5, 0.6, 0.7, 0.8],
9265 [0.9, 1.0, 1.1, 1.2],
9266 [1.3, 1.4, 1.5, 1.6]],
9267
9268 [[2.0, 2.1, 2.2, 2.3],
9269 [2.4, 2.5, 2.6, 2.7],
9270 [2.8, 2.9, 3.0, 3.1],
9271 [3.2, 3.3, 3.4, 3.5]],
9272
9273 [[4.0, 4.1, 4.2, 4.3],
9274 [4.4, 4.5, 4.6, 4.7],
9275 [4.8, 4.9, 5.0, 5.1],
9276 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
9277 helper(x_cpu)
9278 for idx in range(len(self.supported_np_dtypes)):
9279 # torch.randn / torch.rand don't work with all dtypes
9280 # Generate input data for all dtypes on Numpy them move to torch
9281 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
9282 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
9283
9284 helper(inputCPU)
9285
9286 def test_advanced_indexing_3D_put(self):
9287 def helper(x_cpu):
9288 dtype = x_cpu.dtype
9289 x_mps = x_cpu.detach().clone().to("mps")
9290
9291 out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
9292 out_tensor_cpu_view = out_tensor_cpu[1:]
9293
9294 out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
9295 out_tensor_mps_view = out_tensor_mps[1:]
9296
9297 x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
9298 x_mps[[1, 2], 3, :] = out_tensor_mps_view
9299 self.assertEqual(x_cpu, x_mps)
9300
9301 x_cpu[[0, 2], :, :] = out_tensor_cpu_view
9302 x_mps[[0, 2], :, :] = out_tensor_mps_view
9303 self.assertEqual(x_cpu, x_mps)
9304
9305 x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
9306 x_mps[:, [1, 0], [1]] = out_tensor_mps_view
9307 self.assertEqual(x_cpu, x_mps)
9308
9309 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
9310 [0.5, 0.6, 0.7, 0.8],
9311 [0.9, 1.0, 1.1, 1.2],
9312 [1.3, 1.4, 1.5, 1.6]],
9313
9314 [[2.0, 2.1, 2.2, 2.3],
9315 [2.4, 2.5, 2.6, 2.7],
9316 [2.8, 2.9, 3.0, 3.1],
9317 [3.2, 3.3, 3.4, 3.5]],
9318
9319 [[4.0, 4.1, 4.2, 4.3],
9320 [4.4, 4.5, 4.6, 4.7],
9321 [4.8, 4.9, 5.0, 5.1],
9322 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
9323 helper(x_cpu)
9324 for idx in range(len(self.supported_np_dtypes)):
9325 # torch.randn / torch.rand don't work with all dtypes
9326 # Generate input data for all dtypes on Numpy them move to torch
9327 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
9328 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
9329
9330 helper(inputCPU)
9331
9332 def test_index_put_with_view_indices(self):
9333 def helper(dtype):
9334 target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
9335 target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
9336
9337 indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
9338 indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
9339
9340 value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
9341 value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
9342
9343 target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
9344 target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
9345
9346 self.assertEqual(target_cpu, target_mps)
9347
9348 [helper(dtype) for dtype in [torch.int32, torch.float]]
9349
9350 # tests from 'test_indexing.py'
9351 def test_advancedindex_big(self, device="mps"):
9352 reference = torch.arange(0, 123344, dtype=torch.int, device=device)
9353
9354 self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
9355 torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
9356
9357 def test_set_item_to_scalar_tensor(self, device="mps"):
9358 m = random.randint(1, 10)
9359 n = random.randint(1, 10)
9360 z = torch.randn([m, n], device=device)
9361 a = 1.0
9362 w = torch.tensor(a, requires_grad=True, device=device)
9363 z[:, 0] = w
9364 z.sum().backward()
9365 self.assertEqual(w.grad, m * a)
9366
9367 def test_single_int(self, device="mps"):
9368 v = torch.randn(5, 7, 3, device=device)
9369 self.assertEqual(v[4].shape, (7, 3))
9370
9371 def test_multiple_int(self, device="mps"):
9372 v = torch.randn(5, 7, 3, device=device)
9373 self.assertEqual(v[4].shape, (7, 3))
9374 self.assertEqual(v[4, :, 1].shape, (7,))
9375
9376 def test_none(self, device="mps"):
9377 v = torch.randn(5, 7, 3, device=device)
9378 self.assertEqual(v[None].shape, (1, 5, 7, 3))
9379 self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
9380 self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
9381 self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
9382
9383 def test_step(self, device="mps"):
9384 v = torch.arange(10, device=device)
9385 self.assertEqual(v[::1], v)
9386 self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
9387 self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
9388 self.assertEqual(v[::11].tolist(), [0])
9389 self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
9390
9391 def test_step_assignment(self, device="mps"):
9392 v = torch.zeros(4, 4, device=device)
9393 v[0, 1::2] = torch.tensor([3., 4.], device=device)
9394 self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
9395 self.assertEqual(v[1:].sum(), 0)
9396
Kulin Sethce7177f2022-08-18 06:03:16 +00009397 def test_bool_indices(self, device="mps"):
9398 v = torch.randn(5, 7, 3, device=device)
9399 boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
9400 self.assertEqual(v[boolIndices].shape, (3, 7, 3))
9401 self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
9402
9403 v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
9404 boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
9405 uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
9406 with warnings.catch_warnings(record=True) as w:
9407 self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
9408 self.assertEqual(v[boolIndices], v[uint8Indices])
9409 self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
9410 self.assertEqual(len(w), 2)
9411
Denis Vieriu71ec2612023-02-15 06:09:56 +00009412 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Denis Vieriuce4f1872022-09-28 00:47:52 +00009413 def test_bool_indices_accumulate(self, device="mps"):
9414 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
9415 mask = mask > 0
9416 y = torch.ones(size=(10, 10), device=device)
9417 y.index_put_((mask, ), y[mask], accumulate=True)
9418 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
9419
Kulin Sethce7177f2022-08-18 06:03:16 +00009420 def test_multiple_bool_indices(self, device="mps"):
9421 v = torch.randn(5, 7, 3, device=device)
9422 # note: these broadcast together and are transposed to the first dim
9423 mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
9424 mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
9425 self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
9426
Kulin Sethce7177f2022-08-18 06:03:16 +00009427 def test_byte_mask(self, device="mps"):
9428 v = torch.randn(5, 7, 3, device=device)
9429 mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
9430 with warnings.catch_warnings(record=True) as w:
9431 self.assertEqual(v[mask].shape, (3, 7, 3))
9432 self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
9433 self.assertEqual(len(w), 2)
9434
9435 v = torch.tensor([1.], device=device)
9436 self.assertEqual(v[v == 0], torch.tensor([], device=device))
9437
Denis Vieriuce4f1872022-09-28 00:47:52 +00009438 def test_byte_mask_accumulate(self, device="mps"):
9439 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
9440 y = torch.ones(size=(10, 10), device=device)
9441 with warnings.catch_warnings(record=True) as w:
9442 warnings.simplefilter("always")
9443 y.index_put_((mask, ), y[mask], accumulate=True)
9444 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
9445 self.assertEqual(len(w), 2)
9446
9447 def test_index_put_accumulate_expanded_values(self, device="mps"):
9448 t = torch.zeros((5, 2))
9449 t_dev = t.to(device)
9450 indices = [
9451 torch.tensor([0, 1, 2, 3]),
9452 torch.tensor([1, ]),
9453 ]
9454 indices_dev = [i.to(device) for i in indices]
9455 values0d = torch.tensor(1.0)
9456 values1d = torch.tensor([1.0, ])
9457
9458 out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
9459 out_cpu = t.index_put_(indices, values0d, accumulate=True)
9460 self.assertEqual(out_mps.cpu(), out_cpu)
9461
9462 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
9463 out_cpu = t.index_put_(indices, values1d, accumulate=True)
9464 self.assertEqual(out_mps.cpu(), out_cpu)
9465
9466 t = torch.zeros(4, 3, 2)
9467 t_dev = t.to(device)
9468
9469 indices = [
9470 torch.tensor([0, ]),
9471 torch.arange(3)[:, None],
9472 torch.arange(2)[None, :],
9473 ]
9474 indices_dev = [i.to(device) for i in indices]
9475 values1d = torch.tensor([-1.0, -2.0])
9476 values2d = torch.tensor([[-1.0, -2.0], ])
9477
9478 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
9479 out_cpu = t.index_put_(indices, values1d, accumulate=True)
9480 self.assertEqual(out_mps.cpu(), out_cpu)
9481
9482 out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
9483 out_cpu = t.index_put_(indices, values2d, accumulate=True)
9484 self.assertEqual(out_mps.cpu(), out_cpu)
9485
9486 def test_index_put_accumulate_non_contiguous(self, device="mps"):
9487 t = torch.zeros((5, 2, 2))
9488 t_dev = t.to(device)
9489 t1 = t_dev[:, 0, :]
9490 t2 = t[:, 0, :]
9491 self.assertTrue(not t1.is_contiguous())
9492 self.assertTrue(not t2.is_contiguous())
9493
9494 indices = [torch.tensor([0, 1]), ]
9495 indices_dev = [i.to(device) for i in indices]
9496 value = torch.randn(2, 2)
9497 out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
9498 out_cpu = t2.index_put_(indices, value, accumulate=True)
9499 self.assertTrue(not t1.is_contiguous())
9500 self.assertTrue(not t2.is_contiguous())
9501
9502 self.assertEqual(out_mps.cpu(), out_cpu)
9503
9504 def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
9505 # TODO: replace with a better solution.
9506 # Currently, here using torchscript to put None into indices.
9507 # on C++ it gives indices as a list of 2 optional tensors: first is null and
9508 # the second is a valid tensor.
9509 @torch.jit.script
9510 def func(x, i, v):
9511 idx = [None, i]
9512 x.index_put_(idx, v, accumulate=True)
9513 return x
9514
9515 n = 4
9516 t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
9517 t_dev = t.to(device)
9518 indices = torch.tensor([1, 0])
9519 indices_dev = indices.to(device)
9520 value0d = torch.tensor(10.0)
9521 value1d = torch.tensor([1.0, 2.0])
9522
9523 out_mps = func(t_dev, indices_dev, value0d.to("mps"))
9524 out_cpu = func(t, indices, value0d)
9525 self.assertEqual(out_mps.cpu(), out_cpu)
9526
9527 out_mps = func(t_dev, indices_dev, value1d.to("mps"))
9528 out_cpu = func(t, indices, value1d)
9529 self.assertEqual(out_mps.cpu(), out_cpu)
9530
9531 def test_index_put_accumulate_duplicate_indices(self, device="mps"):
9532 for i in range(1, 128):
9533 # generate indices by random walk, this will create indices with
9534 # lots of duplicates interleaved with each other
9535 delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
9536
Nikita Shulga657f2e12022-11-04 01:22:41 +00009537 indices = delta.cumsum(0).long().to("mps")
Denis Vieriuce4f1872022-09-28 00:47:52 +00009538
9539 # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
Denis Vieriu6a14fcb2022-09-29 23:23:00 +00009540 input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
Denis Vieriuce4f1872022-09-28 00:47:52 +00009541 values = torch.randn(indices.size(0), device=device)
9542 output = input.index_put((indices,), values, accumulate=True)
9543
9544 input_list = input.tolist()
9545 indices_list = indices.tolist()
9546 values_list = values.tolist()
9547 for i, v in zip(indices_list, values_list):
9548 input_list[i] += v
9549
9550 self.assertEqual(output, input_list)
9551
9552 def test_multiple_byte_mask(self, device="mps"):
9553 v = torch.randn(5, 7, 3, device=device)
9554 # note: these broadcast together and are transposed to the first dim
9555 mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
9556 mask2 = torch.ByteTensor([1, 1, 1]).to(device)
9557 with warnings.catch_warnings(record=True) as w:
9558 warnings.simplefilter("always")
9559 self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
9560 self.assertEqual(len(w), 2)
9561
9562 def test_byte_mask2d(self, device="mps"):
9563 v = torch.randn(5, 7, 3, device=device)
9564 c = torch.randn(5, 7, device=device)
9565 num_ones = (c > 0).sum()
9566 r = v[c > 0]
9567 self.assertEqual(r.shape, (num_ones, 3))
9568
9569 # FIXME: conditional indexing not working
9570 # def test_jit_indexing(self, device="mps"):
9571 # def fn1(x):
9572 # x[x < 50] = 1.0
9573 # return x
9574
9575 # def fn2(x):
9576 # x[0:50] = 1.0
9577 # return x
9578
9579 # scripted_fn1 = torch.jit.script(fn1)
9580 # scripted_fn2 = torch.jit.script(fn2)
9581 # data = torch.arange(100, device=device, dtype=torch.float)
9582 # out = scripted_fn1(data.detach().clone())
9583 # ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
9584 # self.assertEqual(out, ref)
9585 # out = scripted_fn2(data.detach().clone())
9586 # self.assertEqual(out, ref)
9587
9588 def test_int_indices(self, device="mps"):
9589 v = torch.randn(5, 7, 3, device=device)
9590 self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
9591 self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
9592 self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
9593
9594 def test_index_put_src_datatype(self):
9595 def helper(device, dtype):
9596 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
9597 vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
9598 indices = (torch.tensor([0, 2, 1]),)
9599 res = src.index_put_(indices, vals, accumulate=True)
9600 self.assertEqual(res.shape, src.shape)
9601 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
9602
Denis Vieriu71ec2612023-02-15 06:09:56 +00009603 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Denis Vieriuce4f1872022-09-28 00:47:52 +00009604 def test_index_src_datatype(self):
9605 def helper(device, dtype):
9606 orig_dtype = dtype
9607 if dtype is torch.bool:
9608 dtype = torch.uint8
9609
9610 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
9611 if orig_dtype is torch.bool:
9612 src = src == 1
9613 # test index
9614 res = src[[0, 2, 1], :, :]
9615 self.assertEqual(res.shape, src.shape)
9616 # test index_put, no accum
9617 src[[0, 2, 1], :, :] = res
9618 self.assertEqual(res.shape, src.shape)
9619 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
9620
Kulin Sethce7177f2022-08-18 06:03:16 +00009621 def test_int_indices2d(self, device="mps"):
9622 # From the NumPy indexing example
9623 x = torch.arange(0, 12, device=device).view(4, 3)
9624 rows = torch.tensor([[0, 0], [3, 3]], device=device)
9625 columns = torch.tensor([[0, 2], [0, 2]], device=device)
9626 self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
9627
9628 def test_int_indices_broadcast(self, device="mps"):
9629 # From the NumPy indexing example
9630 x = torch.arange(0, 12, device=device).view(4, 3)
9631 rows = torch.tensor([0, 3], device=device)
9632 columns = torch.tensor([0, 2], device=device)
9633 result = x[rows[:, None], columns]
9634 self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
9635
Denis Vieriuce4f1872022-09-28 00:47:52 +00009636 def test_empty_index(self, device="mps"):
9637 x = torch.arange(0, 12, device=device).view(4, 3)
9638 idx = torch.tensor([], dtype=torch.long, device=device)
9639 self.assertEqual(x[idx].numel(), 0)
9640
9641 # empty assignment should have no effect but not throw an exception
9642 y = x.clone()
9643 y[idx] = -1
9644 self.assertEqual(x, y)
9645
9646 mask = torch.zeros(4, 3, device=device).bool()
9647 y[mask] = -1
9648 self.assertEqual(x, y)
9649
Kulin Sethce7177f2022-08-18 06:03:16 +00009650 def test_empty_ndim_index(self, device="mps"):
9651 x = torch.randn(5, device=device)
9652 self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
9653
9654 x = torch.randn(2, 3, 4, 5, device=device)
9655 self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
9656 x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
9657
9658 x = torch.empty(10, 0, device=device)
9659 self.assertEqual(x[[1, 2]].shape, (2, 0))
9660 self.assertEqual(x[[], []].shape, (0,))
9661 with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
9662 x[:, [0, 1]]
9663
9664 def test_empty_ndim_index_bool(self, device="mps"):
9665 x = torch.randn(5, device=device)
9666 self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
9667
Denis Vieriuce4f1872022-09-28 00:47:52 +00009668 def test_empty_slice(self, device="mps"):
9669 x = torch.randn(2, 3, 4, 5, device=device)
9670 y = x[:, :, :, 1]
9671 z = y[:, 1:1, :]
9672 self.assertEqual((2, 0, 4), z.shape)
9673 # this isn't technically necessary, but matches NumPy stride calculations.
9674 self.assertEqual((60, 20, 5), z.stride())
9675 self.assertTrue(z.is_contiguous())
9676
Kulin Sethce7177f2022-08-18 06:03:16 +00009677 def test_index_getitem_copy_bools_slices(self, device="mps"):
9678 true = torch.tensor(1, dtype=torch.uint8, device=device)
9679 false = torch.tensor(0, dtype=torch.uint8, device=device)
9680
9681 tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
9682
9683 for a in tensors:
9684 self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
9685 self.assertEqual(torch.empty(0, *a.shape), a[False])
9686 self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
9687 self.assertEqual(torch.empty(0, *a.shape), a[false])
9688 self.assertEqual(a.data_ptr(), a[None].data_ptr())
9689 self.assertEqual(a.data_ptr(), a[...].data_ptr())
9690
Denis Vieriuce4f1872022-09-28 00:47:52 +00009691 def test_index_setitem_bools_slices(self, device="mps"):
9692 true = torch.tensor(1, dtype=torch.uint8, device=device)
9693 false = torch.tensor(0, dtype=torch.uint8, device=device)
9694
9695 tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
9696
9697 for a in tensors:
9698 # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
9699 # (some of these ops already prefix a 1 to the size)
9700 neg_ones = torch.ones_like(a) * -1
9701 neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
9702 a[True] = neg_ones_expanded
9703 self.assertEqual(a, neg_ones)
9704 a[False] = 5
9705 self.assertEqual(a, neg_ones)
9706 a[true] = neg_ones_expanded * 2
9707 self.assertEqual(a, neg_ones * 2)
9708 a[false] = 5
9709 self.assertEqual(a, neg_ones * 2)
9710 a[None] = neg_ones_expanded * 3
9711 self.assertEqual(a, neg_ones * 3)
9712 a[...] = neg_ones_expanded * 4
9713 self.assertEqual(a, neg_ones * 4)
9714 if a.dim() == 0:
9715 with self.assertRaises(IndexError):
9716 a[:] = neg_ones_expanded * 5
9717
Kulin Sethce7177f2022-08-18 06:03:16 +00009718 def test_index_scalar_with_bool_mask(self, device="mps"):
9719 a = torch.tensor(1, device=device)
9720 uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
9721 boolMask = torch.tensor(True, dtype=torch.bool, device=device)
9722 self.assertEqual(a[uintMask], a[boolMask])
9723 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
9724
9725 a = torch.tensor(True, dtype=torch.bool, device=device)
9726 self.assertEqual(a[uintMask], a[boolMask])
9727 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
9728
Denis Vieriuce4f1872022-09-28 00:47:52 +00009729 def test_setitem_expansion_error(self, device="mps"):
9730 true = torch.tensor(True, device=device)
9731 a = torch.randn(2, 3, device=device)
9732 # check prefix with non-1s doesn't work
9733 a_expanded = a.expand(torch.Size([5, 1]) + a.size())
9734 # NumPy: ValueError
9735 with self.assertRaises(RuntimeError):
9736 a[True] = a_expanded
9737 with self.assertRaises(RuntimeError):
9738 a[true] = a_expanded
9739
Kulin Sethce7177f2022-08-18 06:03:16 +00009740 def test_getitem_scalars(self, device="mps"):
9741 zero = torch.tensor(0, dtype=torch.int64, device=device)
9742 one = torch.tensor(1, dtype=torch.int64, device=device)
9743
9744 # non-scalar indexed with scalars
9745 a = torch.randn(2, 3, device=device)
9746 self.assertEqual(a[0], a[zero])
9747 self.assertEqual(a[0][1], a[zero][one])
9748 self.assertEqual(a[0, 1], a[zero, one])
9749 self.assertEqual(a[0, one], a[zero, 1])
9750
9751 # indexing by a scalar should slice (not copy)
9752 self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
9753 self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
9754 self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
9755
9756 # scalar indexed with scalar
9757 r = torch.randn((), device=device)
9758 with self.assertRaises(IndexError):
9759 r[:]
9760 with self.assertRaises(IndexError):
9761 r[zero]
9762 self.assertEqual(r, r[...])
9763
Denis Vieriuce4f1872022-09-28 00:47:52 +00009764 def test_setitem_scalars(self, device="mps"):
9765 zero = torch.tensor(0, dtype=torch.int64)
9766
9767 # non-scalar indexed with scalars
9768 a = torch.randn(2, 3, device=device)
9769 a_set_with_number = a.clone()
9770 a_set_with_scalar = a.clone()
9771 b = torch.randn(3, device=device)
9772
9773 a_set_with_number[0] = b
9774 a_set_with_scalar[zero] = b
9775 self.assertEqual(a_set_with_number, a_set_with_scalar)
9776 a[1, zero] = 7.7
9777 self.assertEqual(7.7, a[1, 0])
9778
9779 # scalar indexed with scalars
9780 r = torch.randn((), device=device)
9781 with self.assertRaises(IndexError):
9782 r[:] = 8.8
9783 with self.assertRaises(IndexError):
9784 r[zero] = 8.8
9785 r[...] = 9.9
9786 self.assertEqual(9.9, r)
9787
9788 def test_basic_advanced_combined(self, device="mps"):
9789 # From the NumPy indexing example
9790 x = torch.arange(0, 12, device=device).view(4, 3)
9791 self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
9792 self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
9793
9794 # Check that it is a copy
9795 unmodified = x.clone()
9796 x[1:2, [1, 2]].zero_()
9797 self.assertEqual(x, unmodified)
9798
9799 # But assignment should modify the original
9800 unmodified = x.clone()
9801 x[1:2, [1, 2]] = 0
9802 self.assertNotEqual(x, unmodified)
9803
9804 def test_int_assignment(self, device="mps"):
9805 x = torch.arange(0, 4, device=device).view(2, 2)
9806 x[1] = 5
9807 self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
9808
9809 x = torch.arange(0, 4, device=device).view(2, 2)
9810 x[1] = torch.arange(5, 7, device=device)
9811 self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
9812
9813 def test_byte_tensor_assignment(self, device="mps"):
9814 x = torch.arange(0., 16, device=device).view(4, 4)
9815 b = torch.ByteTensor([True, False, True, False]).to(device)
9816 value = torch.tensor([3., 4., 5., 6.], device=device)
9817
9818 with warnings.catch_warnings(record=True) as w:
9819 x[b] = value
9820 self.assertEqual(len(w), 1)
9821
9822 self.assertEqual(x[0], value)
9823 self.assertEqual(x[1], torch.arange(4., 8, device=device))
9824 self.assertEqual(x[2], value)
9825 self.assertEqual(x[3], torch.arange(12., 16, device=device))
9826
Kulin Sethce7177f2022-08-18 06:03:16 +00009827 def test_variable_slicing(self, device="mps"):
9828 x = torch.arange(0, 16, device=device).view(4, 4)
9829 indices = torch.IntTensor([0, 1]).to(device)
9830 i, j = indices
9831 self.assertEqual(x[i:j], x[0:1])
9832
9833 def test_ellipsis_tensor(self, device="mps"):
9834 x = torch.arange(0, 9, device=device).view(3, 3)
9835 idx = torch.tensor([0, 2], device=device)
9836 self.assertEqual(x[..., idx].tolist(), [[0, 2],
9837 [3, 5],
9838 [6, 8]])
9839 self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
9840 [6, 7, 8]])
9841
9842 def test_invalid_index(self, device="mps"):
9843 x = torch.arange(0, 16, device=device).view(4, 4)
9844 self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
9845
Denis Vieriuce4f1872022-09-28 00:47:52 +00009846 def test_out_of_bound_index(self, device="mps"):
9847 x = torch.arange(0, 100, device=device).view(2, 5, 10)
9848 self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
9849 self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
9850 self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
9851 lambda: x[0, 1, 15])
9852 self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
9853 lambda: x[:, :, 12])
9854
9855 def test_zero_dim_index(self, device="mps"):
9856 x = torch.tensor(10, device=device)
9857 self.assertEqual(x, x.item())
9858
9859 def runner():
9860 print(x[0])
9861 return x[0]
9862
9863 self.assertRaisesRegex(IndexError, 'invalid index', runner)
9864
9865 def test_cpu_indices(self, device="mps"):
9866 idx = torch.tensor([0, 1])
9867 b = torch.zeros(2, device=device)
9868 x = torch.ones(10, device=device)
9869 x[idx] = b # index_put_
9870 ref = torch.ones(10, device=device)
9871 ref[:2] = 0
9872 self.assertEqual(x, ref, atol=0, rtol=0)
9873 out = x[idx] # index
9874 self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
9875
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00009876class TestRNNMPS(TestCaseMPS):
alexdremov78da3152023-03-05 00:19:51 +00009877 def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
9878 seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
9879 rnn = nn.LSTM(
9880 input_size=input_size,
9881 hidden_size=hidden_size,
9882 num_layers=num_layers,
9883 bias=bias,
9884 bidirectional=bidirectional,
9885 batch_first=batch_first,
9886 device="cpu"
9887 )
9888 bidirectional_mul = 2 if bidirectional else 1
Kulin Sethe011a8e2022-05-13 18:28:53 +00009889
alexdremov78da3152023-03-05 00:19:51 +00009890 if batch_first:
9891 input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
9892 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
9893 requires_grad=backward)
9894 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
9895 requires_grad=backward)
9896 else:
9897 input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
9898 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
9899 requires_grad=backward)
9900 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
9901 requires_grad=backward)
Kulin Sethe011a8e2022-05-13 18:28:53 +00009902
alexdremov78da3152023-03-05 00:19:51 +00009903 cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
9904
9905 rnn = rnn.to(device)
9906 input = input.to(device)
9907 hx = hx.to(device)
9908 cx = cx.to(device)
9909 output, (hn, cn) = rnn(input, (hx, cx))
9910
9911 self.assertEqual(cpu_output, output)
9912 self.assertEqual(cpu_hn, hn)
9913 self.assertEqual(cpu_cn, cn)
9914
alexdremov62eb7a22023-03-16 15:53:52 +00009915 def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
alexdremovb9e95152023-02-23 17:32:42 +00009916 rnn = rnn.to(device)
alexdremov78da3152023-03-05 00:19:51 +00009917 inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
Alban Desmaison02551a02022-05-28 12:39:10 -04009918
alexdremov62eb7a22023-03-16 15:53:52 +00009919 output, (hx_out, cx_out) = rnn(inp, (hx, cx))
9920 assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
9921
9922 f = 0
9923 if output_grad_presented:
9924 f = f + 3 * output.sum()
9925 if states_grad_presented:
9926 f = f + (hx_out * cx_out).sum()
qqaatwb0b24b42022-07-07 07:18:00 +00009927
alexdremov78da3152023-03-05 00:19:51 +00009928 param_names, params = zip(*rnn.named_parameters())
9929 param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
qqaatwb0b24b42022-07-07 07:18:00 +00009930
alexdremov78da3152023-03-05 00:19:51 +00009931 input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
9932 return output, param_grads, input_grad, hx_grad, cx_grad
qqaatwb0b24b42022-07-07 07:18:00 +00009933
alexdremov78da3152023-03-05 00:19:51 +00009934 if backward:
alexdremov62eb7a22023-03-16 15:53:52 +00009935 grad_cases = [
9936 dict(output_grad_presented=True, states_grad_presented=True),
9937 dict(output_grad_presented=False, states_grad_presented=True),
9938 dict(output_grad_presented=True, states_grad_presented=False),
9939 ]
alexdremov78da3152023-03-05 00:19:51 +00009940
alexdremov62eb7a22023-03-16 15:53:52 +00009941 for grad_case in grad_cases:
9942 cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
9943 get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
9944 mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
9945 get_backward_results(rnn, device, input, hx, cx, **grad_case)
9946
9947 self.assertEqual(cpu_hx_grad, mps_hx_grad)
9948 self.assertEqual(cpu_cx_grad, mps_cx_grad)
9949 self.assertEqual(cpu_output, mps_output)
9950 self.assertEqual(cpu_input_grad, mps_input_grad)
9951 for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
9952 self.assertEqual(cpu_weight_grad, mps_weight_grad,
9953 f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
alexdremov78da3152023-03-05 00:19:51 +00009954
9955 LSTM_TEST_CASES = [
9956 dict(), # default
9957 dict(batch_first=True),
9958 dict(bias=False),
9959 dict(bidirectional=True),
9960 dict(batch_first=True, bias=False),
9961 dict(bidirectional=True, bias=False),
9962 dict(bidirectional=True, batch_first=True),
9963 dict(bidirectional=True, batch_first=True, bias=False)
9964 ]
9965
9966 def test_lstm_forward(self, device="mps", dtype=torch.float32):
Li-Huai (Allan) Lina87f3f62023-03-10 03:10:49 +00009967 for num_layers in [1, 2, 5]:
alexdremov78da3152023-03-05 00:19:51 +00009968 for test_options in self.LSTM_TEST_CASES:
9969 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
qqaatwb0b24b42022-07-07 07:18:00 +00009970
alexdremovb9e95152023-02-23 17:32:42 +00009971 def test_lstm_backward(self, device="mps", dtype=torch.float32):
Li-Huai (Allan) Lina87f3f62023-03-10 03:10:49 +00009972 for num_layers in [1, 2, 5]:
alexdremov78da3152023-03-05 00:19:51 +00009973 for test_options in self.LSTM_TEST_CASES:
9974 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
alexdremovb9e95152023-02-23 17:32:42 +00009975
Kulin Seth54ebf252023-02-15 16:10:40 +00009976 def test_RNN_cell_no_broadcasting(self):
9977 def test(cell_module, input, hx, input_size, hidden_size):
9978 cell = cell_module(input_size, hidden_size, device='mps')
9979 self.assertRaises(RuntimeError, lambda: cell(input, hx))
9980
9981 def test_all(hidden_size, bad_hx, good_hx, input_size, input):
9982 test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
9983 test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
9984 test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
9985 test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
9986
9987 hidden_size = 20
9988 input_size = 10
9989 input = torch.randn(3, input_size, device='mps')
9990 bad_hx = torch.randn(1, hidden_size, device='mps')
9991 good_hx = torch.randn(3, hidden_size, device='mps')
9992
9993 # Test hidden/input batch size broadcasting
9994 test_all(hidden_size, bad_hx, good_hx, input_size, input)
9995
9996 # Test hx's hidden_size vs module's hidden_size broadcasting
9997 bad_hx = torch.randn(3, 1)
9998 test_all(hidden_size, bad_hx, good_hx, input_size, input)
9999
10000 # Test input's input_size vs module's input_size broadcasting
10001 bad_input = torch.randn(3, 1)
10002 test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
10003
10004 def test_LSTM_cell(self):
10005 # this is just a smoke test; these modules are implemented through
10006 # autograd so no Jacobian test is needed
10007 for bias in (True, False):
10008 input = torch.randn(3, 10, device='mps')
10009 hx = torch.randn(3, 20, device='mps')
10010 cx = torch.randn(3, 20, device='mps')
10011 lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
10012 for _ in range(6):
10013 hx, cx = lstm(input, (hx, cx))
10014
10015 (hx + cx).sum().backward()
10016
10017 def test_LSTM_cell_forward_input_size(self):
10018 input = torch.randn(3, 11, device='mps')
10019 hx = torch.randn(3, 20, device='mps')
10020 cx = torch.randn(3, 20, device='mps')
10021 lstm = nn.LSTMCell(10, 20, device='mps')
10022 self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
10023
10024 def test_LSTM_cell_forward_hidden_size(self):
10025 input = torch.randn(3, 10, device='mps')
10026 hx = torch.randn(3, 21, device='mps')
10027 cx = torch.randn(3, 20, device='mps')
10028 lstm = nn.LSTMCell(10, 20, device='mps')
10029 self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
10030 self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
10031
10032
Kulin Seth3d833212022-05-20 03:18:09 +000010033class TestFallbackWarning(TestCase):
Nikita Shulga97594a22022-06-09 13:07:03 +000010034 # TODO: Remove once test_testing.py is running on MPS devices
Kulin Seth3d833212022-05-20 03:18:09 +000010035 def test_no_warning_on_import(self):
Nikita Shulga97594a22022-06-09 13:07:03 +000010036 out = subprocess.check_output(
10037 [sys.executable, "-W", "all", "-c", "import torch"],
10038 stderr=subprocess.STDOUT,
10039 # On Windows, opening the subprocess with the default CWD makes `import torch`
10040 # fail, so just set CWD to this script's directory
10041 cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
Nikita Shulga078c25df2022-11-08 21:10:07 +000010042 self.assertEqual(out, "")
Kulin Seth3d833212022-05-20 03:18:09 +000010043
10044 def _get_not_implemented_op(self):
Denis Vieriuf7939b22023-01-03 06:01:07 +000010045 # This can be changed once we actually implement `torch.histc`
Kulin Seth3d833212022-05-20 03:18:09 +000010046 # Should return fn, args, kwargs, string_version
Denis Vieriuf7939b22023-01-03 06:01:07 +000010047 return (torch.histc,
10048 torch.tensor([100], device='mps'), {},
10049 "torch.histc(torch.tensor([4], device='mps', dtype=torch.float))")
Kulin Seth3d833212022-05-20 03:18:09 +000010050
10051 def test_error_on_not_implemented(self):
10052 fn, args, kwargs, _ = self._get_not_implemented_op()
10053
Nikita Shulga9b16bf02022-09-12 22:25:26 +000010054 with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
Kulin Seth3d833212022-05-20 03:18:09 +000010055 fn(*args, **kwargs)
10056
10057 def test_warn_on_not_implemented_with_fallback(self):
10058 _, _, _, op = self._get_not_implemented_op()
10059 script = f"""
10060import os
10061# MUST happen before pytorch's import
10062os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
10063import warnings
10064
10065with warnings.catch_warnings(record=True) as w:
10066 import torch
10067
10068if len(w) > 0:
Nikita Shulga97594a22022-06-09 13:07:03 +000010069 print(w)
Kulin Seth3d833212022-05-20 03:18:09 +000010070 exit(1)
10071
10072# This should run just fine and raise warning about perf
10073with warnings.catch_warnings(record=True) as w:
10074 {op}
10075
10076if len(w) != 1:
Nikita Shulga97594a22022-06-09 13:07:03 +000010077 print(w)
Kulin Seth3d833212022-05-20 03:18:09 +000010078 exit(2)
10079
10080"""
10081 try:
10082 subprocess.check_output(
10083 [sys.executable, '-W', 'all', '-c', script],
10084 stderr=subprocess.STDOUT,
10085 # On Windows, opening the subprocess with the default CWD makes `import torch`
10086 # fail, so just set CWD to this script's directory
10087 cwd=os.path.dirname(os.path.realpath(__file__)),)
10088 except subprocess.CalledProcessError as e:
10089 if e.returncode == 1:
Nikita Shulga97594a22022-06-09 13:07:03 +000010090 self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
10091 e.output.decode("utf-8"))
Kulin Seth3d833212022-05-20 03:18:09 +000010092 elif e.returncode == 2:
10093 self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
Nikita Shulga97594a22022-06-09 13:07:03 +000010094 f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
Kulin Seth3d833212022-05-20 03:18:09 +000010095 else:
Nikita Shulga97594a22022-06-09 13:07:03 +000010096 self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
10097 e.output.decode("utf-8"))
Kulin Sethe011a8e2022-05-13 18:28:53 +000010098
Alban Desmaison04ac80c2022-05-20 20:25:12 +000010099class TestNoRegression(TestCase):
10100 def test_assert_close(self):
10101 a = torch.ones(1, device="mps")
10102 b = torch.zeros(1, device="mps")
10103 inf = a / b
10104 nan = b / b
10105
10106 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
10107 torch.testing.assert_close(a, inf)
10108
Kulin Seth76cff182022-07-04 06:41:39 +000010109 # TODO: The NaN test is failing when all the tests in test_mps are run
10110 # together but passes when run separately. There seems to be memory
10111 # corruption which needs to be fixed for this test to be enabled.
10112 # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
10113 # torch.testing.assert_close(a, nan)
Alban Desmaison04ac80c2022-05-20 20:25:12 +000010114
10115 def test_double_error(self):
10116 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
10117 a = torch.ones(2, dtype=torch.float64, device="mps")
10118
10119 a = torch.ones(2, device="mps")
10120 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
10121 a = a.double()
10122
10123 def test_legacy_constructor(self):
10124 a = torch.ones(2, device="mps")
10125
10126 b = a.new(1)
10127
Alban Desmaison0a651a22022-06-14 17:54:30 +000010128 def test_serialization_map_location(self):
10129
10130 # Ensures that cpu Tensor can be loaded on mps
10131 with tempfile.NamedTemporaryFile() as f:
10132 x = torch.rand(2)
10133 torch.save(x, f)
10134
10135 f.seek(0)
10136 x2 = torch.load(f, map_location="mps")
10137
10138 self.assertEqual(x, x2)
10139 self.assertEqual(x2.device.type, "mps")
10140
10141 # Ensures that mps Tensors can be loaded on mps
10142 with tempfile.NamedTemporaryFile() as f:
10143 x = torch.rand(2, device="mps")
10144 torch.save(x, f)
10145
10146 f.seek(0)
10147 x2 = torch.load(f)
10148
10149 self.assertEqual(x, x2)
10150 self.assertEqual(x2.device.type, "mps")
10151
10152 # Ensures that mps Tensors can be loaded on cpu
10153 with tempfile.NamedTemporaryFile() as f:
10154 x = torch.rand(2, device="mps")
10155 torch.save(x, f)
10156
10157 f.seek(0)
10158 x2 = torch.load(f, map_location="cpu")
10159
10160 self.assertEqual(x, x2)
10161 self.assertEqual(x2.device.type, "cpu")
10162
10163
Kulin Seth76cff182022-07-04 06:41:39 +000010164MPS_DTYPES = get_all_dtypes()
Denis Vieriued1957d2023-03-01 01:36:36 +000010165for t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]:
Kulin Seth76cff182022-07-04 06:41:39 +000010166 del MPS_DTYPES[MPS_DTYPES.index(t)]
Alban Desmaison04ac80c2022-05-20 20:25:12 +000010167
Kulin Seth2bb022e2023-03-08 08:41:21 +000010168MPS_GRAD_DTYPES = [torch.float32, torch.float16]
10169
soulitzerbfdfeec2022-08-31 17:53:32 -040010170
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000010171class TestConsistency(TestCaseMPS):
Kulin Seth76cff182022-07-04 06:41:39 +000010172 # TODO: This is only used while some ops are being added.
10173 # This list should contain all ops and dtypes eventually
10174 # This can be generated automatically in the `new_mps_allowlist.txt` file
10175 # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
10176 # You most likely do NOT want to modify this manually
Kulin Seth76cff182022-07-04 06:41:39 +000010177
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000010178 FP16_LOW_PRECISION_LIST = {
10179 'add', 'sub', 'div',
10180 '__rdiv__', '__rmul__',
10181 'nn.functional.huber_loss',
10182 'true_divide', 'kron',
Kulin Seth2bb022e2023-03-08 08:41:21 +000010183 'gradient', 'var', 'std', 'ldexp',
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000010184 'linalg.vector_norm',
Kulin Seth2bb022e2023-03-08 08:41:21 +000010185 'addr', 'var_mean',
10186 'var_mean_unbiased',
10187
10188 # for macOS 12
10189 'masked.normalize', 'masked.sum', 'masked.var',
10190 'outer',
10191 'sum_to_size', 'sum',
10192 'mul',
10193 'nansum', 'nanmean',
10194 'norm',
10195 }
10196
10197 FP32_LOW_PRECISION_LIST = {
10198 # conv2d and conv_transpose2d results have a very small
10199 # difference compared to CPU/CUDA, so we use lower precision on FP32
10200 'nn.functional.conv2d',
10201 'nn.functional.conv_transpose2d',
10202 'matmul', '__rmatmul__',
10203 'linalg.multi_dot',
10204 'addbmm',
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000010205 }
10206
Kulin Seth76cff182022-07-04 06:41:39 +000010207 # Used for accept mode only
10208 NEW_ALLOW_LIST = defaultdict(list)
soulitzerbfdfeec2022-08-31 17:53:32 -040010209 NEW_ALLOW_LIST_GRAD = defaultdict(list)
Kulin Seth76cff182022-07-04 06:41:39 +000010210
Nikita Shulgafd8367a2023-02-27 15:01:01 +000010211 @ops(mps_ops_modifier(op_db), allowed_dtypes=MPS_DTYPES)
Kulin Seth76cff182022-07-04 06:41:39 +000010212 def test_output_match(self, device, dtype, op):
10213 self.assertEqual(device, "cpu")
Kulin Seth76cff182022-07-04 06:41:39 +000010214 key = op.name + op.variant_test_name
Kulin Seth2bb022e2023-03-08 08:41:21 +000010215 run_grad_test = True
Nikita Shulga3859aac2022-12-14 19:51:00 +000010216
Kulin Seth2bb022e2023-03-08 08:41:21 +000010217 def get_samples():
10218 return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex))
10219 cpu_samples = get_samples()
Kulin Seth76cff182022-07-04 06:41:39 +000010220
Kulin Seth2bb022e2023-03-08 08:41:21 +000010221 all_backward_pass = True
10222 for cpu_sample in cpu_samples:
10223 #
10224 # Forward check
10225 #
10226 mps_sample = cpu_sample.transform(
10227 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
10228
10229 cpu_args = [cpu_sample.input] + list(cpu_sample.args)
10230 cpu_kwargs = cpu_sample.kwargs
10231 mps_args = [mps_sample.input] + list(mps_sample.args)
10232 mps_kwargs = mps_sample.kwargs
10233
10234 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
10235 if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
10236 mps_args[1] = cpu_args[1]
10237
10238 cpu_out = op(*cpu_args, **cpu_kwargs)
10239 mps_out = op(*mps_args, **mps_kwargs)
10240
10241 if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32:
10242 atol = 1e-4
10243 rtol = 3e-5
10244 elif op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16:
10245 atol = 1e-2
10246 rtol = 1e-2
10247 elif op.name == "masked.mean":
10248 atol = 7e-4
10249 rtol = 2e-3
10250 elif op.name == "native_layer_norm":
10251 atol = 1e-4
10252 rtol = 1.3e-5
10253 elif op.name in ["pow", "__rpow__"]:
10254 atol = 1e-6
10255 rtol = 4e-6
10256 else:
10257 atol = None
10258 rtol = None
10259
10260 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
10261
10262
10263 @ops(mps_ops_grad_modifier(copy.deepcopy(op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
10264 def test_output_grad_match(self, device, dtype, op):
10265 self.assertEqual(device, "cpu")
10266 key = op.name + op.variant_test_name
Kulin Seth76cff182022-07-04 06:41:39 +000010267
soulitzerbfdfeec2022-08-31 17:53:32 -040010268 run_grad_test = True
Kulin Seth76cff182022-07-04 06:41:39 +000010269
soulitzerbfdfeec2022-08-31 17:53:32 -040010270 def get_samples():
10271 return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex))
10272 cpu_samples = get_samples()
10273
10274 all_forward_pass = True
10275 all_backward_pass = True
10276 for cpu_sample in cpu_samples:
10277 #
10278 # Forward check
10279 #
10280 forward_failed = False
10281 try:
10282 mps_sample = cpu_sample.transform(
10283 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
Kulin Seth76cff182022-07-04 06:41:39 +000010284
Kulin Seth76cff182022-07-04 06:41:39 +000010285 cpu_args = [cpu_sample.input] + list(cpu_sample.args)
10286 cpu_kwargs = cpu_sample.kwargs
10287 mps_args = [mps_sample.input] + list(mps_sample.args)
10288 mps_kwargs = mps_sample.kwargs
10289
Ramin Azarmehrb654d142023-02-07 15:56:46 +000010290 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
10291 if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
10292 mps_args[1] = cpu_args[1]
10293
Kulin Seth76cff182022-07-04 06:41:39 +000010294 cpu_out = op(*cpu_args, **cpu_kwargs)
10295 mps_out = op(*mps_args, **mps_kwargs)
Kulin Seth76cff182022-07-04 06:41:39 +000010296
Kulin Seth2bb022e2023-03-08 08:41:21 +000010297 if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32:
soulitzerbfdfeec2022-08-31 17:53:32 -040010298 atol = 1e-4
10299 rtol = 3e-5
Kulin Seth2bb022e2023-03-08 08:41:21 +000010300 elif op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32:
10301 atol = 1e-4
10302 rtol = 3e-5
10303 elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16:
soulitzerbfdfeec2022-08-31 17:53:32 -040010304 atol = 1e-2
10305 rtol = 1e-2
Kulin Seth2bb022e2023-03-08 08:41:21 +000010306 elif (op.name == "masked.mean"):
Denis Vieriu86ae14d2023-02-07 16:20:52 +000010307 atol = 7e-4
10308 rtol = 2e-3
Kulin Seth2bb022e2023-03-08 08:41:21 +000010309 elif (op.name == "native_layer_norm"):
Denis Vieriua1f15fb2023-02-10 05:53:33 +000010310 atol = 1e-4
10311 rtol = 1.3e-5
Kulin Seth2bb022e2023-03-08 08:41:21 +000010312 elif op.name == "norm" and dtype == torch.float16:
10313 atol = 7e-4
10314 rtol = 1.5e-3
10315 elif op.name == "unique" and cpu_kwargs["sorted"] is False:
10316 continue
soulitzerbfdfeec2022-08-31 17:53:32 -040010317 else:
10318 atol = None
10319 rtol = None
Kulin Seth76cff182022-07-04 06:41:39 +000010320
soulitzerbfdfeec2022-08-31 17:53:32 -040010321 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
10322
10323 except Exception as e:
Kulin Seth2bb022e2023-03-08 08:41:21 +000010324 raise e
soulitzerbfdfeec2022-08-31 17:53:32 -040010325 forward_failed = True
10326 all_forward_pass = False
10327
soulitzerbfdfeec2022-08-31 17:53:32 -040010328 #
10329 # Backward check
10330 #
Kulin Seth2bb022e2023-03-08 08:41:21 +000010331 if forward_failed:
10332 # We would've failed immediately anyway, but this error is clearer
10333 # We error instead of continuing so that all_backward_pass would not be True
10334 raise RuntimeError("Forward pass already failed")
soulitzerbfdfeec2022-08-31 17:53:32 -040010335
Kulin Seth2bb022e2023-03-08 08:41:21 +000010336 cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
10337 mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
10338
10339 def req_grad(t):
10340 return isinstance(t, torch.Tensor) and t.requires_grad
10341
10342 diff_cpu_out = tuple(t for t in cpu_out if req_grad(t))
10343 diff_mps_out = tuple(t for t in mps_out if req_grad(t))
10344 diff_cpu_arg = tuple(t for t in pytree.tree_flatten((cpu_args, cpu_kwargs))[0] if req_grad(t))
10345 diff_mps_arg = tuple(t for t in pytree.tree_flatten((mps_args, mps_kwargs))[0] if req_grad(t))
10346 self.assertEqual(len(diff_cpu_out), len(diff_mps_out))
10347 self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg))
10348
10349 if len(diff_cpu_out) == 0:
soulitzerbfdfeec2022-08-31 17:53:32 -040010350 continue
Kulin Seth2bb022e2023-03-08 08:41:21 +000010351 # rand_like does not work with certain dtypes, so cast to double and cast back
10352 cpu_grad_outputs = tuple(torch.rand_like(t.to(dtype=torch.double)).to(dtype=dtype) for t in diff_cpu_out)
10353 mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
soulitzerbfdfeec2022-08-31 17:53:32 -040010354
Kulin Seth2bb022e2023-03-08 08:41:21 +000010355 # Compare computed gradients with cpu given random grad_output vector
10356 # Sometimes when the derivative is 0, we just don't bother creating the graph
10357 # allow_unused is needed in those cases.
10358 cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
10359 mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
soulitzerbfdfeec2022-08-31 17:53:32 -040010360
Kulin Seth2bb022e2023-03-08 08:41:21 +000010361 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
Alex620dbc42022-10-21 19:03:00 +000010362
10363# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
10364@skipIfSlowGradcheckEnv
10365class TestCommon(TestCase):
10366 exact_dtype = True
10367
10368 # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
10369 @classmethod
10370 def tearDownClass(cls):
10371 super().tearDownClass()
10372
10373 if IS_CI:
10374 err_msg = (
10375 "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
10376 "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
10377 )
10378 # Assure no opinfo entry has dynamic_dtypes
10379 filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
10380 for op in filtered_ops:
10381 fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
10382 err_msg += "\n" + fmt_str
10383
10384 assert len(filtered_ops) == 0, err_msg
10385
10386 # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
10387 # MPS still requires some fairly heavy special casing in the test framework.
10388 # When MPS becomes more consistent, this can probably be merged with that test using
10389 # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
10390 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
Alex620dbc42022-10-21 19:03:00 +000010391 @suppress_warnings
10392 # MPS only supports float32
10393 @ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
10394 def test_numpy_ref_mps(self, device, dtype, op):
10395 # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
10396 # does not support float64 Tensors.
10397 # A few ops are currently broken on their reference inputs, but not their sample inputs. These should
10398 # get patched up and this workaround removed.
Ramin Azarmehr87164ac2023-01-06 17:28:49 +000010399 broken_on_ref_inputs = op.name in ['clamp', 'where']
Alex620dbc42022-10-21 19:03:00 +000010400 inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype)
10401 for sample_input in inputs:
10402 self.compare_with_reference(op, op.ref, sample_input)
10403
Nikita Shulga436993d2023-03-04 01:29:07 +000010404 @dtypes(*get_all_dtypes())
10405 def test_tensor_creation(self, device, dtype):
10406 def ones(device):
10407 return torch.ones((2, 2), dtype=dtype, device=device)
10408 if dtype not in MPS_DTYPES:
10409 with self.assertRaises(TypeError):
10410 ones(device)
10411 else:
10412 mps_tensor = ones(device)
10413 cpu_tensor = ones("cpu")
10414 self.assertEqual(mps_tensor.cpu(), cpu_tensor)
10415
Kulin Seth76cff182022-07-04 06:41:39 +000010416# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
10417# This requires mps to be properly registered in the device generic test framework which is not the
Alex620dbc42022-10-21 19:03:00 +000010418# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
10419# to achieve this.
Kulin Seth76cff182022-07-04 06:41:39 +000010420instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
Nikita Shulga436993d2023-03-04 01:29:07 +000010421instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
Alban Desmaison04ac80c2022-05-20 20:25:12 +000010422
Kulin Sethe011a8e2022-05-13 18:28:53 +000010423if __name__ == "__main__":
10424 run_tests()