blob: ad9316cac762030c08eb84e17446be660b6ee483 [file] [log] [blame]
Kulin Sethe011a8e2022-05-13 18:28:53 +00001# Owner(s): ["module: mps"]
2
Denis Vieriude7ec2d2023-05-25 23:32:29 +00003import io
Denis Vieriu71ec2612023-02-15 06:09:56 +00004import platform
Kulin Sethe011a8e2022-05-13 18:28:53 +00005import sys
6import math
7import random
8import unittest
9import warnings
Kulin Seth3d833212022-05-20 03:18:09 +000010import subprocess
Alban Desmaison0a651a22022-06-14 17:54:30 +000011import tempfile
Kulin Seth3d833212022-05-20 03:18:09 +000012import os
Kulin Seth31d4b6f2022-08-17 00:26:41 +000013import copy
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000014import gc
Nikita Shulga916183a2023-09-13 19:28:47 +000015import threading
Kulin Sethe011a8e2022-05-13 18:28:53 +000016import torch
17import torch.nn as nn
18import torch.nn.functional as F
Kulin Seth978304f2022-05-14 13:33:16 +000019import itertools
Kulin Seth76cff182022-07-04 06:41:39 +000020from collections import defaultdict
Xuehai Panb005ec62023-02-14 09:14:10 +000021from torch import inf
ekamiti9e473fd2024-07-31 10:32:37 +000022from torch.nn import Buffer, Parameter
Alex620dbc42022-10-21 19:03:00 +000023from torch.testing._internal import opinfo
Kulin Seth76cff182022-07-04 06:41:39 +000024from torch.testing._internal.common_utils import \
Nikita Shulga30610252024-05-03 15:20:39 +000025 (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
Denis Vieriu861bdf92024-08-16 21:07:48 +000026 NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
Kulin Sethb744e1c2022-07-01 15:10:56 +000027from torch.testing import make_tensor
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +000028from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
Kulin Sethe011a8e2022-05-13 18:28:53 +000029import torch.backends.mps
Kulin Seth83239352022-06-10 13:16:21 +000030from torch.distributions import Uniform, Exponential
Kulin Sethb744e1c2022-07-01 15:10:56 +000031from functools import partial
PyTorch MergeBotb1943e02022-06-30 16:37:11 +000032
Alex620dbc42022-10-21 19:03:00 +000033from torch.testing._internal.common_methods_invocations import (
34 op_db,
Nikita Shulgafd8367a2023-02-27 15:01:01 +000035 DecorateInfo,
Alex620dbc42022-10-21 19:03:00 +000036 UnaryUfuncInfo,
37 ReductionOpInfo,
38 SpectralFuncInfo,
39 BinaryUfuncInfo,
40)
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000041from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
Kulin Sethe011a8e2022-05-13 18:28:53 +000042from torch.testing._internal.common_nn import NNTestCase
Nikita Shulga4ff91132024-05-24 16:08:04 +000043from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
Kulin Sethe011a8e2022-05-13 18:28:53 +000044import numpy as np
45import torch
soulitzerbfdfeec2022-08-31 17:53:32 -040046import torch.utils._pytree as pytree
Kulin Sethfc596642023-01-04 22:15:13 +000047from itertools import product
Aaron Gokaslan6de28e92023-12-20 19:35:04 +000048import operator
Kulin Sethe011a8e2022-05-13 18:28:53 +000049
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000050test_consistency_op_db = copy.deepcopy(op_db)
51test_error_inputs_op_db = copy.deepcopy(op_db)
Alex620dbc42022-10-21 19:03:00 +000052
53# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
54_ref_test_ops = tuple(
55 filter(
56 lambda op: not isinstance(
57 op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
58 )
59 and op.ref is not None,
60 op_db,
61 )
62)
63
jhavukainend28868c2024-05-20 20:23:53 +000064def xfailIf(condition):
65 def wrapper(func):
66 if condition:
67 return unittest.expectedFailure(func)
68 else:
69 return func
70 return wrapper
71
Nikita Shulga4e29e802024-05-09 13:43:12 +000072def xfailIfMacOS14_4Plus(func):
73 return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821
74
Kulin Seth2bb022e2023-03-08 08:41:21 +000075def mps_ops_grad_modifier(ops):
76 XFAILLIST_GRAD = {
igm5031b9b3a22023-09-12 16:43:37 +000077
78 # precision issues
igm5031b9b3a22023-09-12 16:43:37 +000079 'special.polygammaspecial_polygamma_n_0': [torch.float16],
80 'polygammapolygamma_n_0': [torch.float16],
Nikita Shulga56771282024-04-18 15:21:01 +000081 'nn.functional.binary_cross_entropy': [torch.float16],
igm5031b9b3a22023-09-12 16:43:37 +000082
Kulin Seth2bb022e2023-03-08 08:41:21 +000083 # Unimplemented ops
84 '__getitem__': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +000085 '_segment_reduce': [torch.float16, torch.float32],
Boyuan Feng35d3adb2024-03-08 21:48:08 +000086 '_chunk_cat': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +000087 'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
88 'unfold': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +000089 'sparse.mmreduce': [torch.float32], # csr not supported
90 'unique_consecutive': [torch.float16, torch.float32],
91 'special_modified_bessel_i0': [torch.float16, torch.float32],
92 'scalar_tensor': [torch.float16, torch.float32],
93 'cdist': [torch.float32],
94 'masked.scatter': [torch.float16, torch.float32],
Li-Huai (Allan) Linbe8a4eb2023-04-12 18:13:28 +000095 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
Li-Huai (Allan) Lin799acd32024-06-19 21:07:23 -070096 'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
CaoE4b324a82023-10-23 17:43:47 +000097 'aminmax': [torch.float32, torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +000098
99 # Correctness issues
100 'atanh': [torch.float32],
101
102 # Random output
103 'exponential': [torch.float16, torch.float32],
104
105 # CPU errors
igm503a389181f2023-10-03 19:20:17 +0000106 # derivative for aten::nextafter is not implemented on CPU
107 'nextafter': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000108 # derivative for aten::floor_divide is not implemented on CPU
109 'floor_divide': [torch.float16, torch.float32],
110 # derivative for aten::narrow_copy is not implemented on CPU
111 'narrow_copy': [torch.float16, torch.float32],
Li-Huai (Allan) Linbb355892023-05-17 01:25:43 +0000112 # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
113 'histogramdd': [torch.float16, torch.float32],
114 # derivative for aten::histogram is not implemented
115 'histogram': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000116 # 'bool' object is not iterable
117 'allclose': [torch.float16, torch.float32],
118 'equal': [torch.float16, torch.float32],
Khushi51fe53e2023-05-10 11:32:45 +0000119 # 'float' object is not iterable
120 'item': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000121 # "mse_backward_cpu_out" not implemented for 'Half'
122 'nn.functional.mse_loss': [torch.float16],
123 # "smooth_l1_backward_cpu_out" not implemented for 'Half'
124 'nn.functional.smooth_l1_loss': [torch.float16],
125 # cpu error: grad requires non-empty inputs
126 'randn': [torch.float16, torch.float32],
127 'signal.windows.bartlett': [torch.float32],
128 'signal.windows.blackman': [torch.float32],
129 'signal.windows.cosine': [torch.float32],
130 'signal.windows.exponential': [torch.float32],
131 'signal.windows.gaussian': [torch.float32],
132 'signal.windows.general_cosine': [torch.float32],
133 'signal.windows.general_hamming': [torch.float32],
134 'signal.windows.hamming': [torch.float32],
135 'signal.windows.hann': [torch.float32],
136 'signal.windows.kaiser': [torch.float32],
137 'signal.windows.nuttall': [torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000138 'eye': [torch.float16, torch.float32],
139
140 # trunc_tensor not working properly for float16
141 'divtrunc_rounding': [torch.float16],
142 'fmod': [torch.float16],
Sun, Jiayid56e1b22023-05-11 15:30:59 +0800143
144 # round not working properly for float16
145 'round': [torch.float16],
Isuru Fernandoe6bfa292024-06-24 22:15:14 +0000146
147 # atomic operation in backward pass
148 '_unsafe_masked_index': [torch.float16],
149 '_unsafe_masked_index_put_accumulate': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000150 }
151
152 MACOS_12_3_XFAILLIST_GRAD = {
153 # Unsupported Border padding mode, forward pass success as fallback to cpu
154 'grid_sampler_2d': [torch.float32],
155 # Unimplemented
156 'logaddexp2': [torch.float32],
157
Kulin Seth2bb022e2023-03-08 08:41:21 +0000158 }
159
160 MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
161 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
Cao E1c89ea72023-10-26 08:38:54 +0000162 'masked.softmin': [torch.float32, torch.float16],
163 'masked.softmax': [torch.float32, torch.float16],
164 'masked.log_softmax': [torch.float32, torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000165
166 # Unsupported Border padding mode, forward pass success as fallback to cpu
167 'grid_sampler_2d': [torch.float32],
168
169 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
170 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
171 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
172 # Running `msort` with stable `sort` passes.
173 'msort': [torch.float16],
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000174 }
175
Peter Bell46e80ce2023-10-24 15:19:01 +0100176 SKIPLIST_GRAD = {
Kulin Seth2bb022e2023-03-08 08:41:21 +0000177 'nn.functional.pairwise_distance': [torch.float16],
CaoE7c905212023-09-24 00:25:09 -0700178 # failed assertion `destination datatype must be fp32'
179 'nn.functional.conv1d': [torch.float16],
180 'nn.functional.conv2d': [torch.float16],
181 'nn.functional.conv3d': [torch.float16],
182 'nn.functional.conv_transpose1d': [torch.float16],
183 'nn.functional.conv_transpose2d': [torch.float16],
184 'nn.functional.conv_transpose3d': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000185 }
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000186
Kulin Seth2bb022e2023-03-08 08:41:21 +0000187 MACOS_13_3_XFAILLIST_GRAD = {
188 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
189 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
190 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
191 # Running `msort` with stable `sort` passes.
192 'msort': [torch.float16],
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000193 }
194
watarungurunnnd444a3b2024-02-05 15:36:55 +0000195 ON_MPS_XFAILLIST = {
196 # Failures due to lack of implementation of downstream functions on MPS backend
197 # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
198 'linalg.matrix_rank': None,
Huy Do89921412024-06-05 14:44:00 +0000199
200 # Exception: Caused by sample input at index 3 on MPS
201 'nn.functional.conv3d': [torch.float32],
watarungurunnnd444a3b2024-02-05 15:36:55 +0000202 }
203
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000204 def addDecorator(op, d) -> None:
205 op.decorators = list(op.decorators) if op.decorators is not None else []
206 op.decorators.append(d)
207
208 for op in ops:
209 key = op.name + op.variant_test_name
Kulin Seth2bb022e2023-03-08 08:41:21 +0000210 if key in XFAILLIST_GRAD:
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000211 addDecorator(op, DecorateInfo(
212 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000213 dtypes=XFAILLIST_GRAD[key]))
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000214
Peter Bell46e80ce2023-10-24 15:19:01 +0100215 if key in SKIPLIST_GRAD:
Kulin Seth2bb022e2023-03-08 08:41:21 +0000216 addDecorator(op, DecorateInfo(
217 unittest.skip,
Peter Bell46e80ce2023-10-24 15:19:01 +0100218 dtypes=SKIPLIST_GRAD[key]))
Kulin Seth2bb022e2023-03-08 08:41:21 +0000219
watarungurunnnd444a3b2024-02-05 15:36:55 +0000220 if key in ON_MPS_XFAILLIST:
221 addDecorator(op, DecorateInfo(
222 unittest.expectedFailure,
223 dtypes=ON_MPS_XFAILLIST[key]))
224
Kulin Seth2bb022e2023-03-08 08:41:21 +0000225 if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000226 addDecorator(op, DecorateInfo(
227 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000228 dtypes=MACOS_12_3_XFAILLIST_GRAD[key]))
229
230 if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
Nikita Shulgafd8367a2023-02-27 15:01:01 +0000231 addDecorator(op, DecorateInfo(
232 unittest.expectedFailure,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000233 dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key]))
234
235 if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3):
236 addDecorator(op, DecorateInfo(
237 unittest.expectedFailure,
238 dtypes=MACOS_13_3_XFAILLIST_GRAD[key]))
239 yield op
240
241def mps_ops_modifier(ops):
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700242 # Supported complex OPS
Li-Huai (Allan) Lin293d3b82023-09-11 11:56:27 -0700243 SUPPORTED_COMPLEX_OPS = {
Nikita Shulgac7bb8422023-08-31 20:41:51 -0700244 '__radd__',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600245 '__rmul__',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800246 '__getitem__',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000247 'abs',
Nikita Shulgac7bb8422023-08-31 20:41:51 -0700248 'add',
Tom Ritchford23860452024-06-11 12:54:06 +0000249 'alias_copy',
Denis Vieriua40d6df2024-05-03 03:50:55 +0000250 'argwhere',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700251 'atleast_1d',
252 'atleast_2d',
253 'atleast_3d',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800254 'as_strided',
Tom Ritchfordedb45dc2024-06-12 15:12:58 +0000255 'as_strided_copy',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800256 'as_strided_scatter',
257 'broadcast_tensors',
258 'broadcast_to',
Nikita Shulga8d8fb972024-02-12 10:11:25 -0800259 'chalf',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800260 'cfloat',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800261 'chunk',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700262 'clone',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800263 'conj',
264 'conj_physical',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700265 'contiguous',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800266 'diag',
267 'diag_embed',
268 'diagflat',
269 'diagonal',
270 'diagonal_copy',
271 'diagonal_scatter',
272 'dsplit',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700273 'empty',
274 'empty_permuted',
275 'empty_strided',
276 'eye',
Nikita Shulga06787422024-06-11 15:37:03 -0700277 'exp',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800278 'expand',
279 'expand_as',
Tom Ritchford962f2482024-07-29 08:13:33 +0000280 'expand_copy',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700281 'flatten',
Li-Huai (Allan) Lin4b804da2023-10-23 20:48:11 -0700282 'fill',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700283 'full',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800284 'H',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800285 'hsplit',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700286 'imag',
Nikita Shulga4c70ab22024-03-25 16:57:35 +0000287 'index_select',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700288 'isfinite',
289 'isinf',
290 'isreal',
291 'item',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600292 'kron',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800293 'linalg.diagonal',
294 'linalg.svd',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700295 'linspace',
296 'logspace',
Li-Huai (Allan) Lin293d3b82023-09-11 11:56:27 -0700297 'linspacetensor_overload',
298 'logspacetensor_overload',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800299 'mH',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800300 'mT',
301 'masked_scatter',
302 'masked_select',
303 'meshgridlist_of_tensors',
304 'meshgridvariadic_tensors',
305 'movedim',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600306 'mul',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800307 'narrow',
308 'narrow_copy',
Nikita Shulga1d610112024-02-08 18:10:59 +0000309 'nn.functional.conv1d',
Nikita Shulga045309a2024-05-28 17:56:13 +0000310 'nn.functional.conv2d',
Nikita Shulga1d610112024-02-08 18:10:59 +0000311 'nn.functional.conv_transpose1d',
Nikita Shulga045309a2024-05-28 17:56:13 +0000312 'nn.functional.conv_transpose2d',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700313 'nn.functional.feature_alpha_dropoutwithout_train',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000314 'nn.functional.padcircular',
Nikita Shulga06787422024-06-11 15:37:03 -0700315 'nn.functional.tanhshrink',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700316 'nn.functional.unfold',
Denis Vieriua40d6df2024-05-03 03:50:55 +0000317 'nonzero',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700318 'ones',
Nikita Shulga9b12a282023-09-01 20:52:15 -0600319 'outer',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800320 'permute',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700321 'positive',
322 'randn',
323 'ravel',
324 'real',
Nikita Shulga4c70ab22024-03-25 16:57:35 +0000325 'repeat_interleave',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700326 'reshape_as',
327 'reshape',
328 'resolve_conj',
329 'resolve_neg',
330 'scalar_tensor',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800331 'select',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700332 'sgn',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800333 'slice',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700334 'split',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800335 'split_with_sizes',
Yifu Wanga1280f02024-01-31 15:10:47 -0800336 'split_with_sizes_copy',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800337 'splitlist_args',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700338 'squeeze',
339 'squeezemultiple',
Nikita Shulgac7bb8422023-08-31 20:41:51 -0700340 'sub',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800341 'svd',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700342 't',
Tom Ritchford16247982024-07-21 21:52:27 +0000343 't_copy',
Nikita Shulga06787422024-06-11 15:37:03 -0700344 'tanh',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800345 'tensor_split',
346 'transpose',
347 'T',
348 'unbind',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700349 'unflatten',
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800350 'unfold',
351 'unfold_copy',
352 'unsafe_chunk',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700353 'unsafe_split',
354 'unsqueeze',
Tom Ritchfordbdf5a6d2024-07-29 17:32:06 +0000355 'unsqueeze_copy',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700356 'view_as',
357 'view_as_real',
358 'view',
Tom Ritchford500cbb52024-07-18 13:15:07 +0000359 'view_copy',
Nikita Shulga53a4ca42023-08-31 20:41:39 -0700360 'vsplit',
361 'zero_',
362 'zeros',
Li-Huai (Allan) Lin293d3b82023-09-11 11:56:27 -0700363 }
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000364
365 AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
366 '__rdiv__',
Nikita Shulga045309a2024-05-28 17:56:13 +0000367 '__rmatmul__',
Boyuan Feng35d3adb2024-03-08 21:48:08 +0000368 '_chunk_cat',
Isuru Fernandoe6bfa292024-06-24 22:15:14 +0000369 '_unsafe_masked_index',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000370 'acos',
371 'acosh',
Nikita Shulgaff0f79d2024-01-06 01:10:11 +0000372 'all',
Nikita Shulga1d610112024-02-08 18:10:59 +0000373 'allclose',
Nikita Shulgaff0f79d2024-01-06 01:10:11 +0000374 'any',
Nikita Shulga1d610112024-02-08 18:10:59 +0000375 'addcdiv',
376 'addcmul',
Nikita Shulga045309a2024-05-28 17:56:13 +0000377 'addmmdecomposed',
378 'addmv',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000379 'asin',
380 'atan',
381 'atanh',
Nikita Shulga4ee8aac2024-02-11 16:25:29 +0000382 'bfloat16',
Nikita Shulga045309a2024-05-28 17:56:13 +0000383 'bmm',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000384 'bool',
385 'cartesian_prod',
386 'cat',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000387 'char',
388 'column_stack',
389 'combinations',
Nikita Shulga045309a2024-05-28 17:56:13 +0000390 'corrcoef',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000391 'constant_pad_nd',
392 'cos',
393 'cosh',
394 'count_nonzero',
395 'diff',
Nikita Shulga1d610112024-02-08 18:10:59 +0000396 'div',
397 'divno_rounding_mode',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000398 'dot',
399 'dstack',
Nikita Shulga045309a2024-05-28 17:56:13 +0000400 'einsum',
Nikita Shulga1d610112024-02-08 18:10:59 +0000401 'eq',
402 'equal',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000403 'exp2',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000404 'expm1',
Nikita Shulga53bfae22024-02-20 08:53:12 -0800405 'fft.fft',
406 'fft.fft2',
407 'fft.fftn',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000408 'fft.fftshift',
Nikita Shulga53bfae22024-02-20 08:53:12 -0800409 'fft.ifft',
410 'fft.ifft2',
411 'fft.ifftn',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000412 'fft.ifftshift',
jhavukainen6a539e82024-05-22 21:48:49 +0000413 'fft.irfftn',
414 'fft.irfft2',
415 'fft.irfft',
416 'fft.hfftn',
417 'fft.hfft2',
418 'fft.hfft',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000419 'flip',
420 'fliplr',
421 'flipud',
422 'float',
Nikita Shulga1d610112024-02-08 18:10:59 +0000423 'gradient',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000424 'half',
425 'hstack',
Nikita Shulga045309a2024-05-28 17:56:13 +0000426 'inner',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000427 'int',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000428 'isclose',
Nikita Shulga1d610112024-02-08 18:10:59 +0000429 'isnan',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000430 'ldexp',
Nikita Shulga045309a2024-05-28 17:56:13 +0000431 'linalg.multi_dot',
432 'linalg.pinv',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000433 'log10',
434 'log1p',
435 'log2',
436 'log',
Nikita Shulga1d610112024-02-08 18:10:59 +0000437 'logical_and',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000438 'logical_not',
Nikita Shulga1d610112024-02-08 18:10:59 +0000439 'logical_or',
440 'logical_xor',
Tobias Ringwald758d7872024-09-03 17:28:36 +0000441 'logsumexp',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000442 'long',
Nikita Shulga1d610112024-02-08 18:10:59 +0000443 'masked_fill',
444 'masked.mean',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000445 'masked.prod',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800446 'masked.std',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000447 'masked.sum',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800448 'masked.var',
Tobias Ringwald758d7872024-09-03 17:28:36 +0000449 'masked.logsumexp',
Nikita Shulga045309a2024-05-28 17:56:13 +0000450 'matmul',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000451 'mean',
Nikita Shulga045309a2024-05-28 17:56:13 +0000452 'mm',
453 'mv',
Nikita Shulga1d610112024-02-08 18:10:59 +0000454 'ne',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000455 'neg',
456 'nn.functional.padconstant',
457 'nn.functional.padreflect',
458 'nn.functional.padreplicate',
459 'nn.functional.pixel_shuffle',
460 'nn.functional.pixel_unshuffle',
Nikita Shulga0fd1fc12024-05-07 22:15:20 +0000461 'nn.functional.rms_norm',
462 'nn.functional.softsign',
Nikita Shulga045309a2024-05-28 17:56:13 +0000463 'pinverse',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000464 'prod',
465 'reciprocal',
466 'roll',
467 'rot90',
468 'rsqrt',
469 'short',
470 'sigmoid',
471 'sin',
472 'sinh',
473 'sqrt',
Nikita Shulga15ef52a2024-02-12 17:35:11 -0800474 'square',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000475 'stack',
Nikita Shulga53bfae22024-02-20 08:53:12 -0800476 'stft',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000477 'sum',
478 'sum_to_size',
479 'tan',
Nikita Shulga045309a2024-05-28 17:56:13 +0000480 'tensordot',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000481 'trace',
Nikita Shulga1d610112024-02-08 18:10:59 +0000482 'trapz',
483 'trapezoid',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000484 'tril',
485 'triu',
Nikita Shulga1d610112024-02-08 18:10:59 +0000486 'true_divide',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000487 'vstack',
488 'where',
Sun, Jiayi7be77652024-08-07 00:34:30 -0700489 'byte',
Nikita Shulgab0393eb2024-01-05 00:25:47 +0000490 }
Kulin Seth2bb022e2023-03-08 08:41:21 +0000491 # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
492 MACOS_12_3_XFAILLIST = {
493 # Top 60
494 # expected failures
495 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
496 # fixed in macOS 13.3. Currently error is not raised.
497 'pow': [torch.int16, torch.int64, torch.uint8, torch.int8],
498 # expected failures
499 '__rpow__': [torch.uint8, torch.int8],
500
501 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
502 'cdist': [torch.float32],
503 'tan': [torch.uint8, torch.float32],
504
505 # Data type support starts from macOS 13
506 'nn.functional.avg_pool1d': [torch.int64],
507 'nn.functional.avg_pool2d': [torch.int64],
508 'nn.functional.local_response_norm': [torch.int64],
509 '__radd__': [torch.uint8],
510 '__rdiv__': [torch.uint8],
511 '__rmul__': [torch.uint8],
512 'abs': [torch.uint8],
513 'acos': [torch.uint8],
514 'acosh': [torch.uint8],
515 'add': [torch.uint8],
516 'asin': [torch.uint8],
517 'asinh': [torch.uint8],
518 'atan': [torch.uint8],
519 'atanh': [torch.uint8],
520 'ceil': [torch.uint8],
521 'corrcoef': [torch.uint8],
522 'cos': [torch.uint8],
523 'cosh': [torch.uint8],
524 'cov': [torch.uint8],
525 'cumulative_trapezoid': [torch.uint8],
526 'deg2rad': [torch.uint8],
527 'diff': [torch.uint8],
528 'eq': [torch.uint8],
529 'equal': [torch.uint8],
530 'erf': [torch.uint8],
531 'exp2': [torch.uint8],
532 'exp': [torch.uint8],
533 'expm1': [torch.uint8],
534 'floor': [torch.uint8],
535 'fmax': [torch.uint8],
536 'fmin': [torch.uint8],
537 'fmod': [torch.uint8],
538 'ge': [torch.uint8],
539 'gt': [torch.uint8],
540 'isclose': [torch.uint8],
541 'isnan': [torch.uint8],
542 'kron': [torch.uint8],
543 'le': [torch.uint8],
544 'log10': [torch.uint8],
545 'log1p': [torch.uint8],
546 'log2': [torch.uint8],
547 'log': [torch.uint8],
548 'logical_and': [torch.uint8],
549 'logical_or': [torch.uint8],
550 'logical_xor': [torch.uint8],
551 'logit': [torch.uint8],
552 'lt': [torch.uint8],
553 'masked.mean': [torch.uint8],
554 'masked.std': [torch.uint8],
555 'masked.var': [torch.uint8],
556 'maximum': [torch.uint8],
557 'minimum': [torch.uint8],
558 'mul': [torch.uint8],
559 'ne': [torch.uint8],
560 'neg': [torch.uint8],
561 'nn.functional.cosine_embedding_loss': [torch.uint8],
562 'nn.functional.margin_ranking_loss': [torch.uint8],
563 'nn.functional.poisson_nll_loss': [torch.uint8],
564 'nn.functional.softsign': [torch.uint8],
565 'nn.functional.tanhshrink': [torch.uint8],
566 'nn.functional.triplet_margin_loss': [torch.uint8],
567 'nn.functional.triplet_margin_with_distance_loss': [torch.uint8],
Denis Vieriu89baa1a2023-04-26 01:34:24 +0000568 'nn.functional.pairwise_distance': [torch.uint8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000569 'outer': [torch.uint8],
570 'rad2deg': [torch.uint8],
571 'reciprocal': [torch.uint8],
572 'remainder': [torch.uint8],
573 'round': [torch.uint8],
574 'rsqrt': [torch.uint8],
575 'sigmoid': [torch.uint8],
576 'sign': [torch.uint8],
577 'signbit': [torch.uint8],
578 'sin': [torch.uint8],
579 'sinh': [torch.uint8],
580 'special.ndtr': [torch.uint8],
581 'sqrt': [torch.uint8],
582 'sub': [torch.uint8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000583 'trapezoid': [torch.uint8],
584 'trapz': [torch.uint8],
585 'true_divide': [torch.uint8],
586 'trunc': [torch.uint8],
587 'xlogy': [torch.uint8],
588 'minbinary': [torch.uint8],
589 'maxbinary': [torch.uint8],
590 'divtrunc_rounding': [torch.uint8],
591 'divfloor_rounding': [torch.uint8],
592 'divno_rounding_mode': [torch.uint8],
593 'floor_divide': [torch.uint8],
594 'ldexp': [torch.uint8],
595 # square internally calls into power, and will type cast to int64, which supports starting from macOS 13
596 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
597
598 # cpu not giving nan for x/0.0
Li-Huai (Allan) Lin9a7e2512024-06-18 19:59:50 +0000599 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
vfdev-5d2a2a672023-10-06 10:01:15 +0000600
601 # inconsistency errors between cpu and mps, max seen atol is 2
602 'nn.functional.interpolatebilinear': [torch.uint8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000603 }
604
605 MACOS_BEFORE_13_3_XFAILLIST = {
606 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
607 'tan': [torch.float32],
608 'cdist': [torch.float32],
609
610 # CPU Error: cpu not giving nan for x/0.0
CaoE455241b2023-11-06 06:01:29 +0000611 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000612
613 # test blow pass on macOS 12 as it falls back to cpu
614 # Argsort case using duplicate indices (undefined behaviour):
615 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
616 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
617 # Elements from index 30 and 5133 are both equal.
618 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
619 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
620 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
621 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
622 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
623 # Unsupported dtypes
624 'cumsum': [torch.int64],
Peter Stefek97e50552023-08-01 21:51:16 +0000625 'cumprod': [torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000626 'cumulative_trapezoid': [torch.int64],
627 'masked.cumsum': [torch.int64],
Peter Stefek97e50552023-08-01 21:51:16 +0000628 'masked.cumprod': [torch.int64],
629 'linalg.vander': [torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000630 }
631
Nikita Shulga87084642023-05-11 10:35:05 +0000632 MACOS_AFTER_13_1_XFAILLIST = {
633 # before macOS 13.2 it falls back to cpu and pass the forward pass
Kulin Seth2bb022e2023-03-08 08:41:21 +0000634 'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode
vfdev-5d2a2a672023-10-06 10:01:15 +0000635 # inconsistency errors between cpu and mps, max seen atol is 2
636 'nn.functional.interpolatebilinear': [torch.uint8],
Nikita Shulga87084642023-05-11 10:35:05 +0000637 }
Kulin Seth2bb022e2023-03-08 08:41:21 +0000638
Nikita Shulga87084642023-05-11 10:35:05 +0000639 MACOS_13_3_XFAILLIST = {
Kulin Seth2bb022e2023-03-08 08:41:21 +0000640 # Failure due to precision issue for fp16
641 # on both cpu and mps there are test cases that might produce inf result
642 # 'nn.functional.pairwise_distance': [torch.float16],
643
644 # test blow pass on macOS 12 as it falls back to cpu
645 # Argsort case using duplicate indices (undefined behaviour):
646 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu')
647 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
648 # Elements from index 30 and 5133 are both equal.
649 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
650 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
651 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
652 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
653 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
654 }
655
Huy Do89921412024-06-05 14:44:00 +0000656 MACOS_BEFORE_14_4_XFAILLIST = {
657 # These ops work fine in 14.4 but fail in 14.2 or 13.x
658 'fft.hfft2': [torch.complex64],
659 }
660
Kulin Seth2bb022e2023-03-08 08:41:21 +0000661 # Those ops are not expected to work
662 UNIMPLEMENTED_XFAILLIST = {
663 # Failures due to lack of op implementation on MPS backend
664 'login': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000665 'linalg.eig': None,
666 'linalg.eigvals': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000667 'put': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000668 'nn.functional.conv_transpose3d': None,
669 'rounddecimals_neg_3': None,
670 'rounddecimals_3': None,
671 'rounddecimals_0': None,
672 '__rsub__': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000673 'angle': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000674 'cauchy_': None,
675 'cauchy': None,
676 'cholesky': None,
677 'cholesky_inverse': None,
678 'cholesky_solve': None,
679 'cummax': None,
680 'cummin': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000681 'erfc': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000682 'frexp': None,
683 'gcd': None,
684 'geqrf': None,
685 'nn.functional.grid_sample': None, # Unsupported Border padding mode
686 'heaviside': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000687 'i0': None,
688 'igamma': None,
689 'igammac': None,
690 'index_copy': None,
Pearu Petersond2b0c0a2024-04-17 14:30:26 +0300691 'index_reduceprod': None,
692 'index_reducemean': None,
693 'index_reduceamax': None,
694 'index_reduceamin': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000695 'isneginf': None,
696 'isposinf': None,
697 'kthvalue': None,
698 'lcm': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000699 'linalg.cholesky': None,
700 'linalg.cholesky_ex': None,
701 'linalg.cond': None,
702 'linalg.detsingular': None,
703 'linalg.det': None,
704 'linalg.eigh': None,
705 'linalg.eigvalsh': None,
706 'linalg.householder_product': None,
707 'linalg.ldl_factor': None,
708 'linalg.ldl_factor_ex': None,
709 'linalg.ldl_solve': None,
710 'linalg.lstsq': None,
711 'linalg.lstsqgrad_oriented': None,
712 'linalg.lu': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000713 'linalg.lu_factor_ex': None,
714 'linalg.lu_solve': None,
715 'linalg.matrix_norm': [torch.float32],
716 'linalg.norm': [torch.float32],
717 'linalg.normsubgradients_at_zero': [torch.float32],
718 'linalg.qr': None,
719 'linalg.slogdet': None,
720 'linalg.solve': None,
721 'linalg.solve_ex': None,
722 'linalg.svdvals': None,
723 'linalg.tensorsolve': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000724 'linalg.vecdot': None,
725 'logcumsumexp': None,
726 'logdet': None,
727 'lu': None,
728 'lu_solve': None,
729 'lu_unpack': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000730 'masked.median': None,
731 'matrix_exp': None,
732 'mode': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000733 'nanmedian': None,
734 'native_dropout_backward': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000735 'normnuc': None,
736 'nn.functional.fractional_max_pool2d': None,
737 'nn.functional.fractional_max_pool3d': None,
738 'nn.functional.adaptive_avg_pool3d': None,
739 'nn.functional.adaptive_max_pool3d': None,
740 'nn.functional.interpolatearea': None,
741 'nn.functional.interpolatebicubic': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000742 'nn.functional.interpolatetrilinear': None,
743 'nn.functional.max_unpool1dgrad': None,
744 'nn.functional.max_unpool2dgrad': None,
745 'nn.functional.max_unpool3dgrad': None,
746 'nn.functional.avg_pool3d': None,
747 'nn.functional.ctc_loss': None,
748 'nn.functional.embedding_bag': None,
749 'nn.functional.hardshrink': None,
750 'nn.functional.max_pool3d': None,
751 'nn.functional.max_unpool1d': None,
752 'nn.functional.max_unpool2d': None,
753 'nn.functional.max_unpool3d': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000754 'nn.functional.multi_margin_loss': None,
755 'nn.functional.multilabel_margin_loss': None,
756 'nn.functional.pdist': None,
757 'nn.functional.rrelu': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000758 'nn.functional.norm': None,
759 'ormqr': None,
760 'pca_lowrank': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000761 'qr': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000762 'rsub': None,
763 'scatter_reduceamax': None,
764 'scatter_reduceamin': None,
765 'scatter_reducemin': None,
766 'scatter_reducemean': None,
767 'scatter_reduceprod': None,
768 'scatter_reducesum': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000769 'segment_reduce': None,
770 '_segment.reduce': None,
771 'segment.reduce': None,
772 'segment_reduce_offsets': None,
773 '_segment_reduce_offsets': None,
774 '_segment_reduce_lengths': None,
775 '_segment_reducelengths': None,
776 '_segment_reduceoffsets': None,
777 'sinc': None,
778 'sparse.mm': None,
779 'sparse.mmreduce': None,
780 'special.airy_ai': None,
781 'special.bessel_j0': None,
782 'special.bessel_j1': None,
783 'special.bessel_y0': None,
784 'special.bessel_y1': None,
785 'special.chebyshev_polynomial_t': None,
786 'special.chebyshev_polynomial_u': None,
787 'special.entr': None,
788 'special.erfcx': None,
789 'special.hermite_polynomial_h': None,
790 'special.hermite_polynomial_he': None,
791 'special.i0e': None,
792 'special.i1': None,
793 'special.i1e': None,
794 'special.laguerre_polynomial_l': None,
795 'special.log_ndtr': None,
796 'special.modified_bessel_i0': None,
797 'special.modified_bessel_i1': None,
798 'special.modified_bessel_k0': None,
799 'special.modified_bessel_k1': None,
800 'special.ndtri': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000801 'special.scaled_modified_bessel_k0': None,
802 'special.scaled_modified_bessel_k1': None,
803 'special.spherical_bessel_j0': None,
804 'special.xlog1py': None,
805 'special.zeta': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000806 'svd_lowrank': None,
807 'symeig': None,
808 'take': None,
809 'to': None,
810 'to_sparse': None,
811 'unique': None,
812 'vdot': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000813 'segment_reduce_': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000814 '_upsample_bilinear2d_aa': None,
815 'geometric' : None,
816 'geometric_': None,
817 'log_normal_': None,
818 'log_normal': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000819 'cdouble': None,
Nikita Shulga9dda4b22023-12-18 15:39:11 -0800820 'double': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000821 'nn.functional.softminwith_dtype': None,
822 'log_softmaxwith_dtype': None,
823 'softmaxwith_dtype': None,
824 'float_power': None,
825 'full_like': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000826 'linalg.matrix_rankhermitian': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000827 'linalg.pinvhermitian': None,
Guang Yangc377a852023-04-11 05:13:36 +0000828 'nonzero_static': None,
Kulin Seth2bb022e2023-03-08 08:41:21 +0000829
830 # MPS: input sizes must be divisible by output sizes
831 'nn.functional.adaptive_avg_pool1d': None,
832 'nn.functional.adaptive_avg_pool2d': None,
833
834 # Unsupported dtypes
835 # bmm is not supported for integral types
836 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000837 'ones_like': None,
838 'zeros_like': None,
839
840 # Convolution for integral types is not supported on MPS
841 'nn.functional.conv1d': [torch.int64],
842 'nn.functional.conv2d': [torch.int64],
Khushi Agrawalcff84872023-11-27 14:45:44 +0000843 'nn.functional.conv3d': [torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000844 'nn.functional.conv_transpose1d': [torch.int64],
845 'nn.functional.conv_transpose2d': [torch.int64],
846
847 # Unsupported dtypes
848 'dot': [torch.int64],
CaoEa310cc82023-10-31 09:12:47 +0000849 'histc': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000850 'index_add': [torch.int64],
851 'log1p': [torch.int64],
852 'sigmoid': [torch.int64],
853 'atan2': [torch.int64],
854
855 # GEMM on MPS is not supported for integral types
856 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
857 '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
858 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
859 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
860 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
861 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
862 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
863 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
864 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
865 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
866 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
867 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
868 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
869 'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
870 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
871 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
Kurt Mohler5292a922023-10-12 00:55:51 +0000872 'unravel_index': [torch.int32, torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000873
874 # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as
875 # the MPS framework doesn't support float64
876 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
877 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
878 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
879 # returned output on CPU is float64
880 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
881
882 # trunc_tensor not working properly for float16
883 'divtrunc_rounding': [torch.float16],
884 'fmod': [torch.float16],
Sun, Jiayid56e1b22023-05-11 15:30:59 +0800885
886 # round not working properly for float16
887 'round': [torch.float16],
Isuru Fernandoe6bfa292024-06-24 22:15:14 +0000888
889 # atomic operations not supported
890 '_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000891 }
892
Nikita Shulga53bfae22024-02-20 08:53:12 -0800893 if product_version < 14.0:
894 # FFT and BFloat16 support was added in MacOS 14
895 UNIMPLEMENTED_XFAILLIST.update({
896 'bfloat16': None,
897 'fft.fft': None,
898 'fft.fft2': None,
899 'fft.fftn': None,
900 'fft.hfft': None,
jhavukainen6a539e82024-05-22 21:48:49 +0000901 'fft.hfft2': None,
902 'fft.hfftn': None,
Nikita Shulga53bfae22024-02-20 08:53:12 -0800903 'fft.ifft': None,
904 'fft.ifft2': None,
905 'fft.ifftn': None,
906 'fft.ihfft': None,
907 'fft.ihfft2': None,
908 'fft.ihfftn': None,
909 'fft.irfft': None,
910 'fft.irfft2': None,
911 'fft.irfftn': None,
912 'fft.rfft': None,
913 'fft.rfft2': None,
914 'fft.rfftn': None,
915 'stft': None,
jhavukainend28868c2024-05-20 20:23:53 +0000916 # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers,
Joona Havukainenc451d102024-05-01 23:14:05 +0000917 # not reproducible in later OS. Added assert to op if used in < 14.0
jhavukainend28868c2024-05-20 20:23:53 +0000918 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8],
Isuru Fernando5f912f42024-06-24 14:46:53 +0000919 'nn.functional.max_pool2d': [torch.uint8],
Nikita Shulga53bfae22024-02-20 08:53:12 -0800920 })
921
Denis Vieriu861bdf92024-08-16 21:07:48 +0000922 if product_version < 15.0:
923 UNIMPLEMENTED_XFAILLIST.update({
924 'quantile': None,
925 'nanquantile': None,
926 })
927
Kulin Seth2bb022e2023-03-08 08:41:21 +0000928 UNDEFINED_XFAILLIST = {
929 # Top 60 operators
930 # topk fails with duplicate indices
931 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
932
933 # Failures due to random output that they generate using
934 # Philox engine causing mismatch with CPU results
CaoEd1afb7d2023-10-19 19:05:09 -0700935 'multinomial': [torch.float16, torch.float32], # random results
Kulin Seth2bb022e2023-03-08 08:41:21 +0000936 'uniform': [torch.float16, torch.float32],
937 'rand_like': [torch.float16, torch.float32],
938 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
939 'randn_like': [torch.float16, torch.float32],
CaoE8713a1a2023-10-11 23:54:31 -0700940 'bernoulli': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000941 'exponential': [torch.float16, torch.float32],
CaoE8713a1a2023-10-11 23:54:31 -0700942 'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000943 'normal': [torch.float16, torch.float32, torch.float16, torch.float32],
944 'normalin_place': [torch.float16, torch.float32],
945 'normalnumber_mean': [torch.float16, torch.float32],
CaoE8713a1a2023-10-11 23:54:31 -0700946 'nn.functional.alpha_dropout': [torch.float16, torch.float32],
947 'nn.functional.dropout': [torch.float16, torch.float32],
948 'nn.functional.dropout2d': [torch.float16, torch.float32],
949 'nn.functional.dropout3d': [torch.float16, torch.float32],
Cao E1c89ea72023-10-26 08:38:54 +0000950 # See https://github.com/pytorch/pytorch/issues/111479
951 'nn.functional.multi_head_attention_forward': [torch.float32, torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000952
Kulin Seth2bb022e2023-03-08 08:41:21 +0000953 # duplicate indices are used in the testcase - undefined behaviour
954 'index_put': None,
955 # zero to negative integer powers are undefined
956 '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64],
957 'resize_': [torch.float16, torch.float32],
958 'resize_as_': [torch.float16, torch.float32],
959
960 # CPU Errors:
961 'addr': [torch.bool, torch.int16, torch.int32,
962 torch.int64, torch.uint8, torch.int8], # "addmv_impl_cpu" not implemented for 'Half'
963 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
964 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
965 'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
966 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values
967
968 # random results
969 # mps vs cpu:
970 # Mismatched elements: 40 / 96 (41.7%)
971 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
972 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
973 # cuda(2.0.0.dev20230301+cu117) vs cpu:
974 # Mismatched elements: 56 / 96 (58.3%)
975 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
976 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
Cao E1c89ea72023-10-26 08:38:54 +0000977 'nn.functional.scaled_dot_product_attention': [torch.float32, torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000978
CaoEa310cc82023-10-31 09:12:47 +0000979 # float output for float16 input on MPS
980 'logit': [torch.float16],
Kulin Seth2bb022e2023-03-08 08:41:21 +0000981 }
982
watarungurunnnd444a3b2024-02-05 15:36:55 +0000983 ON_MPS_XFAILLIST = {
984 # Failures due to lack of implementation of downstream functions on MPS backend
985 # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
986 'linalg.matrix_rank': None,
987 }
988
Li-Huai (Allan) Lin13da6582023-05-01 14:55:02 +0800989 EMPTY_OPS_SKIPLIST = {
990 # Fill tensors with uninitialized data, causing mismatch with CPU.
991 # They occasionally match, thus skipping them.
992 # See https://github.com/pytorch/pytorch/issues/100175
993 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
994 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16,
995 torch.int32, torch.int64, torch.uint8, torch.int8],
Khushi1aaf0392023-05-19 03:06:29 +0000996 'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
Li-Huai (Allan) Lin13da6582023-05-01 14:55:02 +0800997 # CPU: empty is returning all 0's and there is a mismatch with MPS
998 # allocation (MacOS 13). According to
999 # https://pytorch.org/docs/2.0/generated/torch.empty.html
1000 'empty': [torch.bool, torch.float16, torch.float32, torch.int16,
1001 torch.int32, torch.int64, torch.uint8, torch.int8],
1002 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
1003 'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16,
1004 torch.int32, torch.int64, torch.uint8, torch.int8],
1005 }
1006
Peter Bell46e80ce2023-10-24 15:19:01 +01001007 SKIPLIST = {
mingfeimaa8acd6c2023-12-12 12:59:47 +00001008 # Unsupported
1009 # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible
1010 'nn.functional.avg_pool2d': [torch.float16],
Huy Do89921412024-06-05 14:44:00 +00001011
1012 # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
1013 'nn.functional.conv3d': None,
Peter Bell46e80ce2023-10-24 15:19:01 +01001014 }
1015
Kulin Seth2bb022e2023-03-08 08:41:21 +00001016 def addDecorator(op, d) -> None:
1017 op.decorators = list(op.decorators) if op.decorators is not None else []
1018 op.decorators.append(d)
1019
1020 for op in ops:
1021 key = op.name + op.variant_test_name
Li-Huai (Allan) Lin13da6582023-05-01 14:55:02 +08001022 if key in EMPTY_OPS_SKIPLIST:
1023 addDecorator(op, DecorateInfo(
1024 unittest.skip("Skipping empty ops."),
1025 dtypes=EMPTY_OPS_SKIPLIST[key]))
Peter Bell46e80ce2023-10-24 15:19:01 +01001026 if key in SKIPLIST:
1027 addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
watarungurunnnd444a3b2024-02-05 15:36:55 +00001028 for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
Kulin Seth2bb022e2023-03-08 08:41:21 +00001029 if key in xfaillist:
1030 addDecorator(op, DecorateInfo(
1031 unittest.expectedFailure,
1032 dtypes=xfaillist[key]))
1033
Huy Do89921412024-06-05 14:44:00 +00001034 if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4):
1035 addDecorator(op, DecorateInfo(
1036 unittest.expectedFailure,
1037 dtypes=MACOS_BEFORE_14_4_XFAILLIST[key]))
1038
Kulin Seth2bb022e2023-03-08 08:41:21 +00001039 if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
1040 addDecorator(op, DecorateInfo(
1041 unittest.expectedFailure,
1042 dtypes=MACOS_BEFORE_13_3_XFAILLIST[key]))
1043
Nikita Shulga87084642023-05-11 10:35:05 +00001044 if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2):
1045 addDecorator(op, DecorateInfo(
1046 unittest.expectedFailure,
1047 dtypes=MACOS_AFTER_13_1_XFAILLIST[key]))
1048
Kulin Seth2bb022e2023-03-08 08:41:21 +00001049 if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3):
1050 addDecorator(op, DecorateInfo(
1051 unittest.expectedFailure,
1052 dtypes=MACOS_13_3_XFAILLIST[key]))
1053
1054 if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()):
1055 addDecorator(op, DecorateInfo(
1056 unittest.expectedFailure,
1057 dtypes=MACOS_12_3_XFAILLIST[key]))
Nikita Shulgab0393eb2024-01-05 00:25:47 +00001058
Nikita Shulga53a4ca42023-08-31 20:41:39 -07001059 # If ops is not supported for complex types, expect it to fail
Nikita Shulgab0393eb2024-01-05 00:25:47 +00001060 if key not in SUPPORTED_COMPLEX_OPS and (key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS or product_version < 14.0):
Nikita Shulga53a4ca42023-08-31 20:41:39 -07001061 addDecorator(op, DecorateInfo(unittest.expectedFailure, dtypes=[torch.complex32, torch.complex64]))
Peter Bell46e80ce2023-10-24 15:19:01 +01001062
Nikita Shulgafd8367a2023-02-27 15:01:01 +00001063 yield op
1064
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001065def mps_ops_error_inputs_modifier(ops):
1066 # Error input samples do not take a dtype argument.
1067 XFAILLIST = {
1068 # Exceptions are not raised
1069 '__rmod__',
1070 '__rsub__',
albanD08cbfb22023-07-12 18:11:24 +00001071 '__rpow__',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001072 'bernoulli',
1073 'clamp_max',
1074 'clamp_min',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001075 'masked_scatter',
1076
1077 # unsupported float64 dtype
1078 'cat',
1079 'complex',
1080 'multinomial',
1081 'nn.functional.conv1d',
1082 'nn.functional.conv2d',
Khushi Agrawalcff84872023-11-27 14:45:44 +00001083 'nn.functional.conv3d',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001084 'gather',
1085 'scatter',
1086 'scatter_add',
1087
1088 # unsupported complex dtypes
1089 'masked_fill',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001090
1091 # MPS does not support tensor dimensions > 16
1092 'amax',
1093 'amin',
Li-Huai (Allan) Lina50fb502023-05-01 14:54:57 +08001094 'aminmax',
1095
1096 # memory overlapping checks
1097 'index_select',
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +00001098
1099 # unimplemented
1100 'logcumsumexp',
1101 }
1102
1103 def addDecorator(op, d) -> None:
1104 op.decorators = list(op.decorators) if op.decorators is not None else []
1105 op.decorators.append(d)
1106
1107 for op in ops:
1108 if op.error_inputs_func is None:
1109 continue
1110 key = op.name + op.variant_test_name
1111 if key in XFAILLIST:
1112 addDecorator(op, DecorateInfo(unittest.expectedFailure))
1113 yield op
1114
Kulin Sethe011a8e2022-05-13 18:28:53 +00001115# Same logic as test_cuda.py
1116if not torch.backends.mps.is_available():
1117 print('MPS not available, skipping tests', file=sys.stderr)
Catherine Leeeea07332023-03-07 18:30:27 +00001118 TestCase = NoTest # noqa: F811
1119 NNTestCase = NoTest # noqa: F811
Kulin Sethe011a8e2022-05-13 18:28:53 +00001120
Pearu Peterson45401ef2023-06-14 14:00:05 +03001121product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
Nikita Shulgaabf3f902024-04-22 23:43:11 +00001122total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
Denis Vieriu71ec2612023-02-15 06:09:56 +00001123
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001124# Determine whether to enable MPS memory leak check (uses same code as CUDA).
1125TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
1126
1127def skipMPSMemoryLeakCheckIf(condition):
1128 def dec(fn):
1129 if getattr(fn, '_do_mps_memory_leak_check', True):
1130 fn._do_mps_memory_leak_check = not condition
1131 return fn
1132 return dec
1133
Justin Chu73e14552023-07-19 07:40:18 -07001134class MpsMemoryLeakCheck:
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001135 def __init__(self, testcase, name=None):
1136 self.name = testcase.id() if name is None else name
1137 self.testcase = testcase
1138
1139 def __enter__(self):
1140 # Performs a gc if required (required if any memory is held)
1141 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1142 if caching_allocator_mem_allocated > 0:
1143 gc.collect()
1144 torch.mps.empty_cache()
1145
1146 # Acquires caching allocator and driver statistics before the test is run
1147 self.caching_allocator_before = torch.mps.current_allocated_memory()
1148 self.driver_before = torch.mps.driver_allocated_memory()
1149
1150 def __exit__(self, exec_type, exec_value, traceback):
1151 # Don't check for leaks if an exception was thrown
1152 if exec_type is not None:
1153 return
1154 # Compares caching allocator before/after statistics
1155 # An increase in allocated memory is a discrepancy indicating a possible memory leak
1156 discrepancy_detected = False
1157 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1158 if caching_allocator_mem_allocated > self.caching_allocator_before:
1159 discrepancy_detected = True
1160
1161 # Short-circuits if no discrepancy detected
1162 if not discrepancy_detected:
1163 return
1164 # Validates the discrepancy persists after garbage collection and
1165 # is confirmed by the driver API
1166 gc.collect()
1167 torch.mps.empty_cache()
1168
1169 discrepancy_detected = True
1170 # Query memory multiple items to ensure leak was not transient
1171 for n in range(3):
1172 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1173 driver_mem_allocated = torch.mps.driver_allocated_memory()
1174
1175 caching_allocator_discrepancy = False
1176 driver_discrepancy = False
1177
1178 if caching_allocator_mem_allocated > self.caching_allocator_before:
1179 caching_allocator_discrepancy = True
1180
1181 if driver_mem_allocated > self.driver_before:
1182 driver_discrepancy = True
1183
Aaron Gokaslan3fe437b22024-01-03 06:04:44 +00001184 if not (caching_allocator_discrepancy or driver_discrepancy):
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001185 # Leak was false positive, exit loop
1186 discrepancy_detected = False
1187 break
1188
1189 if caching_allocator_discrepancy and not driver_discrepancy:
1190 # Just raises a warning if the leak is not validated by the driver API
1191 msg = ("MPS caching allocator reports a memory leak not "
Aaron Gokaslan5a1216b2024-04-21 14:06:20 +00001192 f"verified by the driver API in {self.name}! "
1193 f"Caching allocator allocated memory was {self.caching_allocator_before} "
1194 f"and is now reported as {caching_allocator_mem_allocated}. "
1195 f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001196 warnings.warn(msg)
1197 elif caching_allocator_discrepancy and driver_discrepancy:
1198 # A caching allocator discrepancy validated by the driver API is a failure
Aaron Gokaslan5a1216b2024-04-21 14:06:20 +00001199 msg = (f"MPS driver API confirmed a leak in {self.name}! "
1200 f"Caching allocator allocated memory was {self.caching_allocator_before} "
1201 f"and is now reported as {caching_allocator_mem_allocated}. "
1202 f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001203
1204 raise RuntimeError(msg)
1205
Kulin Seth144fde42024-09-05 23:23:15 +00001206class TestAutocastMPS(TestCase):
1207
1208 def test_matmul_autocast(self):
1209 autocast_tensor_A = torch.rand((8, 8), device="mps")
1210 autocast_tensor_B = torch.rand((8, 8), device="mps")
1211 tensor_A = autocast_tensor_A.clone().detach()
1212 tensor_B = autocast_tensor_B.clone().detach()
1213 autocast_output_tensor = torch.empty(8, 8)
1214 output_tensor = autocast_output_tensor.clone().detach()
1215
1216 with torch.autocast(device_type="mps"):
1217 autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B)
1218 autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor)
1219
1220 output_tensor = torch.mm(tensor_A, tensor_B)
1221 output_tensor = torch.mm(tensor_A, output_tensor)
1222
1223 self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16")
1224 self.assertEqual(autocast_output_tensor,
1225 output_tensor.to(torch.float16),
1226 f"Autocast & non-autocast tensors did not match, \
1227 got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}")
1228
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001229# Expand TestCase class with Memory Leak Detection on MPS device
1230class TestCaseMPS(TestCase):
1231 _do_mps_memory_leak_check = True
1232
1233 def __init__(self, method_name='runTest'):
1234 super().__init__(method_name)
1235 test_method = getattr(self, method_name, None)
1236 if test_method is not None:
1237 # Wraps the tested method if we should do MPS memory check.
1238 if TEST_MPS_MEM_LEAK_CHECK:
1239 if self._do_mps_memory_leak_check:
1240 self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
1241
1242 def assertLeaksNoMpsTensors(self, name=None):
1243 name = self.id() if name is None else name
1244 return MpsMemoryLeakCheck(self, name)
1245
1246 def wrap_with_mps_policy(self, method_name, policy):
1247 test_method = getattr(self, method_name)
1248 setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
1249
1250 # checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
1251 def wrap_with_mps_memory_check(self, method):
1252 return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
1253
1254class TestMemoryLeak(TestCaseMPS):
1255 def test_mps_memory_leak_detection(self):
1256 l = []
1257
1258 @self.wrap_with_mps_memory_check
1259 def no_leak():
1260 pass
1261
1262 # Trigger an intentional memory leak
1263 @self.wrap_with_mps_memory_check
1264 def leak_gpu0():
1265 # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1266 l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
1267
1268 no_leak()
1269
1270 # check if a runtime error for memory leak was emitted which would
1271 # confirm whether memory leak detection worked successfully or not.
1272 with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
1273 leak_gpu0()
1274
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001275 def test_copy_cast_no_leak(self):
Nikita Shulga324cde52023-11-22 14:48:24 +00001276
1277 def step(x):
1278 x = x.to(device='cpu', dtype=torch.float32)
1279 x = x.to(device='mps', dtype=torch.float16)
1280
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001281 a = torch.randn(128, 128, device='mps', dtype=torch.float16)
Nikita Shulga324cde52023-11-22 14:48:24 +00001282 # Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
1283 step(a)
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001284 torch.mps.empty_cache()
1285 driver_before = torch.mps.driver_allocated_memory()
Nikita Shulga324cde52023-11-22 14:48:24 +00001286 step(a)
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001287 torch.mps.empty_cache()
1288 driver_after = torch.mps.driver_allocated_memory()
Nikita Shulgabc689072024-06-26 16:29:59 +00001289 self.assertEqual(driver_before, driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
Nikita Shulgab5dd37f2023-11-21 14:52:55 +00001290
alexdremovb60273b2023-09-06 09:11:39 +00001291
1292class TestPixelShuffle(TestCaseMPS):
1293 def test_pixel_shuffle_unshuffle(self):
1294 def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
1295 upscale_factor=None, is_contiguous=True):
1296
1297 def generate_input():
1298 # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
1299 channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
1300 height = random.randint(5, 10)
1301 width = random.randint(5, 10)
1302
1303 if num_input_dims == 1:
1304 input = torch.rand(channels, requires_grad=True, device='mps')
1305 assert is_contiguous
1306 elif num_input_dims == 2:
1307 input = torch.rand(width, height, requires_grad=True, device='mps').T
1308 if is_contiguous:
1309 input = input.contiguous()
1310 else:
1311 batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1312 input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
1313 input = input.transpose(-1, -2)
1314 if is_contiguous:
1315 input = input.contiguous()
1316
1317 if not is_contiguous and len(input.reshape(-1)) > 0:
1318 assert not input.is_contiguous()
1319
1320 input = input.detach().clone()
1321 input.requires_grad = True
1322 return input
1323
1324 # Function to imperatively ensure pixels are shuffled to the correct locations.
1325 # Used to validate the batch operations in pixel_shuffle.
1326 def _verify_pixel_shuffle(input, output, upscale_factor):
1327 for c in range(output.size(-3)):
1328 for h in range(output.size(-2)):
1329 for w in range(output.size(-1)):
1330 height_idx = h // upscale_factor
1331 weight_idx = w // upscale_factor
1332 channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
1333 (c * upscale_factor ** 2)
1334 self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
1335
1336 upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
1337 input = generate_input()
1338
1339 ps = nn.PixelShuffle(upscale_factor)
1340 pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
1341
1342 if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
1343 output = ps(input)
1344 _verify_pixel_shuffle(input, output, upscale_factor)
1345 output.backward(output.data)
1346 self.assertEqual(input.data, input.grad.data)
1347
1348 # Ensure unshuffle properly inverts shuffle.
1349 unshuffle_output = pus(output)
1350 self.assertEqual(input, unshuffle_output)
1351 else:
1352 self.assertRaises(RuntimeError, lambda: ps(input))
1353
1354 def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
1355 downscale_factor=None):
1356 downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
1357 channels = random.randint(1, 4)
1358 # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
1359 height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
1360 # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
1361 width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
1362
1363 if num_input_dims == 1:
1364 input = torch.rand(channels, requires_grad=True, device='mps')
1365 elif num_input_dims == 2:
1366 input = torch.rand(height, width, requires_grad=True, device='mps')
1367 else:
1368 batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1369 input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
1370
1371 pus = nn.PixelUnshuffle(downscale_factor)
1372 self.assertRaises(RuntimeError, lambda: pus(input))
1373
1374 def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
1375 # For 1D - 2D, this is an error case.
1376 # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
1377 is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
1378 for is_contiguous in is_contiguous_check:
1379 _test_pixel_shuffle_unshuffle_helper(
1380 num_input_dims=num_input_dims, is_contiguous=is_contiguous
1381 )
1382 _test_pixel_shuffle_unshuffle_helper(
1383 num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
1384 )
1385 _test_pixel_shuffle_unshuffle_helper(
1386 num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
1387 )
1388 _test_pixel_shuffle_unshuffle_helper(
1389 num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
1390 )
1391
1392 # Error cases for pixel_unshuffle.
1393 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
1394 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
1395 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
1396 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
1397
1398 def test_pixel_shuffle_unshuffle_1D():
1399 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
1400
1401 def test_pixel_shuffle_unshuffle_2D():
1402 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
1403
1404 def test_pixel_shuffle_unshuffle_3D():
1405 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
1406
1407 def test_pixel_shuffle_unshuffle_4D():
1408 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
1409
1410 def test_pixel_shuffle_unshuffle_5D():
1411 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
1412
1413 test_pixel_shuffle_unshuffle_1D()
1414 test_pixel_shuffle_unshuffle_2D()
1415 test_pixel_shuffle_unshuffle_3D()
1416 test_pixel_shuffle_unshuffle_4D()
1417 test_pixel_shuffle_unshuffle_5D()
1418
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001419class MPSReluTest(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001420 def _npRelu(self, np_features):
1421 return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
1422
1423 def testNpRelu(self):
Philip Meierbc73aff2022-11-02 11:25:04 +01001424 torch.testing.assert_close(
Kulin Sethe011a8e2022-05-13 18:28:53 +00001425 np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
1426 self._npRelu(
1427 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1428 0.9]])))
1429
1430 def _testRelu(self, np_features, device):
1431 np_relu = self._npRelu(np_features)
1432 # Convert the numpy array to a PyTorch Tensor,
1433 # and move the Tensor to the CPU/GPU based on the "device" parameter
1434 py_tensor = torch.from_numpy(np_features).to(device)
1435 py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
1436 py_relu_cpu = py_relu.to("cpu")
1437
Philip Meierbc73aff2022-11-02 11:25:04 +01001438 self.assertEqual(np_relu, py_relu_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001439
1440 def _testReluInPlace(self, np_features, device):
1441 np_relu = self._npRelu(np_features)
1442 # Convert the numpy array to a PyTorch Tensor,
1443 # and move the Tensor to the CPU/GPU based on the "device" parameter
1444 py_tensor = torch.from_numpy(np_features).to(device)
1445 py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
1446 py_relu_cpu = py_relu.to("cpu")
1447
Philip Meierbc73aff2022-11-02 11:25:04 +01001448 self.assertEqual(np_relu, py_relu_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001449 # Inplace Relu modifies the initial input and it should match the output of Relu
Philip Meierbc73aff2022-11-02 11:25:04 +01001450 self.assertEqual(np_relu, py_tensor.to("cpu"))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001451
1452 def testNumbersCPU(self):
1453 for t in [np.int32]:
1454 # Force execution on CPU even if a GPU kernel is available for the type.
1455 self._testRelu(
1456 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1457 device="cpu")
1458 self._testReluInPlace(
1459 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1460 device="cpu")
1461
1462 def testNumbersGPU(self):
1463 for t in [np.float16, np.float32]:
1464 self._testRelu(
1465 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1466 device="mps")
1467 self._testReluInPlace(
1468 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1469 device="mps")
lancerts26a27432024-02-03 23:50:35 +00001470 self._testRelu(np.array([]).astype(t), device="mps")
1471 self._testReluInPlace(np.array([]).astype(t), device="mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001472
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001473class MatmulTest(TestCaseMPS):
Kulin Seth978304f2022-05-14 13:33:16 +00001474 def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
1475 if expand_tensor_1_shape:
1476 tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
1477 else:
1478 tensor1_mps = torch.randn(shape_tensor_1, device="mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001479
Kulin Seth978304f2022-05-14 13:33:16 +00001480 if expand_tensor_2_shape:
1481 tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
1482 else:
1483 tensor2_mps = torch.randn(shape_tensor_2, device="mps")
1484
1485 tensor1_cpu = tensor1_mps.to("cpu")
1486 tensor2_cpu = tensor2_mps.to("cpu")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001487
1488 matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
1489 matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
1490
1491 self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
1492
1493 def test_vector_x_vector(self):
1494 # uses `dot`
1495 self._helper(3, 3)
1496
1497 def test_matrix_x_vector(self):
1498 # uses `addmv`
1499 self._helper((3, 4), 4)
1500
1501 def test_batched_matrix_x_broadcasted_vector(self):
1502 self._helper((10, 3, 4), 4)
1503
1504 def test_batched_matrix_x_batched_matrix(self):
1505 # uses `bmm.out`
1506 self._helper((10, 3, 4), (10, 4, 5))
1507
1508 def test_batched_matrix_x_broadcasted_matrix(self):
1509 self._helper((10, 3, 4), (4, 5))
1510
1511
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001512class MPSLeakyReluTest(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001513 def _npLeakyRelu(self, np_features, negative_slope=0.1):
1514 return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
1515
1516 def testNpLeakyRelu(self):
Philip Meierbc73aff2022-11-02 11:25:04 +01001517 torch.testing.assert_close(
Kulin Sethe011a8e2022-05-13 18:28:53 +00001518 np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
1519 [0.1, -0.03, 0.5, -0.07, 0.9]]),
1520 self._npLeakyRelu(
1521 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1522 0.9]]),
1523 negative_slope=0.1))
1524
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001525 def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
1526 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
1527 mps_x = cpu_x.detach().clone().to('mps')
1528
1529 if not contiguous and not (0 in shape or len(shape) < 2):
1530 # Tranposing will make the tensor non-contiguous
1531 cpu_x = cpu_x.transpose(0, 1)
1532 mps_x = mps_x.transpose(0, 1)
1533 assert not mps_x.is_contiguous()
1534
1535 cpu_x.requires_grad_()
1536 mps_x.requires_grad_()
1537
Kulin Sethe011a8e2022-05-13 18:28:53 +00001538 relu_op = torch.nn.LeakyReLU(negative_slope)
1539
1540 cpu_leaky_relu = relu_op(cpu_x)
1541 mps_leaky_relu = relu_op(mps_x)
Philip Meierbc73aff2022-11-02 11:25:04 +01001542 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001543
1544 # test backward pass
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001545
Kulin Sethe011a8e2022-05-13 18:28:53 +00001546 cpu_grad = torch.ones_like(cpu_leaky_relu)
1547 mps_grad = cpu_grad.to('mps')
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001548
Kulin Sethe011a8e2022-05-13 18:28:53 +00001549 mps_leaky_relu.backward(gradient=mps_grad)
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001550 cpu_leaky_relu.backward(gradient=cpu_grad)
1551
1552 assert cpu_x.grad is not None # Check that the grad is well-populated
1553 self.assertEqual(cpu_x.grad, mps_x.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001554
1555 def testNumbersCPU(self):
Joël Tanga6a3f2e2024-04-21 00:12:29 +00001556 for t in [torch.float, torch.half]:
1557 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
1558 for contiguous in [True, False]:
1559 self._testLeakyRelu(shape,
1560 dtype=t,
1561 negative_slope=0.2,
1562 contiguous=contiguous)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001563
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001564class TestAvgPool(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001565 def _sum_pool2d(self, x, kernel_size):
1566 windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
1567 return torch.sum(windows, dim=1)
1568
1569 def _sum_pool3d(self, x, kernel_size):
1570 # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
1571 h = kernel_size[0]
1572 splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
1573 # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
1574 splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
1575 joined_x = torch.cat(splited_x)
1576 return joined_x.view(1, joined_x.numel())
1577
1578 def _avg_pool2d(self, x, kernel_size):
Aaron Gokaslanbd10fea2024-01-01 08:40:46 +00001579 size = reduce(operator.mul, kernel_size) # noqa: F821
Kulin Sethe011a8e2022-05-13 18:28:53 +00001580 return self._sum_pool2d(x, kernel_size) / size
1581
1582 def _avg_pool3d(self, x, kernel_size):
Aaron Gokaslanbd10fea2024-01-01 08:40:46 +00001583 size = reduce(operator.mul, kernel_size) # noqa: F821
Kulin Sethe011a8e2022-05-13 18:28:53 +00001584 return self._sum_pool3d(x, kernel_size) / size
1585
1586 def test_avg_pool2d_with_zero_divisor(self):
1587 self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
1588 lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
1589
1590 def test_doubletensor_avg_pool2d_with_divisor(self):
1591 n, m = 3, 3
1592 input = torch.rand(1, 1, n, m)
1593 for i in range(1, n + 1):
1594 for j in range(1, m + 1):
1595 for divisor in [1, 7, i * j]:
1596 actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
1597 actual = actual.view(1, actual.numel())
1598 expected = self._sum_pool2d(input, (i, j)) / divisor
1599 self.assertEqual(actual, expected, rtol=0, atol=1e-5)
1600
1601 def test_avg_pool2d_ceil_mode(self):
1602 # Regression test for gh-36977
1603 x = 10 * torch.randn((1, 16, 4, 4))
1604 y = torch.nn.functional.avg_pool2d(
1605 x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1606 padding=(0, 1), stride=2)
Nikita Shulgabc689072024-06-26 16:29:59 +00001607 self.assertFalse(torch.isnan(y).any())
Kulin Sethe011a8e2022-05-13 18:28:53 +00001608 y = torch.nn.functional.avg_pool2d(
1609 x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1610 padding=(0, 1), stride=2)
Nikita Shulgabc689072024-06-26 16:29:59 +00001611 self.assertFalse(torch.isnan(y).any())
Kulin Sethe011a8e2022-05-13 18:28:53 +00001612
1613
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00001614class TestMPS(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00001615 def test_exp(self, device="mps", dtype=torch.float):
1616 for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
Nikita Shulga06787422024-06-11 15:37:03 -07001617 b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
1618 a = torch.tensor(v, dtype=dtype, device="mps") * b
Kulin Sethe011a8e2022-05-13 18:28:53 +00001619 self.compare_with_numpy(torch.exp, np.exp, a)
1620
Joona Havukainen5b96a552024-06-28 20:57:37 +00001621 def test_conv_raises_error(self, device='mps', dtype=torch.float):
1622 conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps')
1623
1624 x = torch.ones([1, 1, 3])
1625 with self.assertRaises(NotImplementedError):
1626 y = conv(x.to("mps"))
1627
Joona Havukainend9eaa222024-06-18 03:44:38 +00001628 def test_triu_inf(self, device="mps", dtype=torch.float):
1629 for diag in [-1, 0, 1]:
1630 mask = torch.full((3, 6, 6), float("-inf"))
1631 mask_mps = mask.clone().detach().to('mps')
1632 cpu_ref = torch.triu(mask, diagonal=diag)
1633 mps_out = torch.triu(mask_mps, diagonal=diag)
1634 self.assertEqual(cpu_ref, mps_out)
1635
Kulin Sethe011a8e2022-05-13 18:28:53 +00001636 def test_exp1(self, device="mps", dtype=torch.float):
Nikita Shulga06787422024-06-11 15:37:03 -07001637 input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
1638 output = torch.exp(input)
1639 output_cpu = torch.exp(input.cpu())
1640 # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
1641 # Mismatched elements: 3 / 4 (75.0%)
1642 # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
1643 # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
1644 self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001645
Denis Vieriu5d483922023-02-07 16:25:03 +00001646 def test_exp_strided_output(self):
1647 x = torch.rand((256, 10), device='mps')
1648 x_cpu = x.to("cpu")
1649
1650 x = x.permute(1, 0)
1651 x_cpu = x_cpu.permute(1, 0)
1652
1653 res = x.exp()
1654 res_cpu = x_cpu.exp()
1655 self.assertEqual(res, res_cpu)
1656
Kulin Sethe011a8e2022-05-13 18:28:53 +00001657 def _testLeakyRelu(self, np_features, negative_slope, device):
1658 cpu_x = torch.from_numpy(np_features).requires_grad_()
1659 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
1660 relu_op = torch.nn.LeakyReLU(negative_slope)
1661
1662 cpu_leaky_relu = relu_op(cpu_x)
1663 mps_leaky_relu = relu_op(mps_x)
Philip Meierbc73aff2022-11-02 11:25:04 +01001664 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001665
1666 # test backward pass
1667 cpu_grad = torch.ones_like(cpu_leaky_relu)
1668 mps_grad = cpu_grad.to('mps')
1669 cpu_leaky_relu.backward(gradient=cpu_grad)
1670 mps_leaky_relu.backward(gradient=mps_grad)
Philip Meierbc73aff2022-11-02 11:25:04 +01001671 torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001672
1673 def testNumbersGPU(self):
1674 for t in [np.float32]:
1675 self._testLeakyRelu(
1676 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1677 negative_slope=0.1,
1678 device="mps")
1679
1680 def test_fill(self):
1681
Li-Huai (Allan) Lin30237aa2023-10-24 12:57:21 -07001682 def helper(val, shape, dtype):
1683 tensor = torch.zeros(shape, device='mps', dtype=dtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001684 tensor_mps = tensor.fill_(val)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001685
Li-Huai (Allan) Lin30237aa2023-10-24 12:57:21 -07001686 tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001687 tensor_cpu = tensor_0.fill_(val)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001688
1689 self.assertEqual(tensor_mps, tensor_cpu)
1690
Li-Huai (Allan) Lin30237aa2023-10-24 12:57:21 -07001691 helper(0, [1024], torch.float32)
1692 helper(0.2, [2, 3], torch.float32)
1693 helper(0.2 + 0.5j, [2, 3], torch.complex64)
Kulin Sethe011a8e2022-05-13 18:28:53 +00001694
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001695 def test_fill_storage_offset(self):
1696 shape = [2, 10]
1697 val = 0.2
1698 tensor = torch.ones(shape, device="mps")
1699 tensor_mps = tensor[:][1].fill_(val)
1700 tensor_0 = torch.ones(shape, device="cpu")
1701 tensor_cpu = tensor_0[:][1].fill_(val)
1702
1703 self.assertEqual(tensor_mps, tensor_cpu)
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001704 self.assertEqual(tensor, tensor_0)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001705
1706 shape = [1, 10]
1707 val = 0.0
1708 tensor = torch.ones(shape, device="mps")
1709 val_tensor_mps = torch.tensor(val, device="mps")
1710 tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001711 # Regression test for https://github.com/pytorch/pytorch/issues/114692
1712 tensor[:, 5].fill_(val_tensor_mps)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001713 tensor_0 = torch.ones(shape, device="cpu")
1714 val_tensor_cpu = torch.tensor(val, device="cpu")
1715 tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001716 tensor_0[:, 5].fill_(val_tensor_cpu)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001717
Nikita Shulga1b27eae2023-12-01 06:24:42 +00001718 self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
1719 self.assertEqual(tensor.to(device="cpu"), tensor_0)
Li-Huai (Allan) Lin25ee6dd2023-02-18 16:19:15 +00001720
Denis Vieriu80394bb2023-01-04 02:20:50 +00001721 def test_cdist_large(self, device="mps"):
1722 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1723 x = torch.randn(100, 10, device=device)
1724 y = torch.randn(100, 10, device=device)
1725 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1726 expected = self._brute_cdist(x, y, p=2)
1727 self.assertEqual(expected, actual)
1728
1729 def test_cdist_large_batch(self, device="mps"):
1730 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1731 x = torch.randn(4, 3, 100, 10, device=device)
1732 y = torch.randn(4, 3, 100, 10, device=device)
1733 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1734 expected = self._brute_cdist(x, y, p=2)
1735 self.assertEqual(expected, actual)
1736
1737 def test_cdist_non_contiguous(self, device="mps"):
1738 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1739 x = torch.randn(5, 7, device=device).mT
1740 y = torch.randn(5, 3, device=device).mT
1741 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1742 expected = self._brute_cdist(x, y, p=2)
1743 self.assertFalse(x.is_contiguous())
1744 self.assertFalse(y.is_contiguous())
1745 self.assertEqual(expected, actual)
1746
1747 x = torch.randn(7, 5, device=device)
1748 y = torch.randn(5, 3, device=device).t()
1749 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1750 expected = self._brute_cdist(x, y, p=2)
1751 self.assertTrue(x.is_contiguous())
1752 self.assertFalse(y.is_contiguous())
1753 self.assertEqual(expected, actual)
1754
1755 x = torch.randn(5, 7, device=device).t()
1756 y = torch.randn(3, 5, device=device)
1757 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1758 expected = self._brute_cdist(x, y, p=2)
1759 self.assertFalse(x.is_contiguous())
1760 self.assertTrue(y.is_contiguous())
1761 self.assertEqual(expected, actual)
1762
1763 def test_cdist_non_contiguous_batch(self, device="mps"):
1764 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1765 x = torch.randn(4, 3, 2, 5, 7, device=device).mT
1766 y = torch.randn(4, 3, 2, 5, 3, device=device).mT
1767 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1768 expected = self._brute_cdist(x, y, p=2)
1769 self.assertFalse(x.is_contiguous())
1770 self.assertFalse(y.is_contiguous())
1771 self.assertEqual(expected, actual)
1772
1773 x = torch.randn(7, 2, 7, 5, device=device)
1774 y = torch.randn(7, 2, 5, 3, device=device).mT
1775 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1776 expected = self._brute_cdist(x, y, p=2)
1777 self.assertTrue(x.is_contiguous())
1778 self.assertFalse(y.is_contiguous())
1779 self.assertEqual(expected, actual)
1780
1781 x = torch.randn(4, 5, 7, device=device).mT
1782 y = torch.randn(4, 3, 5, device=device)
1783 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1784 expected = self._brute_cdist(x, y, p=2)
1785 self.assertFalse(x.is_contiguous())
1786 self.assertTrue(y.is_contiguous())
1787 self.assertEqual(expected, actual)
1788
1789 def test_cdist_euclidean_large(self, device="mps"):
1790 def _test_euclidean_large_cdist(sizex, sizey=None):
1791 if sizey is None:
1792 sizey = sizex
1793 x = torch.randn(sizex, device=device, dtype=torch.float)
1794 y = torch.randn(sizey, device=device, dtype=torch.float)
1795 eps = 1e-6
1796 # to avoid extremum
1797 x = x - (((x - y) < eps).float() * 2 * eps)
1798 x.requires_grad = True
1799 y.requires_grad = True
1800 dist = torch.cdist(x, y, p=2)
1801 # Do a backward pass to check that it is valid for large
1802 # matrices
1803 loss = dist.sum()
1804 loss.backward()
1805
1806 _test_euclidean_large_cdist((2000, 5))
1807
1808 def test_cdist_same_inputs(self, device="mps"):
1809 # Test to detect issues in cdist gradient calculation
1810 # When the distances are 0
1811 sizex = (1, 27, 32)
1812 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
1813 x = torch.randn(sizex, device=device, dtype=torch.float)
1814 dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
1815 y = x.clone()
1816 eps = 1e-6
1817 x.requires_grad = True
1818 d = torch.cdist(x, y)
1819 d.backward(dist_grad)
1820 # Check that the backward passs does not contain invalid
1821 # values such as nan or inf
1822 assert torch.isfinite(x.grad).all()
1823
1824
1825 def _brute_cdist(self, x, y, p=2):
1826 r1 = x.shape[-2]
1827 r2 = y.shape[-2]
1828 if r1 == 0 or r2 == 0:
1829 return torch.empty(r1, r2, device=x.device)
1830 return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
1831
1832 def test_cdist_norm(self, device="mps"):
1833 for r1 in [3, 4]:
1834 for m in [2, 3]:
1835 for r2 in [4, 6]:
1836 for p in [0, 1, 1.5, 2.5, float('inf')]:
1837 x = torch.randn(r1, m, device=device)
1838 y = torch.randn(r2, m, device=device)
1839 if p == 2:
1840 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1841 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1842 expected = self._brute_cdist(x, y, p=2)
1843 self.assertEqual(expected, actual, rtol=0, atol=0.02)
1844 else:
1845 actual = torch.cdist(x, y, p=p)
1846 expected = self._brute_cdist(x, y, p=p)
1847 self.assertEqual(expected, actual)
1848
1849 def test_cdist_norm_batch(self, device="mps"):
1850 for r1 in [3, 4]:
1851 for m in [2, 3]:
1852 for r2 in [4, 6]:
1853 for p in [0, 3, 1.5, 2.5, float('inf')]:
1854 x = torch.randn(2, 3, 6, r1, m, device=device)
1855 y = torch.randn(2, 3, 6, r2, m, device=device)
1856 if p == 2:
1857 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1858 actual = torch.cdist(x, y, p=2, compute_mode=cm)
1859 expected = self._brute_cdist(x, y, p=2)
1860 self.assertEqual(expected, actual, rtol=0, atol=0.02)
1861 else:
1862 actual = torch.cdist(x, y, p=p)
1863 expected = self._brute_cdist(x, y, p=p)
1864 self.assertEqual(expected, actual)
1865
Kulin Sethe011a8e2022-05-13 18:28:53 +00001866 def test_mm(self):
1867 B = torch.ones(5, 6).to("mps")
1868 C = torch.ones(6, 5).to("mps")
1869 D = torch.mm(B, C).cpu()
Philip Meierbc73aff2022-11-02 11:25:04 +01001870 torch.testing.assert_close(D, torch.full((5, 5), 6.0))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001871
Denis Vieriu1a0738f2023-01-05 14:48:34 +00001872 def test_linalg_cross(self):
1873 def helper(dtype):
1874 device = "mps"
1875 if dtype is torch.int32 or dtype is torch.int64:
1876 x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1877 y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1878 else:
1879 x = torch.rand(100, 3, 100, dtype=dtype, device=device)
1880 y = torch.rand(100, 3, 100, dtype=dtype, device=device)
1881 x_cpu = x.to("cpu")
1882 y_cpu = y.to("cpu")
1883 res1 = torch.linalg.cross(x, y, dim=1)
1884 res2 = torch.tensor((), dtype=dtype, device=device)
1885 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1886 res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1887 torch.linalg.cross(x, y, dim=1, out=res2)
1888 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1889 self.assertEqual(res1, res2)
1890 self.assertEqual(res1, res1_cpu)
1891 self.assertEqual(res2, res2_cpu)
1892
1893 # test for broadcastable inputs
1894 if dtype is torch.int32 or dtype is torch.int64:
1895 x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
1896 y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
1897 else:
1898 x = torch.rand(1, 3, 2, dtype=dtype, device=device)
1899 y = torch.rand(4, 3, 1, dtype=dtype, device=device)
1900 x_cpu = x.to("cpu")
1901 y_cpu = y.to("cpu")
1902 res1 = torch.linalg.cross(x, y, dim=1)
1903 res2 = torch.tensor((), dtype=dtype, device=device)
1904 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1905 res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1906 torch.linalg.cross(x, y, dim=1, out=res2)
1907 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1908 self.assertEqual(res1, res2)
1909 self.assertEqual(res1, res1_cpu)
1910 self.assertEqual(res2, res2_cpu)
1911 [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
1912
1913 def test_cross(self):
1914 a = torch.randn(4, 3, device="mps")
1915 b = torch.randn(4, 3, device="mps")
1916 a_cpu = a.to("cpu")
1917 b_cpu = b.to("cpu")
1918 res = torch.cross(a, b, dim=1)
1919 res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
1920 self.assertEqual(res, res_cpu)
1921
Kulin Sethe011a8e2022-05-13 18:28:53 +00001922 def test_addmm(self):
1923 A = torch.ones(5, 5).to("mps")
1924 B = torch.ones(5, 6).to("mps")
1925 C = torch.ones(6, 5).to("mps")
1926 D = torch.addmm(A, B, C).to("cpu")
Philip Meierbc73aff2022-11-02 11:25:04 +01001927 torch.testing.assert_close(D, torch.full((5, 5), 7.0))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001928
1929 def test_bmm(self):
1930 batch1_cpu = torch.randn(10, 3, 4)
1931 batch2_cpu = torch.randn(10, 4, 5)
1932
1933 batch1_mps = batch1_cpu.detach().clone().to("mps")
1934 batch2_mps = batch2_cpu.detach().clone().to("mps")
1935
1936 output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
1937 output_mps = torch.bmm(batch1_mps, batch2_mps)
1938
1939 self.assertEqual(output_cpu, output_mps)
1940 self.assertEqual(output_cpu.size(), output_mps.size())
1941
Joona Havukainen92f282c2024-08-30 14:08:43 +00001942 @xfailIf(product_version < 15.0)
1943 @parametrize("dtype", [torch.float16, torch.bfloat16])
1944 def test_large_bmm(self, dtype):
1945 batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps')
1946 batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps')
1947 output_cpu = torch.bmm(batch1.cpu(), batch2.cpu())
1948 output_mps = torch.bmm(batch1, batch2)
1949
1950 # Using the low precision comparison for FP16
1951 tol = 1e-2 if dtype == torch.float16 else None
1952 self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
1953 self.assertEqual(output_cpu.size(), output_mps.size())
1954
1955
Denis Vieriu507b8c32023-02-11 00:16:46 +00001956 def test_addr(self):
1957 A = torch.ones(5, 10).to("mps")
1958 B = torch.ones(5).to("mps")
1959 C = torch.ones(10).to("mps")
1960 D = torch.addr(A, B, C).to("cpu")
1961 torch.testing.assert_close(D, torch.full((5, 10), 2.0))
1962
PumeTufc1c0cd2022-11-18 07:24:33 +00001963 def test_trace(self):
1964 M_cpu = torch.randn(3, 3)
1965 M_mps = M_cpu.detach().clone().to("mps")
1966
1967 output_cpu = torch.trace(M_cpu)
1968 output_mps = torch.trace(M_mps)
1969
1970 self.assertEqual(output_cpu, output_mps)
1971 self.assertEqual(output_cpu.size(), output_mps.size())
1972
Kulin Sethe011a8e2022-05-13 18:28:53 +00001973 def test_addbmm(self):
1974 M_cpu = torch.randn(3, 5)
1975 batch1_cpu = torch.randn(10, 3, 4)
1976 batch2_cpu = torch.randn(10, 4, 5)
1977
1978 M_mps = M_cpu.detach().clone().to("mps")
1979 batch1_mps = batch1_cpu.detach().clone().to("mps")
1980 batch2_mps = batch2_cpu.detach().clone().to("mps")
1981
1982 output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
1983 output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
1984
1985 self.assertEqual(output_cpu, output_mps)
1986 self.assertEqual(output_cpu.size(), output_mps.size())
1987
1988 def test_baddbmm(self):
Kulin Seth3d833212022-05-20 03:18:09 +00001989 def helper(input_shape, batch1_shape, batch2_shape):
1990 M_cpu = torch.randn(input_shape)
1991 batch1_cpu = torch.randn(batch1_shape)
1992 batch2_cpu = torch.randn(batch2_shape)
1993 alpha = 1.2
1994 beta = 0.8
Kulin Sethe011a8e2022-05-13 18:28:53 +00001995
Kulin Seth3d833212022-05-20 03:18:09 +00001996 M_mps = M_cpu.detach().clone().to("mps")
1997 batch1_mps = batch1_cpu.detach().clone().to("mps")
1998 batch2_mps = batch2_cpu.detach().clone().to("mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +00001999
Kulin Seth3d833212022-05-20 03:18:09 +00002000 output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
2001 output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002002
Kulin Seth3d833212022-05-20 03:18:09 +00002003 self.assertEqual(output_cpu, output_mps)
2004 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Sethd63db522022-05-28 14:41:56 +00002005
Kulin Seth3d833212022-05-20 03:18:09 +00002006 helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
2007 helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
2008 helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
Kulin Sethe011a8e2022-05-13 18:28:53 +00002009
2010 def test_local_scalar_dense_mps(self):
2011 x_cpu = torch.randn(1)
2012 y_mps = x_cpu.to("mps")
Philip Meierbc73aff2022-11-02 11:25:04 +01002013 torch.testing.assert_close(x_cpu.item(), y_mps.item())
Kulin Sethe011a8e2022-05-13 18:28:53 +00002014
Kulin Seth7ff6a002022-09-28 00:43:11 +00002015 def test_linear_1d_weight(self):
2016 device = 'cpu'
2017 projected = torch.rand([8]).to(device)
2018 x = torch.rand([1, 2, 2, 8]).to(device)
2019 x_mps = x.to('mps')
2020 projected_mps = projected.to('mps')
2021 linear = F.linear(x, projected)
2022 linear_mps = F.linear(x_mps, projected_mps)
2023
2024 self.assertEqual(linear, linear_mps)
2025
2026 projected = torch.rand([1, 8]).to(device)
2027 x = torch.rand([1, 2, 2, 8]).to(device)
2028 x_mps = x.to('mps')
2029 projected_mps = projected.to('mps')
2030 linear = F.linear(x, projected)
2031 linear_mps = F.linear(x_mps, projected_mps)
2032
2033 self.assertEqual(linear, linear_mps)
2034
Li-Huai (Allan) Lin1fcf40d2023-04-26 12:11:22 +08002035 def test_linear_bias(self):
2036 def helper(bias_shape):
2037 device = "cpu"
2038 x = torch.randn(2, 2, 2, 64, device=device)
2039 linear = torch.nn.Linear(64, 4, device=device)
2040 linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device))
2041 y = linear(x)
2042 device = "mps"
2043 x_mps = x.to(device)
2044 linear.to(device)
2045 y_mps = linear(x_mps)
2046 self.assertEqual(y, y_mps)
2047
2048 helper(())
2049 helper((2, 4))
2050
Nikita Shulgadb3a2d72024-04-25 23:25:20 +00002051 def test_linear_errors(self):
2052 # Mixed CPU<->MPS tensors
2053 size = (3, 3)
2054
2055 # Unsupported dtypes
2056 with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
2057 torch.nn.functional.linear(torch.rand(size, device='mps'),
2058 torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
2059
2060 # Weigths on wrong device
2061 with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
2062 torch.nn.functional.linear(torch.rand(size, device='mps'),
2063 torch.rand(size, device='cpu'))
2064
2065 # Input on wrong device
2066 with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
2067 torch.nn.functional.linear(torch.rand(size, device='cpu'),
2068 torch.rand(size, device='mps'))
2069
Kulin Sethe011a8e2022-05-13 18:28:53 +00002070 def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
2071 cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
2072 mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
2073
2074 # Use the same weights and bias as the ones from the cpu
2075 mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
2076
2077 if bias:
2078 mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
2079
2080 linear_mps_input = torch.randn(shape).to('mps')
2081 linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
2082
2083 if backward_pass:
2084 linear_mps_input = linear_mps_input.requires_grad_()
2085 linear_cpu_input = linear_cpu_input.requires_grad_()
2086
2087 linear_cpu_output = cpu_linear(linear_cpu_input)
2088 linear_mps_output = mps_linear(linear_mps_input)
2089
2090 self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
2091 self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
2092
2093 if backward_pass:
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00002094 cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True)
2095 grad = cpu_grad.detach().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00002096
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00002097 linear_cpu_output.backward(gradient=cpu_grad, create_graph=True)
2098 linear_mps_output.backward(gradient=grad, create_graph=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002099
2100 self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
2101 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2102
2103 self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
2104 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2105 if bias:
2106 self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
2107 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2108
Li-Huai (Allan) Lin77766532023-03-30 07:24:58 +00002109 # test gradgrad
2110 x_grad_out = torch.rand_like(linear_cpu_input)
2111 x_grad_out_mps = x_grad_out.to("mps")
2112 w_grad_out = torch.rand_like(cpu_linear.weight)
2113 w_grad_out_mps = w_grad_out.to("mps")
2114
2115 linear_cpu_input.grad.detach().zero_()
2116 linear_mps_input.grad.detach().zero_()
2117 cpu_linear.weight.grad.detach().zero_()
2118 mps_linear.weight.grad.detach().zero_()
2119 if bias:
2120 b_grad_out = torch.rand_like(cpu_linear.bias)
2121 b_grad_out_mps = b_grad_out.to("mps")
2122 cpu_linear.bias.grad.detach().zero_()
2123 mps_linear.bias.grad.detach().zero_()
2124
2125 linear_cpu_input.grad.backward(x_grad_out, retain_graph=True)
2126 linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True)
2127 cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True)
2128 mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True)
2129 if bias:
2130 cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True)
2131 mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True)
2132
2133 self.assertEqual(cpu_grad.grad, grad.grad)
2134 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad)
2135 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad)
2136 if bias:
2137 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad)
2138
Ramin Azarmehr0e3953f2022-07-04 02:06:14 +00002139 def test_linear1D(self):
2140 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False)
2141
2142 def test_linear1D_backward(self):
2143 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True)
2144
Kulin Sethe011a8e2022-05-13 18:28:53 +00002145 def test_linear2D(self):
2146 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
2147
2148 def test_linear2D_backward(self):
2149 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
2150
2151 def test_linear2D_no_bias(self):
2152 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
2153
2154 def test_linear2D_no_bias_backward(self):
2155 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
2156
2157 def test_linear3D(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002158 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002159
Nikita Shulga70508262022-05-25 16:23:10 +00002160 def test_linear3D_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002161 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002162
2163 def test_linear3D_no_bias(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002164 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002165
2166 def test_linear3D_no_bias_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002167 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002168
2169 def test_uniform(self):
2170 low = torch.zeros(5, 5, requires_grad=True)
2171 high = (torch.ones(5, 5) * 3).requires_grad_()
2172 low_1d = torch.zeros(1, requires_grad=True)
2173 high_1d = (torch.ones(1) * 3).requires_grad_()
2174 self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
2175 self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
Kulin Seth3d833212022-05-20 03:18:09 +00002176 self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
2177 self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
2178 self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
Kulin Sethe011a8e2022-05-13 18:28:53 +00002179
Kulin Seth3d833212022-05-20 03:18:09 +00002180 # Check log_prob computation when value outside range
2181 uniform = Uniform(low_1d, high_1d, validate_args=False)
2182 above_high = torch.tensor([4.0])
2183 below_low = torch.tensor([-1.0])
2184 self.assertEqual(uniform.log_prob(above_high).item(), -inf)
2185 self.assertEqual(uniform.log_prob(below_low).item(), -inf)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002186
Kulin Seth3d833212022-05-20 03:18:09 +00002187 # check cdf computation when value outside range
2188 self.assertEqual(uniform.cdf(below_low).item(), 0)
2189 self.assertEqual(uniform.cdf(above_high).item(), 1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002190
Kulin Seth3d833212022-05-20 03:18:09 +00002191 state = torch.get_rng_state()
2192 rand = low.new(low.size()).uniform_()
2193 torch.set_rng_state(state)
2194 u = Uniform(low, high).rsample()
2195 u.backward(torch.ones_like(u))
2196 self.assertEqual(low.grad, 1 - rand)
2197 self.assertEqual(high.grad, rand)
2198 low.grad.zero_()
2199 high.grad.zero_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00002200
Denis Vieriu53ef96f2023-01-06 22:49:04 +00002201 def test_randperm(self, device="mps"):
2202 rng_device = None
2203 for n in (5, 100, 50000, 100000):
2204 for dtype in (torch.long, torch.half, torch.float):
2205 if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here.
2206 continue
2207 if n > 256 and dtype == torch.bfloat16:
2208 continue
2209 with torch.random.fork_rng(devices=rng_device):
2210 res1 = torch.randperm(n, dtype=dtype, device=device)
2211 res2 = torch.empty(0, dtype=dtype, device=device)
2212 torch.randperm(n, out=res2, dtype=dtype, device=device)
2213 self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
2214
2215 # Default type is long
2216 for n in (100, 10000):
2217 self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
2218
2219 # randperm of 0 elements is an empty tensor
2220 res1 = torch.randperm(0)
2221 res2 = torch.tensor(5, dtype=dtype, device=device)
2222 torch.randperm(0, out=res2)
2223 self.assertEqual(res1.numel(), 0)
2224 self.assertEqual(res2.numel(), 0)
2225
2226 # Test non-contiguous tensors
2227 for n in (4, 5, 6, 10, 20):
2228 non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
2229 self.assertFalse(non_contiguous_tensor.is_contiguous())
2230 with torch.random.fork_rng(devices=rng_device):
2231 res = torch.randperm(n, dtype=torch.long, device=device)
2232 torch.randperm(n, out=non_contiguous_tensor)
2233 self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
2234
Kulin Sethe011a8e2022-05-13 18:28:53 +00002235 # Test forward maxpool2d
2236 def test_max_pool2d(self):
2237 def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
2238
2239 cpu_x = None
Thomas4935b592022-11-23 02:18:03 +00002240 if (test_ties):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002241 cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
2242 else:
2243 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2244 x = cpu_x.detach().clone().to('mps').requires_grad_()
2245
2246 pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
2247 ceil_mode=ceil_mode, return_indices=return_indices)
2248
Thomas4935b592022-11-23 02:18:03 +00002249 if (return_indices is False):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002250 y = pool(x)
2251 ref_y = pool(cpu_x)
2252
2253 cpu_grad = torch.ones_like(ref_y)
2254 grad = cpu_grad.to('mps')
2255
2256 y.backward(gradient=grad)
2257 ref_y.backward(gradient=cpu_grad)
2258
2259 self.assertEqual(y, ref_y)
2260 self.assertEqual(x.grad, cpu_x.grad)
2261 else:
2262 y, idx = pool(x)
2263 ref_y, ref_idx = pool(cpu_x)
2264
2265 cpu_grad = torch.ones_like(ref_y)
2266 grad = cpu_grad.to('mps')
2267
2268 y.backward(gradient=grad)
2269 ref_y.backward(gradient=cpu_grad)
2270
2271 self.assertEqual(y, ref_y)
2272 self.assertEqual(idx, ref_idx)
2273 self.assertEqual(x.grad, cpu_x.grad)
2274
2275 # Test with no batch dimension
2276 helper((8, 4, 4), ks=2)
2277 helper((2, 8, 4, 4), ks=2)
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002278 helper((1, 1000, 32, 32), ks=4)
2279 helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002280 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002281 helper((1, 1000, 32, 32), ks=4, padding=1)
2282 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002283 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002284 helper((1, 1000, 32, 32), ks=4, dilation=2)
2285 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002286 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002287 helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
2288 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002289
2290 # Test return indices
2291 for test_ties in [False, True]:
2292 # Test with no batch dimension
2293 helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
2294 helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002295 helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
2296 helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +00002297 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002298 helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
2299 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
Kulin Sethe011a8e2022-05-13 18:28:53 +00002300 return_indices=True, test_ties=test_ties) # test for max_pool1d
2301 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002302 helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
2303 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
Kulin Sethe011a8e2022-05-13 18:28:53 +00002304 return_indices=True, test_ties=test_ties) # test for max_pool1d
2305 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002306 helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
2307 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
Kulin Sethe011a8e2022-05-13 18:28:53 +00002308 return_indices=True, test_ties=test_ties) # test for max_pool1d
2309
2310 def test_adaptive_avg_pool2d_output_size_one(self):
2311 def helper(size, memory_format):
2312 x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
Kulin Seth3d833212022-05-20 03:18:09 +00002313 if memory_format == 'non_contiguous':
2314 x = x[::2, ::2, ::2, ::2]
2315 else:
2316 x = x.to(memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002317
2318 net = torch.nn.AdaptiveAvgPool2d((1, 1))
2319 out = net(x)
2320 ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
2321
2322 out.sum().backward() # make sure it doesn't crash
2323
2324 self.assertEqual(out, ref_out)
2325 if memory_format == torch.channels_last:
2326 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
2327 c = out.size(1)
2328 self.assertEqual(out.stride(), [c, 1, c, c])
2329 else:
2330 self.assertTrue(out.is_contiguous())
2331 c = out.size(1)
2332 self.assertEqual(out.stride(), [c, 1, 1, 1])
2333
2334 helper((2, 3, 6, 6), torch.contiguous_format)
2335
Denis Vieriued1957d2023-03-01 01:36:36 +00002336 def test_masked_scatter(self):
2337 def helper(shape):
2338 x_mps = torch.randn(shape, device="mps")
2339 x_cpu = x_mps.detach().clone().cpu()
2340
2341 mask_mps = torch.rand(shape, device="mps") < 0.6
2342 mask_cpu = mask_mps.detach().clone().cpu()
2343
2344 y_mps = torch.randn(shape, device="mps")
2345 y_cpu = y_mps.detach().clone().cpu()
2346
2347 y_mps.masked_scatter_(mask_mps, x_mps)
2348 y_cpu.masked_scatter_(mask_cpu, x_cpu)
2349
2350 self.assertEqual(y_mps, y_cpu)
2351 helper([2, 5])
2352 helper([10, 10])
2353 helper([5, 10, 3])
2354 helper([10, 5, 10, 3])
2355 helper([10, 5, 10, 3, 20])
2356
Kulin Seth3d833212022-05-20 03:18:09 +00002357 def test_masked_fill(self):
2358 device = "mps"
2359 dtype = torch.float32
2360 mask_dtype = torch.bool
Li-Huai (Allan) Lina147fa52024-07-29 16:01:19 -07002361 num_dest = 10
Kulin Seth3d833212022-05-20 03:18:09 +00002362
Li-Huai (Allan) Lina147fa52024-07-29 16:01:19 -07002363 dst = torch.zeros(num_dest, dtype=dtype, device=device)
2364 mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
2365 val = random.random()
2366 dst2 = torch.zeros(num_dest, dtype=dtype)
2367 mask_cpu = mask.to("cpu")
Kulin Seth3d833212022-05-20 03:18:09 +00002368
Li-Huai (Allan) Lina147fa52024-07-29 16:01:19 -07002369 dst.masked_fill_(mask, val)
2370 for i in range(num_dest):
2371 if mask_cpu[i]:
2372 dst2[i] = val
2373 self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
Kulin Seth3d833212022-05-20 03:18:09 +00002374
Li-Huai (Allan) Lina147fa52024-07-29 16:01:19 -07002375 def test_masked_fill__non_contiguous(self):
2376 shape = (3, 5)
Kulin Seth3d833212022-05-20 03:18:09 +00002377
Li-Huai (Allan) Lina147fa52024-07-29 16:01:19 -07002378 x_mps = torch.randn(shape, device="mps")
2379 x_cpu = x_mps.detach().clone().cpu()
2380 mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
2381 mask_cpu = mask_mps.detach().clone().cpu()
Kulin Seth3d833212022-05-20 03:18:09 +00002382
Li-Huai (Allan) Lina147fa52024-07-29 16:01:19 -07002383 x_mps_strided = x_mps.T
2384 x_cpu_strided = x_cpu.T
2385
2386 x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
2387 x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
2388
2389 self.assertEqual(x_mps_strided, x_cpu_strided)
2390 self.assertFalse((x_mps_strided == float("-inf")).any())
Kulin Seth3d833212022-05-20 03:18:09 +00002391
2392 def test_nhwc_operation(self):
2393 def helper(shape, channels_last=False):
2394 import numpy as np
2395 np.random.seed(332)
2396 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2397 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00002398 if (channels_last):
Kulin Seth3d833212022-05-20 03:18:09 +00002399 cpu_x = cpu_x.to(memory_format=torch.channels_last)
2400 cpu_x.retain_grad()
2401 x = cpu_x.detach().clone().to('mps').requires_grad_()
2402
2403 # This passes
2404 self.assertEqual(x, cpu_x)
2405
2406 helper((2, 2, 2, 2), True)
2407
Kulin Sethe011a8e2022-05-13 18:28:53 +00002408 # Test forward batch norm
2409 def test_batch_norm(self):
2410 def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
2411 track_running_stats=True, test_module=False):
2412
2413 import numpy as np
2414 np.random.seed(332)
2415 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2416 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00002417 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002418 cpu_x = cpu_x.to(memory_format=torch.channels_last)
2419 cpu_x.retain_grad()
2420 x = cpu_x.detach().clone().to('mps').requires_grad_()
2421
2422 mean_shape = [shape[1]]
2423 cpu_running_mean = None
2424 cpu_running_var = None
2425 running_mean = None
2426 running_var = None
Thomas4935b592022-11-23 02:18:03 +00002427 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002428 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2429 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2430 var_arr = 32 * np.random.random_sample(size=mean_shape)
2431 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2432 running_mean = cpu_running_mean.detach().clone().to('mps')
2433 running_var = cpu_running_var.detach().clone().to('mps')
2434
2435 weight = None
2436 cpu_weight = None
2437 bias = None
2438 cpu_bias = None
Thomas4935b592022-11-23 02:18:03 +00002439 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002440 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2441 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2442 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2443 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2444
2445 y = None
2446 ref_y = None
2447
Thomas4935b592022-11-23 02:18:03 +00002448 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002449 y = torch.nn.functional.batch_norm(x, running_mean, running_var,
2450 weight=weight,
2451 bias=bias,
2452 training=training,
2453 momentum=momentum, eps=eps)
2454 ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
2455 weight=cpu_weight,
2456 bias=cpu_bias,
2457 training=training,
2458 momentum=momentum, eps=eps)
2459
2460 else:
2461
2462 batchnorm_op = None
2463 mps_batchnorm_op = None
2464
Thomas4935b592022-11-23 02:18:03 +00002465 if (len(shape) == 3):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002466 batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2467 eps=eps,
2468 momentum=momentum,
2469 affine=wts,
2470 track_running_stats=track_running_stats,
2471 device='cpu')
2472 mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2473 eps=eps,
2474 momentum=momentum,
2475 affine=wts,
2476 track_running_stats=track_running_stats,
2477 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002478 elif (len(shape) == 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002479 batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2480 eps=eps,
2481 momentum=momentum,
2482 affine=wts,
2483 track_running_stats=track_running_stats,
2484 device='cpu')
2485 mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2486 eps=eps,
2487 momentum=momentum,
2488 affine=wts,
2489 track_running_stats=track_running_stats,
2490 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002491 elif (len(shape) == 5):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002492 batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2493 eps=eps,
2494 momentum=momentum,
2495 affine=wts,
2496 track_running_stats=track_running_stats,
2497 device='cpu')
2498 mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2499 eps=eps,
2500 momentum=momentum,
2501 affine=wts,
2502 track_running_stats=track_running_stats,
2503 device='mps')
2504
Thomas4935b592022-11-23 02:18:03 +00002505 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002506 batchnorm_op.running_mean = cpu_running_mean
2507 batchnorm_op.running_var = cpu_running_var
2508 mps_batchnorm_op.running_mean = running_mean
2509 mps_batchnorm_op.running_var = running_var
Thomas4935b592022-11-23 02:18:03 +00002510 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002511 batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
2512 batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
2513 mps_batchnorm_op.weight = torch.nn.Parameter(weight)
2514 mps_batchnorm_op.bias = torch.nn.Parameter(bias)
2515
2516 ref_y = batchnorm_op(cpu_x)
2517 y = mps_batchnorm_op(x)
2518
2519 self.assertEqual(y, ref_y)
Thomas4935b592022-11-23 02:18:03 +00002520 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002521 self.assertEqual(running_mean, cpu_running_mean)
2522 self.assertEqual(running_var, cpu_running_var)
2523 else:
2524 self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
2525 self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
2526
2527 cpu_grad = torch.randn(ref_y.shape)
2528 grad = cpu_grad.to('mps')
2529 ref_y.backward(gradient=cpu_grad)
2530 y.backward(gradient=grad)
2531
2532 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002533 if (wts):
2534 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002535 self.assertEqual(weight.grad, cpu_weight.grad)
2536 self.assertEqual(bias.grad, cpu_bias.grad)
2537 else:
2538 self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
2539 self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
2540
2541 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2542 for test_module in [False, True]:
2543 for track_running_stats in [True, False]:
Kulin Seth3d833212022-05-20 03:18:09 +00002544 for channels_last in [False]:
Thomas4935b592022-11-23 02:18:03 +00002545 if (channels_last and len(shape) != 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002546 continue
2547 # Running stats must be tracked in eval mode
Thomas4935b592022-11-23 02:18:03 +00002548 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002549 helper(shape, eps=0, momentum=1, channels_last=channels_last,
2550 track_running_stats=track_running_stats, test_module=test_module)
2551 helper(shape, channels_last=channels_last,
2552 track_running_stats=track_running_stats, test_module=test_module)
2553 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
2554 track_running_stats=track_running_stats, test_module=test_module)
2555 helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
2556 track_running_stats=track_running_stats, test_module=test_module)
2557 helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
2558 track_running_stats=track_running_stats, test_module=test_module)
2559 helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
2560 track_running_stats=track_running_stats, test_module=test_module)
2561 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
2562 track_running_stats=track_running_stats, test_module=test_module)
2563 helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
2564 track_running_stats=track_running_stats, test_module=test_module)
2565 helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
2566 track_running_stats=track_running_stats, test_module=test_module)
2567 helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
2568 track_running_stats=track_running_stats, test_module=test_module)
2569
Nikita Shulga583193e2023-04-11 17:23:36 +00002570 def test_batch_norm_backward(self):
Nikita Shulga24428582023-04-29 03:37:35 +00002571 inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
Nikita Shulga583193e2023-04-11 17:23:36 +00002572 x = torch.nn.BatchNorm2d(8).to("mps")
2573 y = torch.nn.BatchNorm2d(8).to("mps")
2574 y.weight.requires_grad = False
2575 y.bias.requires_grad = False
2576 outputs = y(x(inputs))
2577 # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2578 outputs.sum().backward()
2579
Roy Hvaara43f78bf2024-08-20 18:24:48 +00002580 # Regression test for https://github.com/pytorch/pytorch/issues/133520
2581 def test_batch_norm_slices(self):
2582 bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
2583 bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
2584
2585 x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
2586 x_mps = x_cpu.to('mps')
2587
2588 res_cpu = bn_cpu(x_cpu[5:])
2589 res_mps = bn_mps(x_mps[5:])
2590
2591 self.assertEqual(res_cpu, res_mps)
2592
Nikita Shulga24428582023-04-29 03:37:35 +00002593 def test_layer_norm_backward(self):
2594 inputs = torch.rand(4, 4, device="mps", requires_grad=True)
2595 x = torch.nn.LayerNorm(4).to("mps")
2596 y = torch.nn.LayerNorm(4).to("mps")
2597 y.weight.requires_grad = False
2598 y.bias.requires_grad = False
2599 outputs = y(x(inputs))
2600 # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2601 outputs.sum().backward()
2602
Denis Vieriu80394bb2023-01-04 02:20:50 +00002603 def test_norm(self):
2604 a = torch.arange(9, dtype=torch.float, device="mps") - 4
2605 b = a.reshape((3, 3))
2606
2607 a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
2608 b_cpu = a_cpu.reshape((3, 3))
2609
2610 res = torch.norm(a)
2611 res_cpu = torch.norm(a_cpu)
2612 self.assertEqual(res, res_cpu)
2613
2614 res = torch.norm(b)
2615 res_cpu = torch.norm(b_cpu)
2616 self.assertEqual(res, res_cpu)
2617
2618 res = torch.norm(a, float('inf'))
2619 res_cpu = torch.norm(a_cpu, float('inf'))
2620 self.assertEqual(res, res_cpu)
2621
2622 res = torch.norm(b, float('inf'))
2623 res_cpu = torch.norm(b_cpu, float('inf'))
2624 self.assertEqual(res, res_cpu)
2625
2626 c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
2627 c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
2628
2629 res = torch.norm(c, dim=0)
2630 res_cpu = torch.norm(c_cpu, dim=0)
2631 self.assertEqual(res, res_cpu)
2632
2633 res = torch.norm(c, dim=1)
2634 res_cpu = torch.norm(c_cpu, dim=1)
2635 self.assertEqual(res, res_cpu)
2636
2637 res = torch.norm(c, p=1, dim=1)
2638 res_cpu = torch.norm(c_cpu, p=1, dim=1)
2639 self.assertEqual(res, res_cpu)
2640
2641 d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
2642 d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
2643
2644 res = torch.norm(d, dim=(1, 2))
2645 res_cpu = torch.norm(d_cpu, dim=(1, 2))
2646 self.assertEqual(res, res_cpu)
2647
2648 res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
2649 res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
2650 self.assertEqual(res, res_cpu)
2651
Denis Vieriu89baa1a2023-04-26 01:34:24 +00002652 def test_linalg_vector_norm(self):
2653 x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps")
2654 x_cpu = x_mps.detach().clone().cpu()
2655
2656 res_mps = torch.linalg.vector_norm(x_mps, ord=0)
2657 res_cpu = torch.linalg.vector_norm(x_cpu, ord=0)
2658 self.assertEqual(res_mps, res_cpu)
2659
2660 a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4
2661 a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4
2662
2663 B_mps = a_mps.reshape(3, 3, 3)
2664 B_cpu = a_cpu.reshape(3, 3, 3)
2665
2666 res_mps = torch.linalg.vector_norm(a_mps, ord=3.5)
2667 res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5)
2668 self.assertEqual(res_mps, res_cpu)
2669
2670 res_mps = torch.linalg.vector_norm(B_mps, ord=3.5)
2671 res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5)
2672 self.assertEqual(res_mps, res_cpu)
2673
2674 for dim in range(0, B_mps.dim()):
2675 res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim)
2676 res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim)
2677 self.assertEqual(res_mps, res_cpu)
2678
2679
Kulin Seth77b68852022-06-10 13:25:41 +00002680 def test_layer_norm(self):
2681 # TODO: Test non-contiguous
2682 def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):
2683 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
2684 x = cpu_x.detach().clone().to('mps').requires_grad_()
2685
2686 cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype)
2687 mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
2688 cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2689 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2690 cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2691 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2692
Thomas4935b592022-11-23 02:18:03 +00002693 if (elementwise_affine):
Kulin Seth77b68852022-06-10 13:25:41 +00002694 cpu_op.weight = torch.nn.Parameter(cpu_wt)
2695 mps_op.weight = torch.nn.Parameter(wt)
2696 cpu_op.bias = torch.nn.Parameter(cpu_bias)
2697 mps_op.bias = torch.nn.Parameter(bias)
2698
2699 cpu_result = cpu_op(cpu_x)
2700 result = mps_op(x)
2701
2702 cpu_grad = torch.randn(cpu_result.shape)
2703 grad = cpu_grad.to('mps')
2704
2705 cpu_result.backward(cpu_grad)
2706 result.backward(grad)
2707
2708 self.assertEqual(result, cpu_result)
2709 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002710 if (elementwise_affine):
Kulin Seth77b68852022-06-10 13:25:41 +00002711 self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad)
2712 self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad)
2713
2714 for elementwise_affine in [True, False]:
2715 helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine)
2716 helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
2717 helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
2718
Nikita Shulga075a4942023-03-09 22:09:10 +00002719 # Regression test for https://github.com/pytorch/pytorch/issues/96113
2720 torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
2721
jhavukainen6a539e82024-05-22 21:48:49 +00002722 @xfailIf(product_version < 14.0)
2723 def test_ifft(self):
2724 # See: https://github.com/pytorch/pytorch/issues/124096
2725 device = torch.device("mps")
2726
2727 N = 64
2728 signal = torch.rand(N, device=device)
2729 fft_result = torch.fft.rfft(signal)
2730 ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0])
2731
2732 # Expecting the inverted to yield the original signal
2733 self.assertEqual(ifft_result, signal)
2734
pytorchbot783a6a42024-10-02 15:19:35 -07002735 # Regression test for https://github.com/pytorch/pytorch/issues/135223
2736 def test_fftfreq(self):
2737 freq_cpu = torch.fft.fftfreq(10**4, device='cpu')
2738 freq_mps = torch.fft.fftfreq(10**4, device='mps')
2739 self.assertEqual(freq_cpu, freq_mps)
2740
Kulin Sethe011a8e2022-05-13 18:28:53 +00002741 def test_instance_norm(self):
2742 def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
2743
2744 import numpy as np
2745 np.random.seed(332)
2746 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2747 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00002748 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002749 cpu_x = cpu_x.to(memory_format=torch.channels_last)
2750 cpu_x.retain_grad()
2751 x = cpu_x.detach().clone().to('mps').requires_grad_()
2752
2753 mean_shape = [shape[1]]
2754 cpu_running_mean = None
2755 cpu_running_var = None
2756 running_mean = None
2757 running_var = None
Thomas4935b592022-11-23 02:18:03 +00002758 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002759 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2760 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2761 var_arr = 32 * np.random.random_sample(size=mean_shape)
2762 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2763 running_mean = cpu_running_mean.detach().clone().to('mps')
2764 running_var = cpu_running_var.detach().clone().to('mps')
2765
2766 weight = None
2767 cpu_weight = None
2768 bias = None
2769 cpu_bias = None
Thomas4935b592022-11-23 02:18:03 +00002770 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002771 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2772 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2773 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2774 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2775
2776 y = None
2777 ref_y = None
2778
Thomas4935b592022-11-23 02:18:03 +00002779 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002780 ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
2781 weight=cpu_weight,
2782 bias=cpu_bias,
2783 momentum=momentum, eps=eps)
2784 y = torch.nn.functional.instance_norm(x, running_mean, running_var,
2785 weight=weight,
2786 bias=bias,
2787 momentum=momentum, eps=eps)
2788
2789 else:
2790
2791 instancenorm_op = None
2792 mps_instancenorm_op = None
2793
Thomas4935b592022-11-23 02:18:03 +00002794 if (len(shape) == 3):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002795 instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2796 eps=eps,
2797 momentum=momentum,
2798 affine=wts,
2799 track_running_stats=track_running_stats,
2800 device='cpu')
2801 mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2802 eps=eps,
2803 momentum=momentum,
2804 affine=wts,
2805 track_running_stats=track_running_stats,
2806 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002807 elif (len(shape) == 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002808 instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2809 eps=eps,
2810 momentum=momentum,
2811 affine=wts,
2812 track_running_stats=track_running_stats,
2813 device='cpu')
2814 mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2815 eps=eps,
2816 momentum=momentum,
2817 affine=wts,
2818 track_running_stats=track_running_stats,
2819 device='mps')
Thomas4935b592022-11-23 02:18:03 +00002820 elif (len(shape) == 5):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002821 instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2822 eps=eps,
2823 momentum=momentum,
2824 affine=wts,
2825 track_running_stats=track_running_stats,
2826 device='cpu')
2827 mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2828 eps=eps,
2829 momentum=momentum,
2830 affine=wts,
2831 track_running_stats=track_running_stats,
2832 device='mps')
2833
Thomas4935b592022-11-23 02:18:03 +00002834 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002835 instancenorm_op.running_mean = cpu_running_mean
2836 instancenorm_op.running_var = cpu_running_var
2837 mps_instancenorm_op.running_mean = running_mean
2838 mps_instancenorm_op.running_var = running_var
Thomas4935b592022-11-23 02:18:03 +00002839 if (wts):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002840 instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
2841 instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
2842 mps_instancenorm_op.weight = torch.nn.Parameter(weight)
2843 mps_instancenorm_op.bias = torch.nn.Parameter(bias)
2844
2845 ref_y = instancenorm_op(cpu_x)
2846 y = mps_instancenorm_op(x)
2847
2848 self.assertEqual(y, ref_y)
Thomas4935b592022-11-23 02:18:03 +00002849 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002850 self.assertEqual(running_mean, cpu_running_mean)
2851 self.assertEqual(running_var, cpu_running_var)
2852 else:
2853 self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
2854 self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
2855
2856 cpu_grad = torch.randn(ref_y.shape)
2857 grad = cpu_grad.to('mps')
2858 ref_y.backward(gradient=cpu_grad)
2859 y.backward(gradient=grad)
2860
2861 self.assertEqual(x.grad, cpu_x.grad)
Thomas4935b592022-11-23 02:18:03 +00002862 if (wts):
2863 if (not test_module):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002864 self.assertEqual(weight.grad, cpu_weight.grad)
2865 self.assertEqual(bias.grad, cpu_bias.grad)
2866 else:
2867 self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
2868 self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
2869
2870 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2871 for test_module in [False, True]:
2872 for track_running_stats in [True, False]:
2873 for channels_last in [False]:
Thomas4935b592022-11-23 02:18:03 +00002874 if (channels_last and len(shape) != 4):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002875 continue
2876 # Running stats must be tracked in eval mode
Thomas4935b592022-11-23 02:18:03 +00002877 if (track_running_stats):
Kulin Sethe011a8e2022-05-13 18:28:53 +00002878 helper(shape, eps=0, momentum=1, channels_last=channels_last,
2879 track_running_stats=track_running_stats, test_module=test_module)
2880 helper(shape, channels_last=channels_last,
2881 track_running_stats=track_running_stats, test_module=test_module)
2882 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2883 track_running_stats=track_running_stats, test_module=test_module)
2884 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2885 track_running_stats=track_running_stats, test_module=test_module)
2886 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2887 track_running_stats=track_running_stats, test_module=test_module)
2888 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2889 track_running_stats=track_running_stats, test_module=test_module)
2890 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2891 track_running_stats=track_running_stats, test_module=test_module)
2892 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2893 track_running_stats=track_running_stats, test_module=test_module)
2894 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2895 track_running_stats=track_running_stats, test_module=test_module)
2896 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2897 track_running_stats=track_running_stats, test_module=test_module)
2898
igm50303176262023-09-20 02:18:24 +00002899 def test_weight_norm(self):
Nikita Shulga27458cc2024-06-14 11:23:27 -07002900 def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim):
Nikita Shulga9035fff2024-06-14 11:23:30 -07002901 cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim)
2902 norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim)
Nikita Shulga27458cc2024-06-14 11:23:27 -07002903
2904 cpu_out = cpu_norm(cpu_x)
2905 out = norm(x)
2906
2907 self.assertEqual(cpu_out, out)
2908
2909 cpu_grad = torch.randn(cpu_out.shape)
2910 grad = cpu_grad.to('mps')
2911 cpu_out.backward(gradient=cpu_grad)
2912 out.backward(gradient=grad)
2913
Nikita Shulga9035fff2024-06-14 11:23:30 -07002914 self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad)
2915 self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad)
Nikita Shulga27458cc2024-06-14 11:23:27 -07002916
2917 self.assertEqual(x.grad, cpu_x.grad)
2918
igm50303176262023-09-20 02:18:24 +00002919 def helper(dim, layer='linear', dtype=torch.float32):
2920 # linear layer
2921 if layer == 'linear':
2922 cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True)
2923 x = cpu_x.detach().clone().to('mps').requires_grad_()
2924
2925 cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True)
2926 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2927
2928 cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True)
2929 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2930
2931 cpu_linear = torch.nn.Linear(5, 10, device='cpu')
2932 linear = torch.nn.Linear(5, 10, device='mps')
2933
2934 with torch.no_grad():
2935 cpu_linear.weight.copy_(cpu_weight)
2936 cpu_linear.bias.copy_(cpu_bias)
2937 linear.weight.copy_(weight)
2938 linear.bias.copy_(bias)
Nikita Shulga27458cc2024-06-14 11:23:27 -07002939 validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim)
igm50303176262023-09-20 02:18:24 +00002940
2941 # conv layer
2942 if layer == 'conv':
2943 cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True)
2944 x = cpu_x.detach().clone().to('mps').requires_grad_()
2945
2946 cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu')
2947 conv = torch.nn.Conv2d(3, 3, 3, device='mps')
2948
2949 with torch.no_grad():
2950 conv.weight.copy_(cpu_conv.weight)
2951 conv.bias.copy_(cpu_conv.bias)
2952
Nikita Shulga27458cc2024-06-14 11:23:27 -07002953 validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
igm50303176262023-09-20 02:18:24 +00002954
Nikita Shulga27458cc2024-06-14 11:23:27 -07002955 # conv3d layer
Lucas Steuernagel2e517b22023-12-15 23:05:01 +00002956 if layer == 'conv3d':
2957 cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
2958 x = cpu_x.detach().clone().to('mps').requires_grad_()
2959
2960 cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
2961 conv = torch.nn.Conv3d(3, 3, 3, device='mps')
2962
2963 with torch.no_grad():
2964 conv.weight.copy_(cpu_conv.weight)
2965 conv.bias.copy_(cpu_conv.bias)
2966
Nikita Shulga27458cc2024-06-14 11:23:27 -07002967 validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
igm50303176262023-09-20 02:18:24 +00002968
2969 helper(0, layer='linear')
2970 helper(1, layer='linear')
2971 helper(-1, layer='linear')
2972
2973 helper(0, layer='conv')
2974 helper(1, layer='conv')
2975 helper(2, layer='conv')
2976 helper(3, layer='conv')
2977 helper(-1, layer='conv')
2978
Lucas Steuernagel2e517b22023-12-15 23:05:01 +00002979 if product_version >= 13.2:
2980 # Conv3d is only available from MacOS 13 onwards
2981 helper(0, layer='conv3d')
2982 helper(1, layer='conv3d')
2983 helper(2, layer='conv3d')
2984 helper(3, layer='conv3d')
2985 helper(4, layer='conv3d')
2986 helper(-1, layer='conv3d')
2987
Kulin Sethe011a8e2022-05-13 18:28:53 +00002988 # Test conv2d
2989 def test_conv2d_unit(self):
2990 def helper(input_shape, wt_shape,
2991 stride=1, padding=0,
2992 dilation=1, groups=1,
2993 bias_shape=None):
2994
2995 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
2996 x = cpu_x.detach().clone().to('mps').requires_grad_()
2997
2998 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
2999 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3000
3001 cpu_bias = None
3002 bias = None
3003
Thomas4935b592022-11-23 02:18:03 +00003004 if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003005 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
3006 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3007
3008 y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
3009 padding=padding, dilation=dilation, groups=groups)
3010 ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
3011 padding=padding, dilation=dilation, groups=groups)
3012
3013 cpu_grad = torch.ones_like(ref_y)
3014 grad = cpu_grad.to('mps')
3015
3016 y.backward(gradient=grad)
3017 ref_y.backward(gradient=cpu_grad)
3018
3019 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
3020 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
3021 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
Thomas4935b592022-11-23 02:18:03 +00003022 if (bias_shape is not None):
Kulin Seth3d833212022-05-20 03:18:09 +00003023 self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003024
3025 N = 1
3026 C_in = 3
3027 C_out = 64
3028 H = 64
3029 W = 64
3030 kH = 4
3031 kW = 4
3032 stride = 2
3033 padding = 1
3034
3035 helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
3036
3037 N = 4
3038 C_in = 16
3039 H = 32
3040 W = 32
3041
3042 C_out = 8
3043 kH = 3
3044 kW = 3
3045
3046 for groups in [1, 2, 4]:
3047 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3048 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3049
3050 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3051 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3052
3053 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3054 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3055
3056 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3057 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3058 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3059 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3060
3061 # Test conv transpose 2d
3062 def test_conv_transpose2d(self):
3063 def helper(input_shape, wt_shape,
3064 stride=1, padding=0,
3065 output_padding=0,
3066 dilation=1, groups=1,
3067 bias_shape=None):
3068
3069 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
3070 x = cpu_x.detach().clone().to('mps').requires_grad_()
3071
3072 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
3073 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3074
3075 cpu_bias = None
3076 bias = None
3077
Thomas4935b592022-11-23 02:18:03 +00003078 if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003079 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
3080 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3081
3082 y = torch.nn.functional.conv_transpose2d(
3083 x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
3084 ref_y = torch.nn.functional.conv_transpose2d(
3085 cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
3086 output_padding=output_padding, groups=groups, dilation=dilation)
3087
3088 cpu_grad = torch.randn(ref_y.shape)
3089 grad = cpu_grad.to('mps')
3090
3091 y.backward(gradient=grad)
3092 ref_y.backward(gradient=cpu_grad)
3093
3094 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
3095 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
3096 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
3097
Thomas4935b592022-11-23 02:18:03 +00003098 # if (bias_shape is not None):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003099 # print(cpu_bias.grad)
3100 # print(bias.grad.to('cpu'))
3101 # self.assertEqual(bias.grad, cpu_bias.grad)
3102
3103 N = 4
Alban Desmaisonbde246f2022-05-30 10:36:31 -04003104 C_in = 2
Kulin Sethe011a8e2022-05-13 18:28:53 +00003105 H = 32
3106 W = 32
3107
3108 C_out = 8
3109 groups = 1
3110 kH = 3
3111 kW = 3
3112
3113 for stride in [1, 2, 3]:
3114 for padding in [0, 1, 2]:
3115 for output_padding in [0, 1, 2]:
3116 for dilation in [1, 2]:
Thomas4935b592022-11-23 02:18:03 +00003117 if (output_padding >= stride or output_padding >= dilation):
Kulin Sethe011a8e2022-05-13 18:28:53 +00003118 continue
3119 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3120 padding=padding, output_padding=output_padding, dilation=dilation)
3121 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3122 padding=padding, output_padding=output_padding, dilation=dilation)
3123
3124 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3125 padding=padding, output_padding=output_padding, dilation=dilation)
3126 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3127 padding=padding, output_padding=output_padding, dilation=dilation)
3128
3129 # Test sigmoid
3130 def test_sigmoid(self):
3131 def helper(shape):
3132
3133 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3134 x = cpu_x.detach().clone().to('mps').requires_grad_()
3135
3136 sigmoid_op = torch.nn.Sigmoid()
3137
3138 y = sigmoid_op(x)
3139 ref_y = sigmoid_op(cpu_x)
3140
3141 cpu_grad = torch.ones_like(ref_y)
3142 grad = cpu_grad.to('mps')
3143
3144 y.backward(gradient=grad)
3145 ref_y.backward(gradient=cpu_grad)
3146
3147 self.assertEqual(y, ref_y)
3148 self.assertEqual(x.grad, cpu_x.grad)
3149
3150 helper((2, 3, 4, 5))
3151 helper((2, 3, 4))
3152 helper((2, 8, 4, 5))
3153
3154 # Test tanh
3155 def test_tanh(self):
3156 def helper(shape):
3157
3158 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3159 x = cpu_x.detach().clone().to('mps').requires_grad_()
3160
3161 tanh_op = torch.nn.Tanh()
3162
3163 y = tanh_op(x)
3164 ref_y = tanh_op(cpu_x)
3165
3166 cpu_grad = torch.ones_like(ref_y)
3167 grad = cpu_grad.to('mps')
3168
3169 y.backward(gradient=grad)
3170 ref_y.backward(gradient=cpu_grad)
3171
3172 self.assertEqual(y, ref_y)
3173 self.assertEqual(x.grad, cpu_x.grad)
3174
3175 helper((2, 3, 4, 5))
3176 helper((2, 3, 4))
3177 helper((2, 8, 4, 5))
3178
3179 def test_threshold(self):
3180 def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
3181 m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
3182
3183 input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
3184 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
3185
3186 output_cpu = m(input_cpu)
3187 output_mps = m(input_mps)
3188
3189 cpu_grad = torch.ones_like(output_cpu)
3190 mps_grad = cpu_grad.to('mps')
3191
3192 self.assertEqual(output_cpu, output_mps)
3193
3194 if requires_grad:
3195 output_cpu.backward(gradient=cpu_grad)
3196 output_mps.backward(gradient=mps_grad)
3197
3198 self.assertEqual(input_cpu.grad, input_mps.grad)
3199
3200 helper(threshold=0.1, value=20, num_elems=2)
3201 helper(threshold=-0.1, value=10, num_elems=10)
3202 helper(threshold=0.5, value=-15, num_elems=100)
3203 helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
3204
3205 # Test pow
3206 def test_pow(self):
3207 def helper(shape):
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00003208 # aten::pow.Tensor_Tensor
Kulin Sethe011a8e2022-05-13 18:28:53 +00003209 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3210 x = cpu_x.detach().clone().to('mps')
3211 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3212 y = cpu_y.detach().clone().to('mps')
3213 z = torch.pow(x, y)
3214 ref_z = torch.pow(cpu_x, cpu_y)
3215
3216 self.assertEqual(z, ref_z)
3217
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00003218 # aten::pow.Tensor_Scalar
Kulin Sethe011a8e2022-05-13 18:28:53 +00003219 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3220 x = cpu_x.detach().clone().to('mps')
3221 exp = random.random()
3222 z = torch.pow(x, exp)
3223 ref_z = torch.pow(cpu_x, exp)
3224
3225 self.assertEqual(z, ref_z)
3226
Li-Huai (Allan) Linf33180f2023-02-28 16:11:15 +00003227 # aten::pow.Scalar
3228 x = random.random()
3229 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3230 y = cpu_y.detach().clone().to('mps')
3231 z = torch.pow(x, y)
3232 ref_z = torch.pow(x, cpu_y)
3233
3234 self.assertEqual(z, ref_z)
3235
Kulin Sethe011a8e2022-05-13 18:28:53 +00003236 helper((2, 8, 4, 5))
3237
3238 # Test addcmul
3239 def test_addcmul(self):
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003240 def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None):
3241 def rand_helper(dtype):
3242 if dtype.is_floating_point:
3243 return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
3244 return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003245
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003246 cpu_x = rand_helper(xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003247 x = cpu_x.detach().clone().to('mps')
3248
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003249 cpu_y = rand_helper(ytype if ytype is not None else xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003250 y = cpu_y.detach().clone().to('mps')
3251
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003252 cpu_z = rand_helper(ztype if ztype is not None else xtype)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003253 z = cpu_z.detach().clone().to('mps')
3254
3255 y = torch.addcmul(x, y, z, value=value)
3256 ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
3257
3258 self.assertEqual(y, ref_y)
3259
3260 helper((2, 3, 4, 5), 0.1)
3261 helper((2, 8, 4, 5), 0.1)
3262 helper((2, 3, 4, 5), 0.2)
3263 helper((2, 8, 4, 5), 0.2)
Nikita Shulga769cc8a2023-03-07 04:19:30 +00003264 # Integral types
3265 helper((2, 2), 1.0, xtype=torch.int32)
3266 helper((2, 2), 2.0, xtype=torch.int16)
3267
3268 # Mixed types
3269 helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32)
3270 helper((3, 2), 1.0, ytype=torch.float16)
3271 helper((2, 3), 1.0, ztype=torch.float16)
3272 helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8)
3273 helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003274
3275 # Test addcdiv
3276 def test_addcdiv(self):
3277 def helper(shape, value):
3278 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3279 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3280 # clamp to avoid division by 0
3281 cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
3282 cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3283
3284 mps_x = cpu_x.detach().clone().to('mps')
3285 mps_y = cpu_y.detach().clone().to('mps')
3286 mps_z = cpu_z.detach().clone().to('mps')
3287 mps_out = cpu_out.detach().clone().to('mps')
3288
3289 result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
3290 result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
3291 self.assertEqual(result_div_mps, result_div_cpu)
3292 # test .out variant
3293 self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
3294
3295 helper((2, 3, 4, 5), 0.1)
3296 helper((2, 8, 4, 5), 0.2)
3297 helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally
3298
Joona Havukainena5ba9b22024-06-06 16:09:18 +00003299 def test_addcdiv_transpose(self):
3300 # Regression test for issue https://github.com/pytorch/pytorch/issues/118115
3301 # Testing continuity of all input tensors
3302
3303 def helper(shape, value):
3304 shape_t = shape[::-1]
3305 for i in range(2):
3306 for j in range(2):
3307 for k in range(2):
3308 x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t()
3309 y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t()
3310 z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t()
3311
3312 x_mps = x.detach().clone().to(device="mps")
3313 y_mps = y.detach().clone().to(device="mps")
3314 z_mps = z.detach().clone().to(device="mps")
3315
3316 result_cpu = x.addcdiv_(y, z, value=value)
3317 result_mps = x_mps.addcdiv(y_mps, z_mps, value=value)
3318 result_mps_out = result_cpu.detach().clone().to('mps')
3319 torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value)
3320
3321 self.assertEqual(result_cpu, result_mps)
3322 self.assertEqual(result_cpu, result_mps_out)
3323
3324 helper((2, 3), 1.0)
3325 helper((2, 3), 0.2)
3326 helper((100, 300), 1.0)
3327 helper((100, 300), 0.2)
3328
Ramin Azarmehraa62b3e2022-05-31 19:15:45 +00003329 def test_buffer_size_match(self):
3330 # this test shouldn't cause any crash
3331 size = 16
3332 cpu_A = torch.rand(size, device='cpu')
3333 cpu_F = torch.rand(size, size, size, device='cpu')
3334
3335 mps_A = cpu_A.to('mps')
3336 mps_F = cpu_F.to('mps')
3337 self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
3338
Kulin Sethe011a8e2022-05-13 18:28:53 +00003339 def test_transpose_inplace(self):
3340 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3341 cpu_x = torch.tensor(values, device='cpu')
3342 mps_x = torch.tensor(values, device='mps')
3343
3344 cpu_x.transpose_(0, 1)
3345 mps_x.transpose_(0, 1)
3346 self.assertEqual(cpu_x, mps_x.to('cpu'))
3347
Kulin Seth4858c562022-06-02 06:17:19 +00003348 def test_expand_cpu_to_mps_copy(self):
3349 # https://github.com/pytorch/pytorch/issues/78642
3350
3351 x = torch.tensor(1).expand([10]).to("mps")
3352 x_cpu = torch.tensor(1).expand([10])
3353
3354 self.assertEqual(x_cpu, x.cpu())
3355
Denis Vieriu0a677f22023-01-10 22:45:48 +00003356 def test_cpu_to_strided_mps_copy(self):
3357 # https://github.com/pytorch/pytorch/issues/86975
3358
3359 a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3360 b1 = torch.Tensor([-1, -1])
3361 a1[1:, 1] = b1
3362
3363 a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3364 b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
3365 a2[1:, 1] = b2
3366
3367 self.assertEqual(a1, a2)
3368
Denis Vieriue3ac1092023-02-07 16:20:08 +00003369 def test_view_slice_reshape(self):
3370 x = torch.randn([1, 4, 4], device="mps")
3371 y = x[0, :1, 1:]
3372
3373 x_cpu = x.to("cpu")
3374 y_cpu = x_cpu[0, :1, 1:]
3375
3376 r = y + 1
3377 r_cpu = y_cpu + 1
3378 self.assertEqual(r, r_cpu)
3379
3380 def test_slice_reshape(self):
3381 x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
3382 x_cpu = x.detach().clone().to("cpu")
3383
3384 x = x[:, 3:].view(2, 3, 4, 1)
3385 x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1)
3386 self.assertEqual(x, x_cpu)
3387
3388 x = x + 2
3389 x_cpu = x_cpu + 2
3390 self.assertEqual(x, x_cpu)
3391
Denis Vieriu304a9542023-03-03 08:08:31 +00003392 def test_reshape_storage_offset(self):
3393 # https://github.com/pytorch/pytorch/issues/95883
3394 B = 4
3395 T = 1
3396
3397 lin_cpu = nn.Linear(10, 256)
3398 lin_mps = nn.Linear(10, 256, device="mps")
3399
3400 # Use the same weights and bias as the ones from the cpu
3401 lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
3402 lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
3403
3404 x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
3405 x_cpu = x_mps.detach().clone().cpu().requires_grad_()
3406 x_mps = lin_mps(x_mps)
3407 x_cpu = lin_cpu(x_cpu)
3408
3409 self.assertEqual(x_mps.shape, (B, T, 256))
3410 self.assertEqual(x_cpu.shape, (B, T, 256))
3411
3412 cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
3413 cls_token_cpu = cls_token_mps.detach().clone().cpu()
3414 x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
3415 x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
3416
3417 x_mps = x_mps.transpose(0, 1)
3418 x_cpu = x_cpu.transpose(0, 1)
3419
3420 target_mps = torch.rand_like(x_mps)
3421 target_cpu = target_mps.detach().clone().cpu()
3422 loss_mps = F.mse_loss(x_mps, target_mps)
3423 loss_cpu = F.mse_loss(x_cpu, target_cpu)
3424 self.assertEqual(loss_mps, loss_cpu)
3425
3426 loss_mps.backward()
3427 loss_cpu.backward()
3428 self.assertEqual(x_mps.grad, x_cpu.grad)
3429
Li-Huai (Allan) Lin88a659e2023-11-08 16:19:38 -08003430 def test_stack_storage_offset(self):
Denis Vieriu304a9542023-03-03 08:08:31 +00003431 # https://github.com/pytorch/pytorch/issues/87856
3432 x_cpu = torch.tensor([[1, 2]])
3433 x_mps = x_cpu.detach().clone().to("mps")
3434
3435 y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
3436 y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
3437
3438 self.assertEqual(y_cpu, y_mps)
3439
3440 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3441 t_cpu = t_mps.detach().cpu().detach()
3442
3443 x_mps = t_mps[2:]
3444 y_mps = t_mps[:2]
3445
3446 x_cpu = t_cpu[2:]
3447 y_cpu = t_cpu[:2]
3448
3449 res_mps = torch.stack((y_mps, x_mps), dim=-1)
3450 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
3451
3452 self.assertEqual(res_mps, res_cpu)
3453
3454 def test_unsafe_chunk(self):
3455 # https://github.com/pytorch/pytorch/issues/91065
3456 a = torch.rand(5, dtype=torch.float32, device="cpu")
3457 ret = a.unsafe_chunk(4, 0)
3458 y = ret[0] * ret[2]
3459 a_mps = a.to("mps")
3460 ret_mps = a_mps.unsafe_chunk(4, 0)
3461 y_mps = ret_mps[0] * ret_mps[2]
3462 self.assertEqual(y, y_mps)
3463
Ramin Azarmehr9511b9f2023-02-18 16:29:01 +00003464 def test_slice_casting(self):
3465 # generate random binary numbers
3466 cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
3467 mps_in = cpu_in.detach().clone().to("mps")
3468 # check copy_cast(unit8 -> bool) on tensors with storage offset
3469 cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
3470 mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
3471 self.assertEqual(cpu_out, mps_out)
3472
Denis Vieriue3ac1092023-02-07 16:20:08 +00003473 def test_slice_reshape_contg_view(self):
3474 import torch
3475
3476 x_mps = torch.randn(1, 4800, 2, device="mps")
3477 x_cpu = x_mps.detach().clone().cpu()
3478
3479 r_mps = x_mps + 2
3480 r_cpu = x_cpu + 2
3481
3482 self.assertEqual(r_mps, r_cpu)
3483
Denis Vieriu86efa102023-02-23 17:26:10 +00003484 def test_contiguous_slice_2d(self):
3485 def helper(shape):
3486 for i in range(0, shape[0]):
3487 for j in range(0, shape[1]):
3488 t_mps = torch.randn(shape, device="mps")
3489 t_cpu = t_mps.detach().clone().cpu()
3490
3491 y_mps = t_mps[i:, :j]
3492 y_cpu = t_cpu[i:, :j]
3493 self.assertEqual(y_mps + 1, y_cpu + 1)
3494
3495 y_mps = t_mps[i:, j]
3496 y_cpu = t_cpu[i:, j]
3497 self.assertEqual(y_mps + 1, y_cpu + 1)
3498
3499 y_mps = t_mps[i, :j]
3500 y_cpu = t_cpu[i, :j]
3501 self.assertEqual(y_mps + 1, y_cpu + 1)
3502
3503 y_mps = t_mps[:i, :j]
3504 y_cpu = t_cpu[:i, :j]
3505 self.assertEqual(y_mps + 1, y_cpu + 1)
3506
3507 y_mps = t_mps[:i, j]
3508 y_cpu = t_cpu[:i, j]
3509 self.assertEqual(y_mps + 1, y_cpu + 1)
3510
3511 y_mps = t_mps[:i, j:]
3512 y_cpu = t_cpu[:i, j:]
3513 self.assertEqual(y_mps + 1, y_cpu + 1)
3514
3515 l = []
3516 for N in range(1, 3):
3517 l.append(N)
3518 for C in range(1, 3):
3519 l.append(C)
3520 helper(l)
3521 for D in range(1, 3):
3522 l.append(D)
3523 helper(l)
3524 for H in range(1, 3):
3525 l.append(H)
3526 helper(l)
3527 for W in range(1, 3):
3528 l.append(W)
3529 helper(l)
3530 l.pop()
3531 l.pop()
3532 l.pop()
3533 l.pop()
3534 l.pop()
3535
3536 helper([9, 15, 4])
3537 helper([9, 3, 2])
3538 helper([3, 4, 18, 22])
3539 helper([3, 4, 18, 22, 150])
3540
Denis Vieriue5a959a2023-03-01 16:16:49 +00003541 def test_contiguous_slice_3d(self):
3542 x = torch.randn(2, 3, 3, device="mps")
3543 x_cpu = x.detach().clone().cpu()
3544 x = x[:1]
3545 x_cpu = x_cpu[:1]
3546 out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
3547 out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
3548 self.assertEqual(out, out_cpu)
3549
Denis Vieriub71c7102022-12-08 17:59:55 +00003550 def test_view_slice(self):
3551 # https://github.com/pytorch/pytorch/issues/83995
3552 NUM_SAMPLES = 60
3553 s = (0, 1)
3554
3555 X = torch.rand(8000, 3, dtype=torch.float32, device='cpu')
3556 X_mps = X.detach().clone().to("cpu")
3557
3558 idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s))
3559 pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1]))
3560 idx_mps = idx.to("mps")
3561 pts_mps = pts.to("mps")
3562 pts[:, s] = idx
3563 pts_mps[:, s] = idx_mps
3564
3565 actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float)
3566 actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
3567
3568 for i in range(NUM_SAMPLES):
3569 for j in range(X.shape[1]):
3570 actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j]
3571 actual_pts[i, j] = X[pts[i, j], j]
3572 self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j])
3573
Denis Vieriudbf96162023-01-02 16:31:27 +00003574 def test_slice_scatter(self):
3575 shape = (4, 4)
3576 tensor = torch.randint(10, shape, device="mps")
3577 tensor_before = tensor.clone()
3578 torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
3579 torch.testing.assert_close(tensor, tensor_before)
Denis Vieriub71c7102022-12-08 17:59:55 +00003580
Kulin Sethe011a8e2022-05-13 18:28:53 +00003581 def test_slice(self):
3582 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3583 cpu_x = torch.tensor(values, device='cpu')
3584 mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
3585
3586 cpu_slice1 = cpu_x[:2, :]
3587 mps_slice1 = mps_x[:2, :]
Kulin Sethe011a8e2022-05-13 18:28:53 +00003588 self.assertEqual(cpu_slice1, mps_slice1)
3589
3590 cpu_slice2 = cpu_x[:, :1]
3591 mps_slice2 = mps_x[:, :1]
Kulin Sethe011a8e2022-05-13 18:28:53 +00003592 self.assertEqual(cpu_slice2, mps_slice2)
3593
3594 cpu_slice3 = cpu_x[1:2, :]
3595 mps_slice3 = mps_x[1:2, :]
3596 self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
3597
3598 cpu_slice4 = cpu_x[1, :]
3599 mps_slice4 = mps_x[1, :].to('cpu')
3600 self.assertEqual(cpu_slice4, mps_slice4)
3601
Denis Vieriu861bdf92024-08-16 21:07:48 +00003602 @parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16])
3603 def test_slice_view_api(self, torch_type: torch.dtype):
3604
3605 def helper(x_tensor, y_func, z_func, r_func=None):
3606 x_mps = x_tensor.detach().clone().to("mps")
3607
3608 y = y_func(x_tensor)
3609 y_mps = y_func(x_mps)
3610 self.assertEqual(y, y_mps)
3611
3612 z = z_func(y)
3613 z_mps = z_func(y_mps)
3614 self.assertEqual(z, z_mps)
3615 self.assertEqual(z.storage_offset(), z_mps.storage_offset())
3616
3617 if r_func:
3618 r = r_func(z)
3619 r_mps = r_func(z_mps)
3620 self.assertEqual(r, r_mps)
3621
3622 # Skip bfloat16 before MacOS15
3623 if not (product_version < 15.0 and torch_type == torch.bfloat16):
3624 # Tests for previously encountered MPS bugs
3625 helper(
3626 torch.randn(4, 4, dtype=torch_type),
3627 lambda x: x[1],
3628 lambda y: y.reshape(2, 2),
3629 lambda z: z + 1
3630 )
3631 helper(
3632 torch.randn(2, 4, dtype=torch_type),
3633 lambda x: x[1],
3634 lambda y: y + torch.ones(4, device=y.device)
3635 )
3636 helper(
3637 torch.randn(4, 6, dtype=torch_type),
3638 lambda x: x[1],
3639 lambda y: y.reshape(3, 2).t(),
3640 lambda z: z + 1
3641 )
3642 helper(
3643 torch.arange(4, dtype=torch_type).resize(1, 2, 2),
3644 lambda x: x.permute(2, 0, 1),
3645 lambda y: y + 1
3646 )
3647 helper(
3648 torch.randn(4, 8, dtype=torch_type),
3649 lambda x: x.transpose(0, 1).reshape(-1),
3650 lambda y: y[:2],
3651 lambda z: z + 1
3652 )
3653 helper(
3654 torch.randn(1, dtype=torch_type),
3655 lambda x: x.expand(2, 3),
3656 lambda y: y + torch.ones(2, 3, device=y.device)
3657 )
3658
3659 def test_slice_reshape_contiguous(self):
3660 x = torch.randn(4, 4)
3661 x_mps = x.detach().clone().to("mps")
3662
3663 y = x[1]
3664 y_mps = x_mps[1]
3665 self.assertEqual(y, y_mps)
3666
3667 z = y.reshape(2, 2)
3668 z_mps = y_mps.reshape(2, 2)
3669 self.assertEqual(z, z_mps)
3670 self.assertEqual(z.storage_offset(), z_mps.storage_offset())
3671
Denis Vieriua6b75bb2022-08-22 17:05:53 +00003672 def test_scalar_from_slice_unary(self):
3673 # https://github.com/pytorch/pytorch/issues/82543
3674 tensor_list = torch.tensor([1.0, 1.2], device="mps")
3675
3676 for scalar in tensor_list:
3677 r_mps = torch.ceil(scalar)
3678 r_cpu = torch.ceil(scalar.to("cpu"))
3679 self.assertEqual(r_mps.cpu(), r_cpu)
3680
3681 def test_scalar_from_slice_binary(self):
3682 # https://github.com/pytorch/pytorch/issues/82543
3683 def helper(binary_op):
3684 tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
3685
3686 for scalar in tensor_list:
3687 r_mps = binary_op(scalar, 1.0)
3688 r_cpu = binary_op(scalar.cpu(), 1.0)
3689 self.assertEqual(r_mps.cpu(), r_cpu)
3690 helper(torch.sub)
3691 helper(torch.add)
3692 helper(torch.not_equal)
3693 helper(torch.eq)
3694
Kulin Sethd63db522022-05-28 14:41:56 +00003695 def test_slice_contiguous_view(self):
3696 # https://github.com/pytorch/pytorch/issues/77750
3697
3698 def helper(operator):
3699 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3700 t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
3701
3702 # contiguous view
3703 x_mps = t_mps[2:] # 3, 4
3704 y_mps = t_mps[:2] # 1, 2
3705
3706 x_cpu = t_cpu[2:]
3707 y_cpu = t_cpu[:2]
3708
3709 res_mps = res_cpu = None
3710 if operator == "<=":
3711 res_mps = x_mps <= y_mps
3712 res_cpu = x_cpu <= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003713 elif operator == "<":
Kulin Sethd63db522022-05-28 14:41:56 +00003714 res_mps = x_mps < y_mps
3715 res_cpu = x_cpu < y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003716 elif operator == ">=":
Kulin Sethd63db522022-05-28 14:41:56 +00003717 res_mps = x_mps >= y_mps
3718 res_cpu = x_cpu >= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003719 elif operator == ">":
Kulin Sethd63db522022-05-28 14:41:56 +00003720 res_mps = x_mps >= y_mps
3721 res_cpu = x_cpu >= y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003722 elif operator == "==":
Kulin Sethd63db522022-05-28 14:41:56 +00003723 res_mps = x_mps == y_mps
3724 res_cpu = x_cpu == y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003725 elif operator == "!=":
Kulin Sethd63db522022-05-28 14:41:56 +00003726 res_mps = x_mps != y_mps
3727 res_cpu = x_cpu != y_cpu
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003728 elif operator == "stack":
3729 res_mps = torch.stack((y_mps, x_mps), dim=-1)
3730 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
Kulin Sethd63db522022-05-28 14:41:56 +00003731
3732 self.assertEqual(res_mps, res_cpu)
3733
Li-Huai (Allan) Lin0a9c6082023-02-17 18:44:20 +00003734 for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
Kulin Sethd63db522022-05-28 14:41:56 +00003735 helper(op)
3736
Denis Vieriube327ec2022-09-30 18:51:43 +00003737 def test_slice_of_slice(self):
3738 x = torch.tensor([0.5, 0.5], device="cpu")
3739 x_mps = torch.tensor([0.5, 0.5], device="mps")
3740
3741 tensor = x[1][None]
3742 tensor_mps = x_mps[1][None]
3743
3744 res = tensor.ne(0)
3745 res_mps = tensor_mps.ne(0)
3746
3747 self.assertEqual(res, res_mps)
3748
Kulin Sethd63db522022-05-28 14:41:56 +00003749 def test_index_storage_offset(self):
3750 # https://github.com/pytorch/pytorch/issues/78107
3751
3752 a = torch.tensor([8.2670e-01, -1.0293e+00])
3753 b_cpu = a[0]
3754 c_cpu = a[1]
3755
3756 # both 'b' and 'c' are views of 'a'
3757 # 'b' has a storage offset of 0, while 'c' has a storage offset of 1
3758 # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
3759 # otherwise it ends with same value as 'b'
3760 b = b_cpu.to('mps')
3761 c = c_cpu.to('mps')
3762
3763 res_mps = b > c
3764 res_cpu = b_cpu > c_cpu
3765 self.assertEqual(res_mps, res_cpu)
3766
3767 res_mps = c > b
3768 res_cpu = c_cpu > b_cpu
3769 self.assertEqual(res_mps, res_cpu)
3770
Kulin Sethe011a8e2022-05-13 18:28:53 +00003771 def test_flatten(self):
3772 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3773 cpu_x = torch.tensor(values, device='cpu')
3774 mps_x = torch.tensor(values, device='mps')
3775
3776 cpu_flatten1 = cpu_x.flatten()
3777 mps_flatten1 = mps_x.flatten().to('cpu')
3778 self.assertEqual(cpu_flatten1, mps_flatten1)
3779
3780 cpu_flatten2 = cpu_x.flatten(start_dim=1)
3781 mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
3782 self.assertEqual(cpu_flatten2, mps_flatten2)
3783
3784 cpu_flatten3 = cpu_x.flatten(end_dim=1)
3785 mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
3786 self.assertEqual(cpu_flatten3, mps_flatten3)
3787
3788 # Test repeat
3789 def test_repeat(self):
3790 def helper(shape, repeats):
3791
3792 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3793 x = cpu_x.detach().clone().to('mps').requires_grad_()
3794
3795 y = x.repeat(repeats)
3796 ref_y = cpu_x.repeat(repeats)
3797
3798 cpu_grad = torch.randn(ref_y.shape)
3799 grad = cpu_grad.to('mps')
3800
3801 y.backward(gradient=grad)
3802 ref_y.backward(gradient=cpu_grad)
3803
3804 self.assertEqual(y, ref_y)
3805 self.assertEqual(x.grad, cpu_x.grad)
3806
3807 helper((2, 3, 4, 5), (2, 3, 4, 5))
3808 helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
3809 helper((3, 4, 5), (2, 3, 4, 5))
3810 helper((3, 4, 5), (2, 2, 2))
3811
Henry Chengfe0c7fb2023-02-12 08:43:52 +00003812 def test_torch_repeat_interleave(self, device="mps"):
3813 y = torch.tensor([[1, 2], [3, 4]], device=device)
3814 # exercise single argument function signature
3815 temp = y.repeat_interleave(2)
3816 self.assertEqual(torch.Size([8]), temp.size())
3817
3818 for dtype in [torch.int, torch.long]:
3819 lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
3820 output_size = torch.sum(lengths)
3821 a = torch.repeat_interleave(
3822 y,
3823 lengths,
3824 dim=0,
3825 )
3826 self.assertEqual(a.dtype, y.dtype)
3827 self.assertEqual(a.size(), torch.Size([3, 2]))
3828
3829 a_with_output = torch.repeat_interleave(
3830 y,
3831 lengths,
3832 dim=0,
3833 output_size=output_size,
3834 )
3835 self.assertEqual(a_with_output.dtype, y.dtype)
3836 self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
3837
3838 def test_repeat_interleave(self, device="mps"):
3839 x = torch.tensor([0, 1, 2, 3], device=device)
BJ Hargravedc52ba22023-04-12 19:23:04 +00003840 expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device)
3841 # Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32
3842 self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=product_version >= 13.3)
Henry Chengfe0c7fb2023-02-12 08:43:52 +00003843
3844 with self.assertRaises(RuntimeError):
3845 torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
3846
3847 with self.assertRaises(RuntimeError):
3848 torch.repeat_interleave(torch.arange(4.0, device=device))
3849
3850 with self.assertRaises(RuntimeError):
3851 torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
3852
3853 y = torch.tensor([[1, 2], [3, 4]], device=device)
3854
3855 y1_v1 = torch.repeat_interleave(y, 2)
3856 y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
3857 y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
3858 y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
3859 self.assertEqual(y1_v1, y1_expect)
3860 self.assertEqual(y1_v2, y1_expect)
3861 self.assertEqual(y1_v3, y1_expect)
3862
3863 y2 = torch.repeat_interleave(y, 3, dim=1)
3864 y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
3865 [3, 3, 3, 4, 4, 4]], device=device)
3866 self.assertEqual(y2, y2_expect)
3867
3868 y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
3869 y3_expect = torch.tensor([[1, 2],
3870 [3, 4],
3871 [3, 4]], device=device)
3872 self.assertEqual(y3, y3_expect)
3873
3874 with self.assertRaises(RuntimeError):
3875 torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
3876
3877 with self.assertRaises(RuntimeError):
3878 torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
3879
3880 # test zero sized dimension
3881 x = torch.zeros((5, 0), device=device)
3882 y = torch.repeat_interleave(x, repeats=3, dim=1)
3883 self.assertEqual(y, x.new_zeros(5, 0, device=device))
3884
3885 x = torch.tensor([], dtype=torch.int64, device=device)
3886 y = torch.repeat_interleave(x, x)
3887 self.assertEqual(y, x)
3888
3889 def test_repeat_interleave_simple(self):
3890 def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None):
3891 x = torch.randn(shape, dtype=dtype, device="mps")
3892 x_cpu = x.detach().clone().cpu()
3893
3894 num_repeats_cpu = num_repeats.detach().clone().cpu()
3895
3896 repeats = torch.repeat_interleave(x, num_repeats, dim)
3897 repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim)
3898
3899 self.assertEqual(repeats, repeats_cpu)
3900 helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
3901 helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
3902 helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
3903 helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
3904 helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
3905
Rohan Mitchellf42b42d2022-05-31 18:23:25 +00003906 def test_count_nonzero(self):
3907 def helper(dtype):
3908 n = [
3909 [[1, 0, 2], [3, 0, 2], [7, 9, -4]],
3910 [[0, 2, 3], [3, 2, 1], [2, 0, 0]],
3911 ]
3912 cpu_x = torch.tensor(n, dtype=dtype)
3913 mps_x = torch.tensor(n, dtype=dtype).to('mps')
3914
3915 # All non-zeros
3916 self.assertEqual(
3917 torch.count_nonzero(cpu_x),
3918 torch.count_nonzero(mps_x)
3919 )
3920
3921 # dim=1
3922 self.assertEqual(
3923 torch.count_nonzero(cpu_x, dim=1),
3924 torch.count_nonzero(mps_x, dim=1)
3925 )
3926
3927 # dim=(0, 1)
3928 self.assertEqual(
3929 torch.count_nonzero(cpu_x, dim=(0, 1)),
3930 torch.count_nonzero(mps_x, dim=(0, 1))
3931 )
3932 helper(torch.int32)
3933 helper(torch.int64)
3934 helper(torch.float16)
3935 helper(torch.float32)
3936
Kulin Sethe011a8e2022-05-13 18:28:53 +00003937 def _test_module_empty_input(self, module, inp, check_size=True):
3938 inp.requires_grad_(True)
3939 out = module(inp)
3940 gO = torch.rand_like(out)
3941 out.backward(gO)
3942 if check_size:
3943 self.assertEqual(out.size(), inp.size())
3944 for p in module.parameters():
3945 if p.requires_grad:
3946 self.assertEqual(p.grad, torch.zeros_like(p.grad))
3947 self.assertEqual(inp.grad, torch.zeros_like(inp))
3948
Lukas Hoeniga52bfe22022-05-24 20:09:45 +00003949 # Test dtype casting, with and without simultaneous device change
3950 def test_to(self):
3951 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3952 cpu_x = torch.tensor(values, device='cpu')
3953 mps_x = torch.tensor(values, device='mps')
3954
3955 self.assertEqual(cpu_x.int(), mps_x.int().cpu())
3956 self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
3957 self.assertEqual(cpu_x.float(), mps_x.float().cpu())
3958
3959 self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
3960 torch.tensor(1, dtype=torch.int32))
3961 self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
3962 self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
3963 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
3964 torch.tensor(1, dtype=torch.int32))
3965 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
3966 torch.tensor(1.0))
3967 self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
3968 torch.tensor(4, dtype=torch.int32))
3969 self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
3970 torch.tensor(4, dtype=torch.int32))
3971 self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
3972 torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
Nikita Shulga43905462022-06-22 18:41:21 +00003973 # Cast int8 and uint8 to float and compare results
3974 # See https://github.com/pytorch/pytorch/issues/80009 for more details
3975 cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8)
3976 cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8)
3977 for x_cpu in [cpu_byte, cpu_char]:
3978 x_mps = x_cpu.to('mps')
3979 self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32))
3980
Lukas Hoeniga52bfe22022-05-24 20:09:45 +00003981
Kulin Sethd63db522022-05-28 14:41:56 +00003982 def test_setitem_scalar(self) -> None:
3983 device = 'mps'
3984 for dtype in [torch.int32, torch.float32, torch.int64]:
3985 for i in range(3, 6):
3986 for j in range(3, 6):
3987 t = torch.zeros(i, j, dtype=dtype, device=device)
3988 self.assertEqual(t.sum(), 0)
3989 t[1, 1] = 1
3990 t[2, 1] = j
3991 t[1, 2] = i
3992 self.assertEqual(t[1, 1], 1)
3993 self.assertEqual(t[1, 2], i)
3994 self.assertEqual(t[2, 1], j)
3995 self.assertEqual(t.sum(), 1 + i + j)
Nikita Shulga437ecfc2022-05-27 20:46:53 +00003996
Nikita Shulga81cd2762022-06-14 07:48:56 -07003997 def test_stride_of_strides(self) -> None:
3998 x = torch.rand(32, 1, device='mps')
3999 y = x.as_strided(size=(32, 2), stride=(1, 0))
4000 # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
4001 # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
4002 z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
4003 self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
4004
Kulin Seth596bb412022-07-20 14:27:54 +00004005 def test_type_casting(self):
4006 # https://github.com/pytorch/pytorch/issues/81567
4007 def helper(data, to_dtype):
4008 a_cpu = torch.tensor(data)
4009 a_mps = a_cpu.to(torch.device('mps'))
4010
4011 res_cpu = a_cpu.type(to_dtype)
4012 res_mps = a_mps.type(to_dtype)
4013 self.assertEqual(res_cpu, res_mps)
4014
4015 helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
4016 helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
4017 helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
4018 helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
4019 helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
4020 helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
4021 helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
4022
4023 def test_to_casting(self):
4024 # https://github.com/pytorch/pytorch/issues/81567
4025 def helper(data, to_dtype):
4026 a_cpu = torch.tensor(data)
4027 a_mps = a_cpu.to(torch.device('mps'))
4028
4029 res_cpu = a_cpu.to(to_dtype)
4030 res_mps = a_mps.to(to_dtype)
4031 self.assertEqual(res_cpu, res_mps)
4032
4033 helper([9.0, 3.0, 5.0, 4.0], torch.int64)
4034 helper([9.0, 3.0, 5.0, 4.0], torch.float)
4035 helper([9.0, 3.0, 5.0, 4.0], torch.int32)
4036 helper([9.0, 3.0, 5.0, 4.0], torch.short)
4037 helper([9.0, 3.0, 5.0, 4.0], torch.half)
4038 helper([9.0, 3.0, 5.0, 4.0], torch.int8)
4039 helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
4040
4041 def test_storage_offset_greater_than_src_nbytes(self):
4042 # https://github.com/pytorch/pytorch/issues/80844
4043 n_tensors = 100
4044 n_tensor_elems = 784
4045 elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
4046
4047 tensor_list = []
4048 for i in range(0, n_tensors - 1):
4049 # create a list of contiguous view tensors (view tensor created by the slice op)
4050 t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
4051 tensor_list.append(t)
4052
4053 for i in range(0, n_tensors - 1):
Nikita Shulgaae62cf72022-10-21 14:10:05 +00004054 t = tensor_list[i].view(1, n_tensor_elems)
Kulin Seth596bb412022-07-20 14:27:54 +00004055 t_mps = t.to("mps")
Nikita Shulgaae62cf72022-10-21 14:10:05 +00004056 self.assertEqual(t, t_mps.cpu(), f"i={i}")
Kulin Sethe011a8e2022-05-13 18:28:53 +00004057
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00004058 # See https://github.com/pytorch/pytorch/issues/82427
Nikita Shulgaff533b12022-08-18 21:59:15 +00004059 # and https://github.com/pytorch/pytorch/issues/83692
4060 def test_full_bugs(self):
4061 # Test should not crash
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00004062 x = torch.full((3, 3), True, device='mps')
Nikita Shulgaff533b12022-08-18 21:59:15 +00004063 # torch.full should work for uint8
4064 y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
4065 y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
4066 self.assertEqual(y_mps, y_cpu)
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00004067
Denis Vieriu71ec2612023-02-15 06:09:56 +00004068 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +00004069 # See https://github.com/pytorch/pytorch/issues/84995
4070 def test_div_bugs(self):
4071 for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
Kulin Seth299ada92023-02-10 00:10:08 +00004072 if dtype != torch.int64:
4073 x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
4074 y = torch.div(x, 101, rounding_mode=mode)
4075 self.assertEqual(y.sum(), 0)
Nikita Shulga1a6cf6e2022-09-14 23:40:20 +00004076
Nikita Shulgadcf51882022-08-03 14:54:47 +00004077 # See https://github.com/pytorch/pytorch/issues/82663
4078 def test_bool_expand(self):
4079 x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
4080 y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
PyTorch MergeBotcba96362022-12-02 21:36:13 +00004081 self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
Nikita Shulgadcf51882022-08-03 14:54:47 +00004082
Denis Vieriu861bdf92024-08-16 21:07:48 +00004083 def test_int_expand(self):
4084 x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps')
4085 y = torch.tensor([0, 1], dtype=torch.int8, device='mps')
4086 self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
4087
Nikita Shulga420c5762022-08-02 21:15:37 +00004088 # Empty unary op should return tensor of the same size
4089 def test_empty_neg(self):
4090 x = torch.tensor([[]], device='mps')
4091 y = -x
4092 self.assertEqual(x, y)
4093
Kulin Sethfc596642023-01-04 22:15:13 +00004094 def _test_unique_scalar_empty(self, dtype, device, f):
4095 # test scalar
4096 x = torch.tensor(0, dtype=dtype, device=device)
4097 unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
4098 expected_unique = torch.tensor([0], dtype=dtype, device=device)
4099 expected_inverse = torch.tensor(0, device=device)
4100 expected_counts = torch.tensor([1], device=device)
4101 self.assertEqual(unique, expected_unique)
4102 self.assertEqual(inverse, expected_inverse)
4103 self.assertEqual(counts, expected_counts)
4104
4105 # test zero sized tensor
4106 x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
4107 unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
4108 expected_unique = torch.tensor([], dtype=dtype, device=device)
4109 expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
4110 expected_counts = torch.tensor([], dtype=torch.long, device=device)
4111 self.assertEqual(unique, expected_unique)
4112 self.assertEqual(inverse, expected_inverse)
4113 self.assertEqual(counts, expected_counts)
4114
4115 def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
4116 def ensure_tuple(x):
4117 if isinstance(x, torch.Tensor):
4118 return (x,)
4119 return x
4120
4121 for return_inverse in [True, False]:
4122 for return_counts in [True, False]:
4123 # test with expected
4124 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4125 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4126 self.assertEqual(expected_unique, ret[0])
4127 if return_inverse:
4128 self.assertEqual(expected_inverse, ret[1])
4129 if return_counts:
4130 count_index = 1 + int(return_inverse)
4131 self.assertEqual(expected_counts, ret[count_index])
4132
4133 # tests per-element unique on a higher rank tensor.
4134 y = x.view(additional_shape)
4135 y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
4136 self.assertEqual(expected_unique, y_unique)
4137 self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
4138 self.assertEqual(expected_counts, y_counts)
4139
4140 def test_unique_all_dtypes(self, device="mps"):
4141 def helper(dtype):
4142 def ensure_tuple(x):
4143 if isinstance(x, torch.Tensor):
4144 return (x,)
4145 return x
4146
4147 if dtype is torch.bool:
4148 x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
4149 expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
4150 expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
4151 expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
4152 else:
4153 x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
4154 expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
4155 expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
4156 expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
4157
4158 # test sorted unique
4159 fs = (
4160 lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
4161 lambda x, **kwargs: x.unique(sorted=True, **kwargs),
4162 )
4163 x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
4164 xs = (x, x_sliced)
4165 for f, x in product(fs, xs):
4166 self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
4167 self._test_unique_scalar_empty(dtype, device, f)
4168
4169 # test unsorted unique
4170 fs = (
4171 lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
4172 lambda x, **kwargs: x.unique(sorted=False, **kwargs)
4173 )
4174 for f, x in product(fs, xs):
4175 self._test_unique_scalar_empty(dtype, device, f)
4176 for return_inverse, return_counts in product((True, False), repeat=2):
4177 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4178 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4179 x_list = x.tolist()
4180 x_unique_list = ret[0].tolist()
4181 self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
4182 if return_inverse:
4183 x_inverse_list = ret[1].tolist()
4184 for i, j in enumerate(x_inverse_list):
4185 self.assertEqual(x_list[i], x_unique_list[j])
4186 if return_counts:
4187 count_index = 1 + int(return_inverse)
4188 x_counts_list = ret[count_index].tolist()
4189 for i, j in zip(x_unique_list, x_counts_list):
4190 count = 0
4191 for k in x_list:
4192 if k == i:
4193 count += 1
4194 self.assertEqual(j, count)
4195 [helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]]
4196
4197 def test_unique(self):
4198 def helper(x, return_inverse, return_counts):
4199 cpu_x = x
4200 x = cpu_x.detach().clone().to('mps')
4201
4202 result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts)
4203 result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts)
4204
4205 self.assertEqual(result, result_cpu)
4206 helper(torch.tensor([1, 2, 4, 2, 1]), False, False)
4207 helper(torch.randint(3, (10, )), False, False)
4208 helper(torch.randint(3, (10, )), True, False)
4209 helper(torch.randint(3, (10, )), False, True)
4210 helper(torch.randint(3, (10, )), True, True)
4211 helper(torch.randint(3, (1, )), True, True)
4212 helper(torch.randint(3, (0, )), True, True)
Nikita Shulga5e4ee152023-07-11 19:55:52 +00004213 # Regression test for https://github.com/pytorch/pytorch/issues/104879
4214 x = torch.arange(2, device="mps")
4215 self.assertEqual(x.reshape(1, 1, 2).unique(), x)
Kulin Sethfc596642023-01-04 22:15:13 +00004216
4217 def test_unique_consecutive(self):
4218 def helper(x, dim, return_inverse, return_counts):
4219 cpu_x = x
4220 x = cpu_x.detach().clone().to('mps')
4221
4222 result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4223 result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4224
4225 self.assertEqual(result, result_cpu)
4226 helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False)
4227 helper(torch.randint(3, (10, )), 0, False, False)
4228 helper(torch.randint(3, (10, )), 0, True, False)
4229 helper(torch.randint(3, (10, )), 0, False, True)
4230 helper(torch.randint(3, (10, )), 0, True, True)
4231 helper(torch.randint(3, (10, )), 0, True, True)
4232 helper(torch.randint(3, (1, )), 0, True, True)
4233 helper(torch.randint(3, (0, )), 0, True, True)
4234
4235 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False)
4236 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True)
4237 helper(torch.randint(2, (20, 2)), 0, True, True)
4238 helper(torch.randint(2, (1, 2)), 0, True, True)
4239 helper(torch.randint(2, (0, 2)), 0, True, True)
4240
4241 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False)
4242 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True)
4243 helper(torch.randint(2, (2, 20)), 1, True, True)
4244 helper(torch.randint(2, (2, 1)), 1, True, True)
4245 helper(torch.randint(2, (2, 0)), 1, True, True)
4246
Nikita Shulga1367f242022-09-27 15:44:53 +00004247 # See https://github.com/pytorch/pytorch/issues/85675
4248 def test_cat_non_contiguous(self):
Kulin Sethc74f4382023-02-11 19:43:33 +00004249 def rotate_subset(data, dim):
4250 x1 = data[:, :, :2, :]
4251 x2 = data[:, :, 2:, :]
4252 self.assertFalse(x1.is_contiguous())
4253 self.assertFalse(x2.is_contiguous())
4254 return torch.concat((x1, x2), dim=dim)
Nikita Shulga1367f242022-09-27 15:44:53 +00004255 for dtype in MPS_DTYPES:
4256 if dtype == torch.bool:
4257 continue
Kulin Sethc74f4382023-02-11 19:43:33 +00004258 data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
4259 data = data.to(memory_format=torch.channels_last)
Nikita Shulga1367f242022-09-27 15:44:53 +00004260 mps_data = data.to("mps")
Kulin Sethc74f4382023-02-11 19:43:33 +00004261 self.assertEqual(data, mps_data)
4262 for dim in range(data.dim()):
4263 cpu_result = rotate_subset(data, dim)
4264 mps_result = rotate_subset(mps_data, dim)
4265 self.assertEqual(cpu_result, mps_result.to("cpu"))
4266 # TODO: enable memory format test
4267 # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
Nikita Shulga1367f242022-09-27 15:44:53 +00004268
Nikita Shulgab9b24c32022-10-02 20:13:05 +00004269 # See https://github.com/pytorch/pytorch/issues/85967
4270 def test_from_numpy_non_contiguous(self):
4271 a = np.arange(9).reshape(3, 3)[:, :2]
4272 t_cpu = torch.tensor(a, device="cpu")
4273 t_mps = torch.tensor(a, device="mps")
4274 self.assertEqual(t_cpu, t_mps.to("cpu"))
4275
Nikita Shulgaae62cf72022-10-21 14:10:05 +00004276 # See https://github.com/pytorch/pytorch/issues/86954
4277 def test_copy_non_contiguous(self):
4278 x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
4279 self.assertFalse(x.is_contiguous())
4280 y = x.to('mps')
4281 self.assertFalse(y.is_contiguous())
4282 self.assertEqual(x, y.to('cpu'))
4283
4284 x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
4285 y = x.to('mps')
4286 self.assertEqual(x, y.to('cpu'))
4287
4288 x = torch.full((4, 4, 4, 4), 13, device="cpu")
4289 y = torch.full((4, 4, 4, 4), 13, device="mps")
4290 z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
4291 x.permute(3, 2, 1, 0)[1::, ::2] = z
4292 # As y is on MPS and z on CPU, this dispatches to a copy operator
4293 y.permute(3, 2, 1, 0)[1::, ::2] = z
4294 self.assertEqual(x, y.to('cpu'))
4295
Li-Huai (Allan) Linb7c2a652023-02-28 05:24:31 +00004296 # See https://github.com/pytorch/pytorch/issues/95417
4297 def test_copy_storage_offset(self):
4298 x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
4299 x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
4300 update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
4301 update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
4302 x_cpu[2:4] = update_cpu
4303 x_mps[2:4] = update_mps # implicit type casting and copy
4304 self.assertEqual(x_cpu, x_mps)
4305
Li-Huai (Allan) Lin00871182023-09-18 16:18:37 -07004306 x_cpu[2:4] = update_mps # implicit device moving and copy
4307 self.assertEqual(x_cpu, x_mps)
4308
Peter Stefekc9c2b142023-08-03 04:03:28 +00004309 def test_copy_broadcasting(self):
4310 def helper(src_shape, dst_shape, src_dtype, dst_dtype):
4311 cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
4312 cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
4313 cpu_result = cpu_dst.copy_(cpu_src)
4314 mps_src = cpu_src.to("mps")
4315 mps_dst = cpu_dst.to("mps")
4316 mps_result = mps_dst.copy_(mps_src)
4317 self.assertEqual(cpu_result, mps_result)
4318
4319 test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
4320
4321 for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
4322 helper((2, 1), (2, 3), src_dtype, dst_dtype)
4323 helper((2, 1), (2, 2), src_dtype, dst_dtype)
4324 helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
4325 helper((3,), (2, 3), src_dtype, dst_dtype)
4326 helper((2,), (2, 2), src_dtype, dst_dtype)
4327 helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
4328 helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
4329 helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
4330 helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
4331 helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
Nikita Shulgabae40932023-08-31 21:08:29 +00004332 # Regression test for https://github.com/pytorch/pytorch/issues/107867
4333 self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
Peter Stefekc9c2b142023-08-03 04:03:28 +00004334
Lukas Hoenig81a8fdc2022-11-17 04:54:23 +00004335 # See https://github.com/pytorch/pytorch/pull/84742
4336 # and https://github.com/pytorch/pytorch/pull/78319
4337 def test_binops_dtype_precedence(self):
4338 # Test dtype precedence (casting order) in binary operations by comparing to CPU result
4339 # Example values for all dtypes supported on the MPS backend
4340 sample_vals = {
4341 torch.bool: [False, True],
4342 torch.int16: [-15, 0, 1, 10],
4343 torch.int32: [-376, 0, 1, 13],
4344 torch.int64: [-8, 0, 1, 77],
4345 torch.float16: [-234.5, 0.0, 1.0, 2.0],
4346 torch.float32: [-1.0, 0.0, 0.1, 111.99],
4347 }
4348 # Test all combinations of dtypes, operations, dimensionality
4349 for dtype1, dtype2, binop in itertools.product(
4350 sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']):
4351 # bool minus bool is generally unsupported, so skip
4352 if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
4353 continue
4354 full_shape = (10,)
4355 for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
4356 # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
4357 # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4358 # (torch.tensor(val2, dtype=dtype2, device='mps')))
4359 # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4360 # (torch.tensor(val2, dtype=dtype2, device='cpu')))
4361 self.assertEqual(
4362 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4363 (torch.tensor(val2, dtype=dtype2, device='mps')),
4364 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4365 (torch.tensor(val2, dtype=dtype2, device='cpu')))
4366 self.assertEqual(
4367 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4368 (torch.tensor([val2], dtype=dtype2, device='mps')),
4369 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4370 (torch.tensor([val2], dtype=dtype2, device='cpu')))
4371 self.assertEqual(
4372 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4373 (torch.tensor([val2], dtype=dtype2, device='mps')),
4374 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4375 (torch.tensor([val2], dtype=dtype2, device='cpu')))
4376 self.assertEqual(
4377 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4378 (torch.tensor(val2, dtype=dtype2, device='mps')),
4379 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4380 (torch.tensor(val2, dtype=dtype2, device='cpu')))
4381 # Test tensors created with torch.full
4382 x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
4383 y1 = torch.tensor(val2, dtype=dtype2, device='mps')
4384 x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
4385 y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
4386 self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
4387 x3 = torch.tensor(val1, dtype=dtype1, device='mps')
4388 y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
4389 x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
4390 y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
4391 self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
4392 self.assertEqual(
4393 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4394 (torch.full(full_shape, val2, dtype=dtype2, device='mps')),
4395 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4396 (torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
Nikita Shulgaae62cf72022-10-21 14:10:05 +00004397
Soof Golan19264b52023-02-09 10:30:51 +00004398 def test_nansum(self):
4399 def helper(dtype, noncontiguous, dim):
4400 zero_cpu = torch.zeros((), dtype=dtype)
4401
4402 # Randomly scale the values
4403 scale = random.randint(10, 100)
4404 x_cpu: torch.Tensor = make_tensor(
4405 (5, 5), dtype=dtype, device='cpu',
4406 low=-scale, high=scale, noncontiguous=noncontiguous)
4407
4408 if dtype.is_floating_point:
4409 nan_mask_cpu = x_cpu < (0.2 * scale)
4410 x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
4411 x_cpu[nan_mask_cpu] = np.nan
4412 else:
4413 x_no_nan_cpu = x_cpu
4414
4415 x_mps = x_cpu.to('mps')
4416 actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
4417 expect_out_cpu = torch.empty(0, dtype=dtype)
4418 dim_kwargs = {"dim": dim} if dim is not None else {}
4419 expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
4420
4421 actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
4422 # Sanity check on CPU
4423 self.assertEqual(expect, actual_cpu)
4424
4425 # Test MPS
4426 actual_mps = torch.nansum(x_mps, **dim_kwargs)
4427 # Test out= variant
4428 torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
4429 torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
4430 self.assertEqual(expect, actual_mps)
4431 self.assertEqual(expect_out_cpu, actual_out_mps)
4432
4433 args = itertools.product(
4434 (torch.float16, torch.float32, torch.int32, torch.int64), # dtype
4435 (True, False), # noncontiguous
4436 (0, 1, None), # dim
4437 )
4438
4439 for dtype, noncontiguous, dim in args:
4440 with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
4441 helper(dtype, noncontiguous, dim)
4442
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004443 def test_cumsum_all_dtypes(self):
4444 def helper(dtype):
4445 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4446 t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4447
4448 a = t.cumsum(0, dtype=dtype)
4449 a_cpu = t_cpu.cumsum(0, dtype=dtype)
4450
4451 self.assertEqual(a.cpu(), a_cpu)
4452 [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4453
4454 try:
4455 helper(torch.int64)
4456 except Exception as e:
4457 e_string = str(e)
Peter Stefek97e50552023-08-01 21:51:16 +00004458 self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
4459 " Support has been added in macOS 13.3")
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004460
Roy Hvaarae15da782024-05-03 01:19:21 +00004461 def test_cumsum_bool(self):
4462 a = torch.ones(2**16, dtype=torch.bool)
4463 t_cpu = a.cumsum(0)
4464 t_mps = a.to("mps").cumsum(0)
4465
4466 self.assertEqual(t_cpu, t_mps)
4467
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004468 def test_cumsum_minus_one_axis(self):
4469 def helper(dtype):
4470 # Test with axis -1
4471 cpu_x = None
Aaron Gokaslan3fe437b22024-01-03 06:04:44 +00004472 if dtype == torch.float32:
Denis Vieriu92d8c4b2023-02-10 17:40:29 +00004473 cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4474 else:
4475 cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4476 x = cpu_x.detach().clone().to('mps')
4477
4478 cpu_y = cpu_x.cumsum(-1)
4479 y = x.cumsum(-1)
4480
4481 self.assertEqual(y, cpu_y)
4482
4483 [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
Nikita Shulgabdd0a4a2022-08-01 19:42:24 +00004484
Peter Stefek97e50552023-08-01 21:51:16 +00004485 def test_cumprod_all_dtypes(self):
4486 def helper(dtype):
4487 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4488 t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4489
4490 a = t.cumprod(0, dtype=dtype)
4491 a_cpu = t_cpu.cumprod(0, dtype=dtype)
4492
4493 self.assertEqual(a.cpu(), a_cpu)
4494 [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4495
4496 try:
4497 helper(torch.int64)
4498 except Exception as e:
4499 e_string = str(e)
4500 self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
4501 + " Support has been added in macOS 13.3")
4502
4503 def test_cumprod_minus_one_axis(self):
4504 def helper(dtype):
4505 # Test with axis -1
4506 cpu_x = None
Aaron Gokaslan3fe437b22024-01-03 06:04:44 +00004507 if dtype == torch.float32:
Peter Stefek97e50552023-08-01 21:51:16 +00004508 cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4509 else:
4510 cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4511 x = cpu_x.detach().clone().to('mps')
4512
4513 cpu_y = cpu_x.cumprod(-1)
4514 y = x.cumprod(-1)
4515
4516 self.assertEqual(y, cpu_y)
4517
4518 [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
4519
Kulin Seth105f7202023-02-09 19:29:07 +00004520 def test_median_int16(self):
4521 def helper(shape, dtype):
4522 cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
4523 x = cpu_x.detach().clone().to('mps')
4524
4525 median_result = torch.median(x)
4526 median_result_cpu = torch.median(cpu_x)
4527 self.assertEqual(median_result, median_result_cpu)
4528
4529 helper((2, 8, 4, 5), torch.int16)
4530
soulitzer91dcc3b2023-07-07 17:05:13 -04004531 def test_activation_checkpoint_does_not_error(self):
4532 from torch.utils.checkpoint import checkpoint
4533
4534 for use_reentrant in (True, False):
4535 a = torch.tensor(1., device="mps", requires_grad=True)
4536
4537 def fn(x):
4538 return x.sin().cos().exp()
4539
4540 out = checkpoint(fn, a, use_reentrant=use_reentrant)
4541 out.backward()
4542
Kulin Sethe011a8e2022-05-13 18:28:53 +00004543 def test_as_strided(self):
Kulin Seth54361342022-07-06 03:39:20 +00004544 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4545 values_1 = [[1.0, 1.0], [1.0, 1.0]]
4546 cpu_x = torch.tensor(values, device='cpu')
4547 ones1 = torch.tensor(values_1, device='mps')
4548 x = cpu_x.detach().clone().to('mps').requires_grad_()
4549 strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
4550 strided_mps = torch.as_strided(x, (2, 2), (1, 2))
4551 self.assertEqual(strided_mps, strided_cpu)
4552 strided_cpu_out = strided_cpu + ones1.to('cpu')
4553 strided_mps_out = strided_mps + ones1
4554 self.assertEqual(strided_cpu_out, strided_mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004555
Kulin Seth54361342022-07-06 03:39:20 +00004556 # test with storage offsets
4557 cpu_x = torch.rand(3, 3, device='cpu')
4558 mps_x = cpu_x.to('mps')
4559 strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0)
4560 strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0)
4561 strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1)
4562 strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1)
4563 strided_cpu_out = strided_cpu1 - strided_cpu2
4564 strided_mps_out = strided_mps1 - strided_mps2
4565 self.assertEqual(strided_cpu_out, strided_mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004566
Denis Vieriu4477a5b2022-12-22 21:21:00 +00004567 def test_unfold(self):
4568 x = torch.arange(1., 8)
4569 x_mps = torch.arange(1., 8, device="mps")
Kulin Seth54361342022-07-06 03:39:20 +00004570
Denis Vieriu4477a5b2022-12-22 21:21:00 +00004571 y = x.unfold(0, 2, 1)
4572 y_mps = x_mps.unfold(0, 2, 1)
4573
4574 self.assertEqual(y, y_mps)
4575
4576 def test_unfold_all_devices_and_dtypes(self):
4577 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
4578 for dt in supported_dtypes:
4579 x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
4580 self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
4581
4582 def test_unfold_scalars(self):
4583 x = torch.tensor(0.5, device="mps")
4584 # unfold on a 0-dimensional tensor should always return a 1-d dimensional
4585 # tensor of shape [size] (i.e., the second parameter to unfold)
4586
4587 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
4588 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
4589 self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
Kulin Sethe011a8e2022-05-13 18:28:53 +00004590
Denis Vieriuf7939b22023-01-03 06:01:07 +00004591 def test_bincount_simple(self):
4592 input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
4593 input_cpu = input.to("cpu")
4594 weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
4595 weights_cpu = weights.to("cpu")
4596
4597 x = torch.bincount(input)
4598 x_cpu = torch.bincount(input_cpu)
4599 self.assertEqual(x, x_cpu)
4600
4601 y = input.bincount(weights)
4602 y_cpu = input_cpu.bincount(weights_cpu)
4603 self.assertEqual(y, y_cpu)
4604
4605 def test_bincount_reduction(self):
4606 device = "mps"
4607 # negative input throws
4608 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4609 torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32))
4610 # n-d input, with n > 1 throws
4611 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4612 torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
4613 # minlength < 0 throws
4614 with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
4615 torch.bincount(torch.tensor([1, 3], device=device),
4616 torch.tensor([.2, .2], device=device),
4617 minlength=-1)
4618 # n-d weights, with n > 1 throws
4619 with self.assertRaisesRegex(RuntimeError, '1-d'):
4620 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4621 torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float))
4622 # input and weights dim mismatch
4623 with self.assertRaisesRegex(RuntimeError, 'same length'):
4624 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4625 torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float))
4626 # 1-d input with no elements and default minlength
4627 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
4628 torch.zeros(0, dtype=torch.long, device=device))
4629 # 1-d input with no elements and specified minlength
4630 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
4631 torch.zeros(10, dtype=torch.long, device=device))
4632
4633 # test tensor method without weights
4634 long_counts = torch.tensor(
4635 [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
4636 self.assertEqual(
4637 torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
4638 long_counts)
4639 # test avoiding overflow for uint8 (#76979)
4640 count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
4641 count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
4642 self.assertEqual(count_uint8, count_int16)
4643 # test minlength functionality
4644 int_counts = torch.bincount(
4645 torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5)
4646 self.assertEqual(
4647 torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
4648 int_counts)
4649 # test weights
4650 byte_counts = torch.bincount(
4651 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4652 torch.tensor([.1, .2, .3, .4, .5], device=device))
4653 self.assertEqual(
4654 torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
4655 byte_counts = torch.bincount(
4656 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4657 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
4658 self.assertEqual(
4659 torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts)
4660 # test non-contiguous inputs and weights
4661 inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32)
4662 weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
4663 for i in [0, 1]:
4664 assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
4665 assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
4666 # inputs are non-contiguous but weights are contiguous
4667 self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
4668 # inputs and weights are non-contiguous
4669 self.assertEqual(
4670 inputs[:, 1].bincount(weights[:, 1]),
4671 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4672 # weights are non-contiguous but inputs are contiguous
4673 self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
4674 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4675
4676 # test bincount on non-contiguous slices
4677 all0s = torch.zeros((32, 2), dtype=torch.int32, device=device)
4678 self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
4679
4680 all1s = torch.ones((32, 2), dtype=torch.int32, device=device)
4681 self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
4682
4683 # test large number of bins - global memory use
4684 big_exp = torch.zeros(100, device=device)
4685 big_exp[-1] = 50.0
4686 big_w = torch.tensor([.5] * 100, device=device)
4687 big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w)
4688 self.assertEqual(big_exp, big_out)
4689 # test large input size
4690 big_exp = torch.zeros(2, device=device, dtype=torch.int64)
4691 big_exp[1] = 10
4692 big_out = torch.ones(10, dtype=torch.int8, device=device).bincount()
4693 self.assertEqual(big_exp, big_out)
4694
4695 def test_bincount(self):
4696 device = "mps"
4697 input_size = (5000,)
4698 w = torch.randn(input_size, dtype=torch.float, device=device)
4699 w_cpu = w.cpu()
4700
4701 t = torch.randint(50, input_size, dtype=torch.int8, device=device)
4702 self.assertEqual(t.cpu().bincount(), t.bincount())
4703 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4704
4705 t = torch.randint(500, input_size, dtype=torch.int32, device=device)
4706 self.assertEqual(t.cpu().bincount(), t.bincount())
4707 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4708
4709 t = torch.randint(2000, input_size, dtype=torch.int32, device=device)
4710 self.assertEqual(t.cpu().bincount(), t.bincount())
4711 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4712
4713 t = torch.zeros([10], dtype=torch.int32, device=device)
4714 t[0] = 35488
4715 counted = t.bincount(minlength=65536)
4716 self.assertEqual(torch.sum(counted), 10)
4717
Kulin Seth3d833212022-05-20 03:18:09 +00004718 def test_sum_backward(self):
4719 def helper(n, c):
4720 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4721 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4722 x = cpu_x.detach().clone().to('mps').requires_grad_()
4723
4724 all_sum = torch.sum(x)
4725 all_sum_cpu = torch.sum(cpu_x)
4726
4727 all_sum.backward()
4728 all_sum_cpu.backward()
Kulin Seth3d833212022-05-20 03:18:09 +00004729 self.assertEqual(all_sum, all_sum_cpu)
4730 self.assertEqual(x.grad, cpu_x.grad)
4731
4732 helper(3, 3)
4733
qqaatwff44bfa2022-06-24 17:18:30 +00004734 # L1 loss
4735 def test_l1_loss(self):
4736 def helper(shape, reduction):
4737 # create the criterion
4738 loss = torch.nn.L1Loss(reduction=reduction)
4739
4740 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4741 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4742 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4743 targetMPS = targetCPU.detach().clone().to('mps')
4744
4745 # forward pass
4746 outputCPU = loss(inputCPU, targetCPU)
4747 outputMPS = loss(inputMPS, targetMPS)
4748 self.assertEqual(outputCPU, outputMPS)
4749
4750 # backward pass
4751 if reduction != 'none':
4752 # chose 2 just to make the grad_output > 1 in backward pass
4753 outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4754 outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4755 self.assertEqual(inputCPU.grad, inputMPS.grad)
4756
4757 helper([8, 5, 4], 'none')
4758 helper([7, 5, 2, 4], 'sum')
4759 # verify if changes in shape would cause cached graph lookup problems
4760 helper([7, 5, 2, 4, 6], 'sum')
4761 helper([8, 4, 5, 7, 6], 'mean')
4762
Kulin Sethe011a8e2022-05-13 18:28:53 +00004763 # Mean Squared Error
4764 def test_mse_loss(self):
4765 def helper(shape, reduction):
4766 # create the criterion
4767 loss = torch.nn.MSELoss(reduction=reduction)
4768
4769 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4770 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4771 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4772 targetMPS = targetCPU.detach().clone().to('mps')
4773
4774 # forward pass
4775 outputCPU = loss(inputCPU, targetCPU)
4776 outputMPS = loss(inputMPS, targetMPS)
4777 self.assertEqual(outputCPU, outputMPS)
4778
4779 # backward pass
4780 if reduction != 'none':
4781 # chose 2 just to make the grad_output > 1 in backward pass
4782 outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4783 outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4784 self.assertEqual(inputCPU.grad, inputMPS.grad)
4785
4786 helper([8, 5, 4], 'none')
4787 helper([7, 5, 2, 4], 'sum')
4788 # verify if changes in shape would cause cached graph lookup problems
4789 helper([7, 5, 2, 4, 6], 'sum')
4790 helper([8, 4, 5, 7, 6], 'mean')
4791
Denis Vieriu3fb53bb2024-05-08 00:52:26 +00004792 def test_mse_loss_strided_output(self):
4793 # https://github.com/pytorch/pytorch/issues/124621
4794 lf = nn.MSELoss(reduction='none')
4795 model_cpu = nn.Sequential(
4796 nn.Conv1d(3, 3, 1),
4797 )
4798 model_mps = copy.deepcopy(model_cpu).to("mps")
4799
4800 x = torch.randn(128, 10, 3)
4801 x = x.permute(0, 2, 1)
4802
4803 x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
4804 x_mps = x_mps.permute(0, 2, 1)
4805
4806 y = model_cpu(x)
4807 y_mps = model_mps(x_mps)
4808
4809 y = y.permute(0, 2, 1)[:, :5, :]
4810 y_mps = y_mps.permute(0, 2, 1)[:, :5, :]
4811
4812 y_hat = torch.randn(128, 5, 3)
4813 y_hat_mps = y_hat.detach().clone().to("mps")
4814
4815 loss = lf(y, y_hat)
4816 loss_mps = lf(y_mps, y_hat_mps)
4817 self.assertEqual(loss, loss_mps)
4818
Kulin Sethe011a8e2022-05-13 18:28:53 +00004819 # Binary Cross Enropy
Kulin Seth4615f6a2022-06-16 20:21:31 +00004820 def test_bce_loss_simple(self):
Kulin Sethe011a8e2022-05-13 18:28:53 +00004821 def helper(shape, reduction):
4822 # create the criterion
4823 loss = torch.nn.BCELoss(reduction=reduction)
4824
4825 # input and target must be within [0..1]
4826 input_t = np.random.random_sample(size=shape).astype(np.float32)
4827 target_t = np.random.random_sample(size=shape).astype(np.float32)
4828 inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
4829 targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
4830 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4831 targetMPS = targetCPU.detach().clone().to('mps')
4832
4833 # forward pass
4834 outputCPU = loss(inputCPU, targetCPU)
4835 outputMPS = loss(inputMPS, targetMPS)
4836 self.assertEqual(outputCPU, outputMPS)
4837
4838 # backward pass
4839 if reduction != 'none':
4840 # chose 0.6 just to have the grad_output != 1
4841 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
4842 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
4843 self.assertEqual(inputCPU.grad, inputMPS.grad)
4844
4845 helper([8, 5, 4], 'none')
4846 helper([7, 5, 2, 4], 'sum')
4847 # verify if changes in shape would cause cached graph lookup problems
4848 helper([7, 5, 2, 4, 6], 'sum')
4849 helper([8, 4, 5, 7, 6], 'mean')
Kulin Seth4615f6a2022-06-16 20:21:31 +00004850 helper([1, 1, 32, 32], 'mean')
4851
4852 def test_bce_loss_always_nonnegative(self):
4853 target = torch.ones(5, device='mps')
4854 input = torch.ones(5, device='mps')
4855 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4856
4857 target = torch.zeros(5, device='mps')
4858 input = torch.zeros(5, device='mps')
4859 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4860
4861 def test_bce_loss_size_mismatch(self):
4862 bceloss = nn.BCELoss()
4863 a = torch.rand(25, device='mps')
4864 b = torch.rand(25, 1, device='mps')
4865 with self.assertRaisesRegex(ValueError, r'Using a target size \('):
4866 bceloss(a, b)
4867
4868 def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
4869 x_size = 1024
4870 y_size = 256
4871 target = torch.rand(x_size, y_size, device='mps')
4872
4873 for reduction in ['none', 'mean', 'sum']:
4874 output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
4875 output_logits = output_sig.clone().detach()
4876
4877 output_sig.requires_grad = True
4878 output_logits.requires_grad = True
4879 weight = torch.rand(y_size, device='mps')
4880
4881 loss_sig = nn.BCELoss(weight, reduction=reduction)(
4882 torch.sigmoid(output_sig), target
4883 )
4884 loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
4885 output_logits, target
4886 )
4887
4888 self.assertEqual(loss_logits, loss_sig)
4889
4890 if reduction == 'none':
4891 grad = torch.rand(x_size, y_size, device='mps')
4892 loss_sig.backward(grad)
4893 loss_logits.backward(grad)
4894 else:
4895 loss_sig.backward()
4896 loss_logits.backward()
4897
4898 self.assertEqual(output_sig.grad, output_logits.grad)
4899
4900 def test_bce_with_logits_has_correct_grad_at_zero(self):
4901 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4902 target = torch.zeros(3, 1, device='mps')
4903 nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
4904 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4905 self.assertEqual(output.grad, expected_grad)
4906
4907 def test_bce_with_logits_broadcasts_weights(self):
4908 target = torch.rand(16, 4, device='mps')
4909 output = torch.rand(16, 4, device='mps') - 0.5
4910
4911 weight = torch.rand(4, device='mps')
4912 out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4913
4914 weight = weight.expand(16, 4).contiguous()
4915 out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4916
4917 self.assertEqual(out1, out2)
4918
4919 weight = torch.rand(16, 1, device='mps')
4920 out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4921
4922 weight = weight.expand(16, 4).contiguous()
4923 out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4924
4925 self.assertEqual(out1, out2)
4926
4927 def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
4928 target = torch.rand(64, 4, device='mps')
4929 output = torch.rand(64, 4, device='mps') - 0.5
4930 pos_weight = torch.ones(64, 4, device='mps')
4931
4932 self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
4933 nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
4934
4935 def test_bce_with_logits_broadcasts_pos_weights(self):
4936 target = torch.rand(64, 4, device='mps')
4937 output = torch.rand(64, 4, device='mps') - 0.5
4938 pos_weight = torch.rand(4, device='mps')
4939 out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4940
4941 pos_weight1 = pos_weight.expand(1, 4)
4942 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
4943
4944 pos_weight2 = pos_weight.expand(64, 4)
4945 out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
4946
4947 self.assertEqual(out1, out2)
4948 self.assertEqual(out1, out3)
4949
4950 def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
4951 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4952 target = torch.zeros(3, 1, device='mps')
4953 pos_weight = torch.ones(3, 1, device='mps')
4954 nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
4955 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4956 grad = output.grad
4957 self.assertEqual(grad, expected_grad)
4958
4959 def test_bce_with_logits_stability(self):
4960 output = torch.tensor([0., -120.], device='mps')
4961 target = torch.tensor([0., 1.], device='mps')
4962 pos_weight = torch.tensor([1., 1.], device='mps')
4963
4964 out1 = nn.BCEWithLogitsLoss()(output, target)
4965 self.assertTrue(torch.isfinite(out1).all().item())
4966
4967 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4968 self.assertTrue(torch.isfinite(out2).all().item())
4969
4970 def test_bce_loss_broadcasts_weights(self):
4971 sigmoid = nn.Sigmoid()
4972 target = torch.rand(16, 4, device='mps')
4973 output = torch.rand(16, 4, device='mps') - 0.5
4974
4975 weight = torch.rand(4, device='mps')
4976 out1 = nn.BCELoss(weight)(sigmoid(output), target)
4977
4978 weight = weight.expand(16, 4).contiguous()
4979 out2 = nn.BCELoss(weight)(sigmoid(output), target)
4980
4981 self.assertEqual(out1, out2)
4982
4983 weight = torch.rand(16, 1, device='mps')
4984 out1 = nn.BCELoss(weight)(sigmoid(output), target)
4985
4986 weight = weight.expand(16, 4).contiguous()
4987 out2 = nn.BCELoss(weight)(sigmoid(output), target)
4988
4989 self.assertEqual(out1, out2)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004990
Nikita Shulga09ee96b2024-01-03 05:58:26 +00004991 def test_cross_entropy_loss(self):
4992 # Regression test for https://github.com/pytorch/pytorch/issues/116095
4993 loss = nn.CrossEntropyLoss()
4994 pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps')
4995 target = torch.ones(3, dtype=torch.long, device='mps')
4996 output = loss(pred, target)
4997 output.backward()
4998
Kulin Sethe011a8e2022-05-13 18:28:53 +00004999 def test_log_softmax(self):
5000 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
5001 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
5002 mps_x = torch.tensor(values, device='mps', requires_grad=True)
5003
5004 cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
5005 mps_log_softmax = F.log_softmax(mps_x, dim=0)
5006 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
5007
5008 cpu_grad = torch.ones_like(cpu_log_softmax)
5009 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5010
5011 cpu_log_softmax.backward(gradient=cpu_grad)
5012 mps_log_softmax.backward(gradient=mps_grad)
5013
5014 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
5015
alexdremova17a7cc2023-02-18 18:26:29 +00005016 def test_log_softmax_large_numbers(self):
5017 values = [
5018 [10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
5019 [-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
5020 ]
5021 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
5022 mps_x = torch.tensor(values, device='mps', requires_grad=True)
5023
5024 cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
5025 mps_log_softmax = F.log_softmax(mps_x, dim=-1)
5026 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
5027
5028 cpu_grad = torch.ones_like(cpu_log_softmax)
5029 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5030
5031 cpu_log_softmax.backward(gradient=cpu_grad)
5032 mps_log_softmax.backward(gradient=mps_grad)
5033
5034 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
5035
Kulin Sethe011a8e2022-05-13 18:28:53 +00005036 def test_eq(self):
5037 values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
5038 values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
5039 mps_x = torch.tensor(values1, device='mps')
5040 mps_y = torch.tensor(values2, device='mps')
5041 cpu_x = torch.tensor(values1, device='cpu')
5042 cpu_y = torch.tensor(values2, device='cpu')
5043 result_mps = torch.eq(mps_x, mps_y)
5044 result_cpu = torch.eq(cpu_x, cpu_y)
5045
5046 self.assertEqual(result_cpu, result_mps.to('cpu'))
5047
Denis Vieriu71ec2612023-02-15 06:09:56 +00005048 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Ramin Azarmehr6485d262022-12-23 17:11:55 +00005049 def test_signed_vs_unsigned_comparison(self):
5050 cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
5051 mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
5052 # in the comparison of signed vs. unsigned we should always cast to unsigned
5053 self.assertEqual(cpu_x == -1, mps_x == -1)
5054 self.assertEqual(cpu_x > -1, mps_x > -1)
5055 self.assertEqual(cpu_x < -1, mps_x < -1)
5056
Kulin Sethe011a8e2022-05-13 18:28:53 +00005057 def test_eq_int64(self):
5058 values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
5059 values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
5060 mps_x = torch.tensor(values1, device='mps')
5061 mps_y = torch.tensor(values2, device='mps')
5062 cpu_x = torch.tensor(values1, device='cpu')
5063 cpu_y = torch.tensor(values2, device='cpu')
5064 result_mps = torch.eq(mps_x, mps_y)
5065 result_cpu = torch.eq(cpu_x, cpu_y)
5066
5067 self.assertEqual(result_cpu, result_mps.to('cpu'))
5068
5069 def test_ne(self):
5070 def helper(shape):
5071 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5072 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5073 mps_x = cpu_x.detach().clone().to('mps')
5074 mps_y = cpu_y.detach().clone().to('mps')
5075 result_mps = torch.ne(mps_x, mps_y)
5076 result_cpu = torch.ne(cpu_x, cpu_y)
5077
5078 self.assertEqual(result_cpu, result_mps.to('cpu'))
5079
5080 helper((2, 3, 4, 5))
5081
5082 def test_ne_scalar(self):
5083 def helper(shape):
5084 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5085 mps_x = cpu_x.detach().clone().to('mps')
5086 result_mps = torch.ne(mps_x, 0.0)
5087 result_cpu = torch.ne(cpu_x, 0.0)
5088
5089 self.assertEqual(result_cpu, result_mps.to('cpu'))
5090
5091 helper((2, 3, 4, 5))
5092
5093 def test_lt(self):
5094 def helper(shape):
5095 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5096 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5097 mps_x = cpu_x.detach().clone().to('mps')
5098 mps_y = cpu_y.detach().clone().to('mps')
5099 result_mps = torch.lt(mps_x, mps_y)
5100 result_cpu = torch.lt(cpu_x, cpu_y)
5101
5102 self.assertEqual(result_cpu, result_mps.to('cpu'))
5103
5104 helper((2, 3, 4, 5))
5105
5106 def test_lt_scalar(self):
5107 def helper(shape):
5108 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5109 mps_x = cpu_x.detach().clone().to('mps')
5110 result_mps = torch.lt(mps_x, 0.0)
5111 result_cpu = torch.lt(cpu_x, 0.0)
5112
5113 self.assertEqual(result_cpu, result_mps.to('cpu'))
5114
5115 helper((2, 3, 4, 5))
5116
5117 def test_le(self):
5118 def helper(shape):
5119 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5120 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5121 mps_x = cpu_x.detach().clone().to('mps')
5122 mps_y = cpu_y.detach().clone().to('mps')
5123 result_mps = torch.le(mps_x, mps_y)
5124 result_cpu = torch.le(cpu_x, cpu_y)
5125
5126 self.assertEqual(result_cpu, result_mps.to('cpu'))
5127
5128 helper((2, 3, 4, 5))
5129
5130 def test_le_scalar(self):
5131 def helper(shape):
5132 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5133 mps_x = cpu_x.detach().clone().to('mps')
5134 result_mps = torch.le(mps_x, 0.0)
5135 result_cpu = torch.le(cpu_x, 0.0)
5136
5137 self.assertEqual(result_cpu, result_mps.to('cpu'))
5138
5139 helper((2, 3, 4, 5))
5140
5141 def test_ge(self):
5142 def helper(shape):
5143 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5144 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5145 mps_x = cpu_x.detach().clone().to('mps')
5146 mps_y = cpu_y.detach().clone().to('mps')
5147 result_mps = torch.ge(mps_x, mps_y)
5148 result_cpu = torch.ge(cpu_x, cpu_y)
5149
5150 self.assertEqual(result_cpu, result_mps.to('cpu'))
5151
5152 helper((2, 3, 4, 5))
5153
5154 def test_ge_scalar(self):
5155 def helper(shape):
5156 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5157 mps_x = cpu_x.detach().clone().to('mps')
5158 result_mps = torch.ge(mps_x, 0.0)
5159 result_cpu = torch.ge(cpu_x, 0.0)
5160
5161 self.assertEqual(result_cpu, result_mps.to('cpu'))
5162
5163 helper((2, 3, 4, 5))
5164
5165 def test_gt(self):
5166 def helper(shape):
5167 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5168 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5169 mps_x = cpu_x.detach().clone().to('mps')
5170 mps_y = cpu_y.detach().clone().to('mps')
5171 result_mps = torch.gt(mps_x, mps_y)
5172 result_cpu = torch.gt(cpu_x, cpu_y)
5173
5174 self.assertEqual(result_cpu, result_mps.to('cpu'))
5175
5176 helper((2, 3, 4, 5))
5177
5178 def test_gt_scalar(self):
5179 def helper(shape):
5180 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5181 mps_x = cpu_x.detach().clone().to('mps')
5182 result_mps = torch.gt(mps_x, 0.0)
5183 result_cpu = torch.gt(cpu_x, 0.0)
5184
5185 self.assertEqual(result_cpu, result_mps.to('cpu'))
5186
5187 helper((2, 3, 4, 5))
5188
Denis Vieriu28720ad2023-06-30 18:11:49 +00005189 def test_argmax(self):
5190 # https://github.com/pytorch/pytorch/issues/98191
5191 cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]])
5192 res_cpu = torch.argmax(cpu_tensor, dim=1)
5193
5194 mps_tensor = cpu_tensor.to(torch.device('mps'))
5195 res_mps = torch.argmax(mps_tensor, dim=1)
5196 self.assertEqual(res_cpu, res_mps)
5197
5198 # https://github.com/pytorch/pytorch/issues/92311
5199 mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32)
5200 cpu_tensor = mps_tensor.detach().clone().cpu()
5201
5202 res_mps = torch.argmax(mps_tensor, dim=1)
5203 res_cpu = torch.argmax(cpu_tensor, dim=1)
5204 self.assertEqual(res_cpu, res_mps)
5205
qqaatw2458b3c2022-07-07 00:04:49 +00005206 # Test forward argmin argmax
5207 def test_argmin_argmax(self):
5208 def helper(n, c, h, w, reduction_type, dtype=torch.float32):
5209 if reduction_type == "max":
5210 arg_reduction_fn = torch.argmax
5211 else:
5212 arg_reduction_fn = torch.argmin
5213
Kulin Sethe011a8e2022-05-13 18:28:53 +00005214 cpu_x = None
5215 x = None
Thomas4935b592022-11-23 02:18:03 +00005216 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005217 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5218 x = cpu_x.detach().clone().to('mps')
5219 elif (dtype == torch.bool):
5220 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5221 x = cpu_x.detach().clone().to('mps')
5222 else:
5223 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5224 x = cpu_x.detach().clone().to('mps').requires_grad_()
5225
qqaatw2458b3c2022-07-07 00:04:49 +00005226 y = arg_reduction_fn(x)
5227 ref_y = arg_reduction_fn(cpu_x)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005228 self.assertEqual(y, ref_y)
5229
qqaatw2458b3c2022-07-07 00:04:49 +00005230 y_0 = arg_reduction_fn(x, dim=0)
5231 refy_0 = arg_reduction_fn(cpu_x, dim=0)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005232 self.assertEqual(y_0, refy_0)
5233
qqaatw2458b3c2022-07-07 00:04:49 +00005234 y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
5235 refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005236 self.assertEqual(y_0dim, refy_0dim)
5237
qqaatw2458b3c2022-07-07 00:04:49 +00005238 y_1 = arg_reduction_fn(x, dim=1)
5239 refy_1 = arg_reduction_fn(cpu_x, dim=1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005240 self.assertEqual(y_1, refy_1)
5241
qqaatw2458b3c2022-07-07 00:04:49 +00005242 y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
5243 refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005244 self.assertEqual(y_1dim, refy_1dim)
5245
qqaatw2458b3c2022-07-07 00:04:49 +00005246 y_2 = arg_reduction_fn(x, dim=2)
5247 refy_2 = arg_reduction_fn(cpu_x, dim=2)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005248 self.assertEqual(y_2, refy_2)
5249
qqaatw2458b3c2022-07-07 00:04:49 +00005250 y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
5251 refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005252 self.assertEqual(y_2dim, refy_2dim)
5253
qqaatw2458b3c2022-07-07 00:04:49 +00005254 y_3 = arg_reduction_fn(x, dim=3)
5255 refy_3 = arg_reduction_fn(cpu_x, dim=3)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005256 self.assertEqual(y_3, refy_3)
5257
qqaatw2458b3c2022-07-07 00:04:49 +00005258 y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
5259 refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005260 self.assertEqual(y_3dim, refy_3dim)
5261
qqaatw2458b3c2022-07-07 00:04:49 +00005262 helper(2, 8, 4, 4, "max", torch.float32)
5263 helper(2, 8, 4, 4, "max", torch.int32)
5264 helper(2, 8, 4, 4, "max", torch.float16)
5265 helper(2, 8, 4, 4, "max", torch.int64)
5266 helper(2, 8, 4, 4, "min", torch.float32)
5267 helper(2, 8, 4, 4, "min", torch.int32)
5268 helper(2, 8, 4, 4, "min", torch.float16)
5269 helper(2, 8, 4, 4, "min", torch.int64)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005270
Denis Vieriud0dd8982023-03-02 12:44:59 +00005271 @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
5272 def test_reduction_sum_max_long_val(self):
5273 x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
5274 x_cpu = x_mps.detach().clone().cpu()
5275
5276 res_mps = torch.sum(x_mps)
5277 res_cpu = torch.sum(x_cpu)
5278 self.assertEqual(res_mps, res_cpu)
5279
Kulin Sethe011a8e2022-05-13 18:28:53 +00005280 # Test forward max
5281 # Note - don't test grad now
5282 def test_max_el(self):
5283 def helper(n, c, h, w, dtype=torch.float32):
5284
Thomas4935b592022-11-23 02:18:03 +00005285 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005286 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5287 x = cpu_x.detach().clone().to('mps')
5288 elif (dtype == torch.bool):
5289 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5290 x = cpu_x.detach().clone().to('mps')
5291 else:
5292 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5293 x = cpu_x.detach().clone().to('mps')
5294
5295 ref_y = torch.max(cpu_x)
5296 y = torch.max(x)
5297 self.assertEqual(y, ref_y)
5298
5299 for dim in [0, 1, 2, 3]:
5300 for keepdim in [True, False]:
5301 y, idx = torch.max(x, dim=dim, keepdim=keepdim)
5302 refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
5303 self.assertEqual(y, refy)
5304 self.assertEqual(idx, refidx)
5305
5306 y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
5307 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5308 torch.max(x, dim=0, out=(y_0, idx_0))
5309 refy_0, refidx_0 = torch.max(cpu_x, dim=0)
5310 self.assertEqual(y_0, refy_0)
5311 self.assertEqual(idx_0, refidx_0)
5312
5313 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
5314 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5315 torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5316 refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
5317 self.assertEqual(y_0dim, refy_0dim)
5318 self.assertEqual(idx_0dim, refidx_0dim)
5319
5320 y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
5321 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5322 torch.max(x, dim=1, out=(y_1, idx_1))
5323 refy_1, refidx_1 = torch.max(cpu_x, dim=1)
5324 self.assertEqual(y_1, refy_1)
5325 self.assertEqual(idx_1, refidx_1)
5326
5327 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
5328 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5329 torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5330 refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
5331 self.assertEqual(y_1dim, refy_1dim)
5332 self.assertEqual(idx_1dim, refidx_1dim)
5333
5334 y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
5335 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5336 torch.max(x, dim=2, out=(y_2, idx_2))
5337 refy_2, refidx_2 = torch.max(cpu_x, dim=2)
5338 self.assertEqual(y_2, refy_2)
5339 self.assertEqual(idx_2, refidx_2)
5340
5341 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
5342 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5343 torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5344 refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
5345 self.assertEqual(y_2dim, refy_2dim)
5346 self.assertEqual(idx_2dim, refidx_2dim)
5347
5348 y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
5349 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5350 torch.max(x, dim=3, out=(y_3, idx_3))
5351 refy_3, refidx_3 = torch.max(cpu_x, dim=3)
5352 self.assertEqual(y_3, refy_3)
5353 self.assertEqual(idx_3, refidx_3)
5354
5355 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
5356 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5357 torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5358 refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
5359 self.assertEqual(y_3dim, refy_3dim)
5360 self.assertEqual(idx_3dim, refidx_3dim)
5361
5362 helper(2, 8, 4, 5, torch.float32)
5363 helper(2, 8, 4, 5, torch.int32)
5364 # helper(2, 8, 4, 5, torch.int64)
5365
Raman kumarfd0efb02022-11-18 02:53:39 +00005366 def test_median(self):
5367 def helper_dtype_int32(n1, n2, n3):
5368 cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
5369 mps_x = cpu_x.detach().clone().to('mps')
5370
5371 result_cpu = torch.median(cpu_x)
5372 result_mps = torch.median(mps_x)
5373
5374 self.assertEqual(result_cpu, result_mps)
5375
5376 for dim in [0, 1, 2]:
5377 for keepdim in [True, False]:
5378 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5379 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5380 self.assertEqual(y, refy)
5381 self.assertEqual(idx, refidx)
5382
5383 def helper_dtype_float32(n1, n2, n3):
5384 cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
5385 mps_x = cpu_x.detach().clone().to('mps')
5386
5387 result_cpu = torch.median(cpu_x)
5388 result_mps = torch.median(mps_x)
5389
5390 self.assertEqual(result_cpu, result_mps)
5391
5392 for dim in [0, 1, 2]:
5393 for keepdim in [True, False]:
5394 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5395 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5396 self.assertEqual(y, refy)
5397 self.assertEqual(idx, refidx)
5398
5399 helper_dtype_int32(10, 10, 10) # median at even place
5400 helper_dtype_int32(3, 3, 3) # median at odd place
5401 helper_dtype_int32(1, 1, 1)
5402 helper_dtype_int32(1, 2, 3)
5403 helper_dtype_float32(10, 10, 10)
5404 helper_dtype_float32(3, 3, 3)
5405 helper_dtype_float32(1, 1, 1)
5406
Kulin Sethe011a8e2022-05-13 18:28:53 +00005407 def test_any(self):
5408 def helper(shape):
5409 input_xs = []
5410 prod = 1
5411
5412 for i in range(len(shape)):
5413 prod *= shape[i]
5414 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5415 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5416 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5417 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5418 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5419 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5420 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5421 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5422 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5423 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5424
5425 for i, cpu_x in enumerate(input_xs):
5426 x = cpu_x.detach().clone().to('mps')
5427 y = torch.any(x)
5428 ref_y = torch.any(cpu_x)
5429 self.assertEqual(y, ref_y)
5430
5431 y_0 = torch.any(x, dim=0)
5432 refy_0 = torch.any(cpu_x, dim=0)
5433 self.assertEqual(y_0, refy_0)
5434
5435 y_0dim = torch.any(x, dim=0, keepdim=True)
5436 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5437 self.assertEqual(y_0dim, refy_0dim)
5438
5439 y_0dim = torch.any(x, dim=0, keepdim=True)
5440 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5441 self.assertEqual(y_0dim, refy_0dim)
5442
5443 y_1 = torch.any(x, dim=1)
5444 refy_1 = torch.any(cpu_x, dim=1)
5445 self.assertEqual(y_1, refy_1)
5446
5447 y_1dim = torch.any(x, dim=1, keepdim=True)
5448 refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
5449 self.assertEqual(y_1dim, refy_1dim)
5450
5451 if (len(shape) > 2):
5452 y_2 = torch.any(x, dim=2)
5453 refy_2 = torch.any(cpu_x, dim=2)
5454 self.assertEqual(y_2, refy_2)
5455
5456 y_2dim = torch.any(x, dim=2, keepdim=True)
5457 refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
5458 self.assertEqual(y_2dim, refy_2dim)
5459
5460 y_3 = torch.any(x, dim=3)
5461 refy_3 = torch.any(cpu_x, dim=3)
5462 self.assertEqual(y_3, refy_3)
5463
5464 y_3dim = torch.any(x, dim=3, keepdim=True)
5465 refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
5466 self.assertEqual(y_3dim, refy_3dim)
5467 helper((1, 1, 1, 1))
5468 helper((1, 1, 3, 3))
5469 helper((7, 13))
5470 helper((2, 8, 4, 5))
5471
Kulin Sethe20c94b2023-05-05 22:57:06 +00005472 def test_reduction_ops_5D(self):
5473 def helper(fn, dim):
Nikita Shulgafebadda2024-07-11 12:00:56 -07005474 shape = (1, 1, 2, 1, 1)
5475 x_cpu = fn(torch.zeros(shape), dim=dim)
5476 x_mps = fn(torch.zeros(shape, device="mps"), dim=dim)
5477 self.assertEqual(x_cpu, x_mps.cpu())
5478 for fn in [torch.any, torch.all]:
Kulin Sethe20c94b2023-05-05 22:57:06 +00005479 for dim in range(0, 4):
5480 helper(fn, dim)
5481
Nikita Shulgafebadda2024-07-11 12:00:56 -07005482 # 6D tensor reductions
5483 # Regression test for https://github.com/pytorch/pytorch/issues/95538
5484 x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu()
5485 self.assertEqual(x.all(), x.cpu().all())
5486 for i in range(-5, 6):
5487 self.assertEqual(x.all(dim=i), x.cpu().all(dim=i))
5488
Kulin Sethe011a8e2022-05-13 18:28:53 +00005489 def test_all(self):
5490 def helper(shape):
5491 input_xs = []
5492 prod = 1
5493
5494 for i in range(len(shape)):
5495 prod *= shape[i]
5496 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5497 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5498 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5499 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5500 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5501 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5502 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5503 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5504 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5505 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5506
5507 for i, cpu_x in enumerate(input_xs):
5508 x = cpu_x.detach().clone().to('mps')
5509 y = torch.all(x)
5510 ref_y = torch.all(cpu_x)
5511 self.assertEqual(y, ref_y)
5512
5513 y_0 = torch.all(x, dim=0)
5514 refy_0 = torch.all(cpu_x, dim=0)
5515 self.assertEqual(y_0, refy_0)
5516
5517 y_0dim = torch.all(x, dim=0, keepdim=True)
5518 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5519 self.assertEqual(y_0dim, refy_0dim)
5520
5521 y_0dim = torch.all(x, dim=0, keepdim=True)
5522 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5523 self.assertEqual(y_0dim, refy_0dim)
5524
5525 y_1 = torch.all(x, dim=1)
5526 refy_1 = torch.all(cpu_x, dim=1)
5527 self.assertEqual(y_1, refy_1)
5528
5529 y_1dim = torch.all(x, dim=1, keepdim=True)
5530 refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
5531 self.assertEqual(y_1dim, refy_1dim)
5532 if (len(shape) > 2):
5533 y_2 = torch.all(x, dim=2)
5534 refy_2 = torch.all(cpu_x, dim=2)
5535 self.assertEqual(y_2, refy_2)
5536
5537 y_2dim = torch.all(x, dim=2, keepdim=True)
5538 refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
5539 self.assertEqual(y_2dim, refy_2dim)
5540
5541 y_3 = torch.all(x, dim=3)
5542 refy_3 = torch.all(cpu_x, dim=3)
5543 self.assertEqual(y_3, refy_3)
5544
5545 y_3dim = torch.all(x, dim=3, keepdim=True)
5546 refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
5547 self.assertEqual(y_3dim, refy_3dim)
5548
5549 helper((1, 1, 1, 1))
5550 helper((1, 1, 3, 3))
5551 helper((7, 13))
5552 helper((2, 8, 4, 5))
Nikita Shulgafebadda2024-07-11 12:00:56 -07005553 # Empty tensor
David Radley17250972023-07-14 17:42:51 +00005554 x_cpu = torch.tensor([], dtype=torch.bool)
5555 x_mps = x_cpu.to("mps")
Nikita Shulgafebadda2024-07-11 12:00:56 -07005556 self.assertEqual(x_cpu.all(), x_mps.all().cpu())
Kulin Sethe011a8e2022-05-13 18:28:53 +00005557
5558 # Test forward min
5559 def test_min_el(self):
5560 def helper(n, c, h, w):
5561 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5562 x = cpu_x.detach().clone().to('mps')
5563
5564 y = torch.min(x)
5565 ref_y = torch.min(cpu_x)
5566 self.assertEqual(y, ref_y)
5567
5568 y_0, idx_0 = torch.min(x, dim=0)
5569 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5570 self.assertEqual(y_0, refy_0)
5571 self.assertEqual(idx_0, refidx_0)
5572
5573 y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
5574 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5575 torch.min(x, dim=0, out=(y_0, idx_0))
5576 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5577 self.assertEqual(y_0, refy_0)
5578 self.assertEqual(idx_0, refidx_0)
5579
5580 y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
5581 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5582 self.assertEqual(y_0dim, refy_0dim)
5583 self.assertEqual(idx_0dim, refidx_0dim)
5584
5585 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
5586 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5587 torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5588 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5589 self.assertEqual(y_0dim, refy_0dim)
5590 self.assertEqual(idx_0dim, refidx_0dim)
5591
5592 y_1, idx_1 = torch.min(x, dim=1)
5593 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5594 self.assertEqual(y_1, refy_1)
5595 self.assertEqual(idx_1, refidx_1)
5596
5597 y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
5598 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5599 torch.min(x, dim=1, out=(y_1, idx_1))
5600 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5601 self.assertEqual(y_1, refy_1)
5602 self.assertEqual(idx_1, refidx_1)
5603
5604 y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
5605 refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
5606 self.assertEqual(y_1dim, refy_1dim)
5607 self.assertEqual(idx_1dim, refidx_1dim)
5608
5609 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
5610 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5611 torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5612 refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
5613 self.assertEqual(y_1dim, refy_1dim)
5614 self.assertEqual(idx_1dim, refidx_1dim)
5615
5616 y_2, idx_2 = torch.min(x, dim=2)
5617 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5618 self.assertEqual(y_2, refy_2)
5619 self.assertEqual(idx_2, refidx_2)
5620
5621 y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
5622 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5623 torch.min(x, dim=2, out=(y_2, idx_2))
5624 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5625 self.assertEqual(y_2, refy_2)
5626 self.assertEqual(idx_2, refidx_2)
5627
5628 y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
5629 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
5630 self.assertEqual(y_2dim, refy_2dim)
5631 self.assertEqual(idx_2dim, refidx_2dim)
5632
5633 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
5634 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5635 torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5636 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
5637 self.assertEqual(y_2dim, refy_2dim)
5638 self.assertEqual(idx_2dim, refidx_2dim)
5639
5640 y_3, idx_3 = torch.min(x, dim=3)
5641 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5642 self.assertEqual(y_3, refy_3)
5643 self.assertEqual(idx_3, refidx_3)
5644
5645 y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
5646 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5647 torch.min(x, dim=3, out=(y_3, idx_3))
5648 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5649 self.assertEqual(y_3, refy_3)
5650 self.assertEqual(idx_3, refidx_3)
5651
5652 y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
5653 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
5654 self.assertEqual(y_3dim, refy_3dim)
5655 self.assertEqual(idx_3dim, refidx_3dim)
5656
5657 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
5658 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5659 torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5660 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
5661 self.assertEqual(y_3dim, refy_3dim)
5662 self.assertEqual(idx_3dim, refidx_3dim)
5663
5664 helper(2, 8, 4, 5)
5665
5666 # Test forward sum
5667 def test_sum(self):
5668 def helper(n, c, h, w, dtype=torch.float32):
5669 cpu_x = None
5670 x = None
Thomas4935b592022-11-23 02:18:03 +00005671 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005672 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5673 x = cpu_x.detach().clone().to('mps')
5674 elif (dtype == torch.bool):
5675 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5676 x = cpu_x.detach().clone().to('mps')
5677 else:
5678 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5679 x = cpu_x.detach().clone().to('mps').requires_grad_()
5680
5681 all_sum = torch.sum(x)
5682 all_sum_cpu = torch.sum(cpu_x)
5683
5684 self.assertEqual(all_sum, all_sum_cpu)
5685
5686 nil_dim_sum = torch.sum(x, dim=[])
5687 nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
5688
5689 self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
5690
5691 nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
5692 nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
5693
5694 self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
5695
5696 zero_dim_sum = torch.sum(x, dim=[0])
5697 zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
5698
5699 self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
5700
5701 zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
5702 zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
5703
5704 self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
5705
5706 zero_one_dim_sum = torch.sum(x, dim=[0, 1])
5707 zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
5708
5709 self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
5710
5711 zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
5712 zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
5713
5714 self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
5715
5716 two_three_dim_sum = torch.sum(x, dim=[2, 3])
5717 two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
5718
5719 self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
5720
5721 two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
5722 two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
5723
5724 self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
5725
5726 helper(2, 8, 4, 5)
5727 helper(2, 8, 4, 5, dtype=torch.int32)
5728 helper(2, 8, 4, 5, dtype=torch.int64)
5729 helper(2, 8, 4, 5, dtype=torch.bool)
pytorchbot1de132e2024-10-02 14:52:55 -07005730 # Regression test for https://github.com/pytorch/pytorch/issues/136132
5731 x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2)
5732 self.assertEqual(x.numel(), 8)
5733 self.assertEqual(x.max().item(), 30.0)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005734
5735 # Test forward prod
5736 def test_prod(self):
5737 def helper(shape, dtype=torch.float32):
5738 cpu_x = None
5739 x = None
Thomas4935b592022-11-23 02:18:03 +00005740 if (dtype not in [torch.float32, torch.bool]):
Kulin Sethe011a8e2022-05-13 18:28:53 +00005741 cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
5742 x = cpu_x.detach().clone().to('mps')
5743 elif (dtype == torch.bool):
5744 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
5745 x = cpu_x.detach().clone().to('mps')
5746 else:
5747 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
5748 x = cpu_x.detach().clone().to('mps').requires_grad_()
5749
5750 all_prod = torch.prod(x)
5751 all_prod_cpu = torch.prod(cpu_x)
5752
5753 self.assertEqual(all_prod, all_prod_cpu)
5754
5755 for dim in range(len(shape)):
5756 dim_prod = torch.prod(x, dim=dim)
5757 dim_prod_cpu = torch.prod(cpu_x, dim=dim)
5758
5759 self.assertEqual(dim_prod, dim_prod_cpu)
5760
5761 dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
5762 dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
5763
5764 self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
5765
5766 for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
5767 helper((2, 3), dtype)
5768
5769 # Test forward mean
5770 def test_mean(self):
5771 def helper(n, c, h, w):
5772 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
5773 x = cpu_x.detach().clone().to('mps').requires_grad_()
5774
5775 all_mean = torch.mean(x)
5776 all_mean_cpu = torch.mean(cpu_x)
5777
5778 self.assertEqual(all_mean, all_mean_cpu)
5779
5780 nil_dim_mean = torch.mean(x, dim=[])
5781 nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
5782
5783 self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
5784
5785 nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
5786 nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
5787
5788 self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
5789
5790 zero_dim_mean = torch.mean(x, dim=[0])
5791 zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
5792
5793 self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
5794
5795 zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
5796 zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
5797
5798 self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
5799
5800 zero_one_dim_mean = torch.mean(x, dim=[0, 1])
5801 zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
5802
5803 self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
5804
5805 zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
5806 zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
5807
5808 self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
5809
5810 two_three_dim_mean = torch.mean(x, dim=[2, 3])
5811 two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
5812
5813 self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
5814
5815 two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
5816 two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
5817
5818 self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
5819
5820 helper(2, 8, 4, 5)
5821
5822 # Test std
5823 def test_std(self):
5824 def helper(shape):
5825 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5826 x = cpu_x.detach().clone().to('mps')
5827
5828 all_std = torch.std(x, unbiased=False)
5829 all_std_cpu = torch.std(cpu_x, unbiased=False)
5830
5831 self.assertEqual(all_std, all_std_cpu)
5832
5833 nil_dim_std = torch.std(x, dim=[], unbiased=False)
5834 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
5835
5836 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5837
5838 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
5839 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
5840
5841 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5842
5843 zero_dim_std = torch.std(x, dim=[0], unbiased=False)
5844 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
5845
5846 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5847
5848 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
5849 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
5850
5851 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5852
5853 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
5854 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
5855
5856 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5857
5858 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
5859 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
5860
5861 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5862
5863 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
5864 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
5865
5866 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5867
5868 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
5869 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
5870
5871 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5872
5873 all_std = torch.std(x, unbiased=True)
5874 all_std_cpu = torch.std(cpu_x, unbiased=True)
5875
5876 self.assertEqual(all_std, all_std_cpu)
5877
5878 nil_dim_std = torch.std(x, dim=[], unbiased=True)
5879 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
5880
5881 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5882
5883 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
5884 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
5885
5886 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5887
5888 zero_dim_std = torch.std(x, dim=[0], unbiased=True)
5889 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
5890
5891 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5892
5893 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
5894 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
5895
5896 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5897
5898 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
5899 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
5900
5901 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5902
5903 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
5904 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
5905
5906 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5907
5908 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
5909 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
5910
5911 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5912
5913 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
5914 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
5915
5916 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5917
5918 helper((4, 5, 6, 7))
qqaatwae6f07e2022-06-30 12:56:55 +00005919 # verify if a change in shape of input would cause problems with graph caching
5920 helper((9, 5, 6, 7))
Kulin Sethe011a8e2022-05-13 18:28:53 +00005921
5922 # Test var
Abhishek Pathakf0570352022-09-25 19:03:58 +00005923 def test_var_simple(self):
5924 def helper():
5925
5926 shape = [2, 3, 4, 5]
5927
Kulin Sethe011a8e2022-05-13 18:28:53 +00005928 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5929 x = cpu_x.detach().clone().to('mps')
5930
Abhishek Pathakf0570352022-09-25 19:03:58 +00005931 for unbiased in [False, True]:
5932 for keepdim in [False, True]:
Kulin Sethe011a8e2022-05-13 18:28:53 +00005933
Abhishek Pathakf0570352022-09-25 19:03:58 +00005934 zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
5935 zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005936
Abhishek Pathakf0570352022-09-25 19:03:58 +00005937 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005938
Abhishek Pathakf0570352022-09-25 19:03:58 +00005939 all_var = torch.var(x, unbiased=unbiased)
5940 all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005941
Abhishek Pathakf0570352022-09-25 19:03:58 +00005942 self.assertEqual(all_var, all_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005943
Abhishek Pathakf0570352022-09-25 19:03:58 +00005944 nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
5945 nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005946
Abhishek Pathakf0570352022-09-25 19:03:58 +00005947 self.assertEqual(nil_dim_var, nil_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005948
Abhishek Pathakf0570352022-09-25 19:03:58 +00005949 zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
5950 zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005951
Abhishek Pathakf0570352022-09-25 19:03:58 +00005952 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005953
Abhishek Pathakf0570352022-09-25 19:03:58 +00005954 zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
5955 zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005956
Abhishek Pathakf0570352022-09-25 19:03:58 +00005957 self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005958
Abhishek Pathakf0570352022-09-25 19:03:58 +00005959 two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
5960 two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005961
Abhishek Pathakf0570352022-09-25 19:03:58 +00005962 self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00005963
Abhishek Pathakf0570352022-09-25 19:03:58 +00005964 helper()
Kulin Sethe011a8e2022-05-13 18:28:53 +00005965
Abhishek Pathak074dc742022-06-18 00:14:05 +00005966 # Test forward amax
5967 def test_amax(self):
5968 def helper(shape, dim, keepdim):
5969 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5970 x = cpu_x.detach().clone().to('mps').requires_grad_()
5971
5972 result = torch.amax(x, dim=dim, keepdim=keepdim)
5973 result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
5974
5975 cpu_grad = torch.randn(result_cpu.shape)
5976 grad = cpu_grad.to('mps')
5977
5978 result_cpu.backward(gradient=cpu_grad)
5979 result.backward(gradient=grad)
5980
5981 self.assertEqual(result, result_cpu)
5982 self.assertEqual(x.grad, cpu_x.grad)
5983
5984 for dim in ([], [0], [0, 1], [2, 3]):
5985 for keepdim in [False, True]:
5986 helper((2, 8, 4, 5), dim, keepdim)
5987
5988 # Test forward amin
5989 def test_amin(self):
5990 def helper(shape, dim, keepdim):
5991 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5992 x = cpu_x.detach().clone().to('mps').requires_grad_()
5993
5994 result = torch.amin(x, dim=dim, keepdim=keepdim)
5995 result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
5996
5997 cpu_grad = torch.randn(result_cpu.shape)
5998 grad = cpu_grad.to('mps')
5999
6000 result_cpu.backward(gradient=cpu_grad)
6001 result.backward(gradient=grad)
6002
6003 self.assertEqual(result, result_cpu)
6004 self.assertEqual(x.grad, cpu_x.grad)
6005
6006 for dim in ([], [0], [0, 1], [2, 3]):
6007 for keepdim in [False, True]:
6008 helper((2, 8, 4, 5), dim, keepdim)
6009
Kulin Sethe011a8e2022-05-13 18:28:53 +00006010 # Test minimum and maximum
6011 def test_minimum_maximum(self):
6012 def helper(n, c, h, w):
6013 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6014 cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6015 mps_x = cpu_x.detach().clone().to('mps')
6016 mps_y = cpu_y.detach().clone().to('mps')
6017
6018 minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
6019 minimum_result_mps = torch.minimum(mps_x, mps_y)
6020 self.assertEqual(minimum_result_cpu, minimum_result_mps)
6021
6022 maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
6023 maximum_result_mps = torch.maximum(mps_x, mps_y)
6024 self.assertEqual(maximum_result_cpu, maximum_result_mps)
6025
6026 helper(1, 1, 4, 5)
6027
Denis Vieriud1a2aa12023-08-02 02:51:34 +00006028 def test_clamp_fp16_fp32(self):
6029 cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False)
6030 x = cpu_x.detach().clone().to('mps')
6031
6032 dtype = torch.float16
6033
6034 clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16)
6035 clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10
6036 clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps)
6037
6038 clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16)
6039 clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10
6040 clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu)
6041
6042 self.assertEqual(clamp_result_mps, clamp_result_cpu)
6043
Roger Lam40acc842024-03-18 19:38:15 +00006044 def test_clamp_nan(self):
6045 t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
6046 t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu")
6047
6048 clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100)
6049 clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100)
6050
6051 self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu)
6052
6053 clamp_min_mps = torch.clamp(t_mps, min=-100)
6054 clamp_min_cpu = torch.clamp(t_cpu, min=-100)
6055
6056 self.assertEqual(clamp_min_mps, clamp_min_cpu)
6057
6058 clamp_max_mps = torch.clamp(t_mps, max=100)
6059 clamp_max_cpu = torch.clamp(t_cpu, max=100)
6060
6061 self.assertEqual(clamp_max_mps, clamp_max_cpu)
6062
Kulin Sethe011a8e2022-05-13 18:28:53 +00006063 # Test clamp_min
6064 def test_clamp_min(self):
6065 def helper(n, c, h, w):
6066 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6067 x = cpu_x.detach().clone().to('mps')
6068
6069 cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6070 min_t = cpu_min_t.detach().clone().to('mps')
6071
6072 clamp_min_result = torch.clamp_min(x, min=5.0)
6073 clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
6074
6075 self.assertEqual(clamp_min_result, clamp_min_result_cpu)
6076
6077 clamp_min_t_result = torch.clamp_min(x, min=min_t)
6078 clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
6079
6080 self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
6081
6082 helper(2, 8, 4, 5)
6083
6084 # Test clamp_max
6085
6086 def test_clamp_max(self):
6087 def helper(n, c, h, w):
6088 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6089 x = cpu_x.detach().clone().to('mps')
6090
6091 cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6092 max_t = cpu_max_t.detach().clone().to('mps')
6093
6094 clamp_max_result = torch.clamp_max(x, max=100.0)
6095 clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
6096
6097 self.assertEqual(clamp_max_result, clamp_max_result_cpu)
6098
6099 clamp_max_t_result = torch.clamp_max(x, max=max_t)
6100 clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
6101
6102 self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
6103
6104 helper(2, 8, 4, 5)
6105
6106 # Test clamp
6107 def test_clamp(self):
6108 def helper(n, c, h, w):
6109 import numpy as np
6110 upper_bound = 1000
6111 half_upper_bound = upper_bound / 2
6112
6113 # x=[0..1000)
6114 x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
6115 cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
6116 x = cpu_x.detach().clone().to('mps')
6117
6118 # x=[0..500)
6119 min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
6120 cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
6121 min_t = cpu_min_t.detach().clone().to('mps')
6122
6123 # x=[500..1000), to ensure max's are greater than mins
6124 max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
6125 cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
6126 max_t = cpu_max_t.detach().clone().to('mps')
6127
6128 # [200..600]: just an arbitrary range between [0..1000]
6129 clamp_result = torch.clamp(x, min=200.0, max=600.0)
6130 clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
6131 self.assertEqual(clamp_result, clamp_result_cpu)
6132
6133 # test optional scalar refs and cached graph keys by passing only max
6134 clamp_opt_result = torch.clamp(x, max=600.0)
6135 clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
6136 self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
6137
6138 clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
6139 clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
6140 self.assertEqual(clamp_t_result, clamp_t_result_cpu)
6141
6142 # test optional tensor refs and cached graph keys by passing only max
6143 clamp_topt_result = torch.clamp(x, max=max_t)
6144 clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
6145 self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
6146
Li-Huai (Allan) Lind4d086c2023-08-04 09:32:09 +00006147 # test strided x
6148 clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
6149 clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
6150 self.assertEqual(clamp_result, clamp_result_cpu)
6151
6152 # test strided x, min_t, max_t
6153 clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
6154 clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
6155 self.assertEqual(clamp_result, clamp_result_cpu)
6156
6157 # test strided min_t, max_t
6158 clamp_result = torch.clamp(
6159 x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6160 min=min_t.movedim(0, -1),
6161 max=max_t.movedim(0, -1)
6162 )
6163 clamp_result_cpu = torch.clamp(
6164 cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6165 min=cpu_min_t.movedim(0, -1),
6166 max=cpu_max_t.movedim(0, -1)
6167 )
6168 self.assertEqual(clamp_result, clamp_result_cpu)
6169
Kulin Sethe011a8e2022-05-13 18:28:53 +00006170 # test inplace clamping
6171 x.clamp_(min=200.0, max=600.0)
6172 cpu_x.clamp_(min=200.0, max=600.0)
6173 self.assertEqual(cpu_x, x)
6174
6175 helper(2, 8, 4, 5)
6176
6177 def test_divmode(self):
6178 def helper(shape, rounding_mode):
Abhishek Pathakbccc26f2022-09-10 03:10:04 +00006179 for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
Kulin Seth5d9d8c62023-03-01 20:52:28 +00006180 if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or
6181 (rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False:
Kulin Seth299ada92023-02-10 00:10:08 +00006182 cpu_x = None
6183 cpu_y = None
6184 if (dtype in [torch.float32, torch.float16]):
6185 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6186 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6187 else:
6188 cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
6189 cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
Abhishek Pathakbccc26f2022-09-10 03:10:04 +00006190
Kulin Seth299ada92023-02-10 00:10:08 +00006191 mps_x = cpu_x.detach().clone().to('mps')
6192 # clamp to avoid division by 0
6193 mps_y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00006194
Kulin Seth299ada92023-02-10 00:10:08 +00006195 if (rounding_mode == "floor_divide"):
6196 result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
6197 result_div_mps = torch.floor_divide(mps_x, mps_y)
6198 self.assertEqual(result_div_mps, result_div_cpu)
6199 else:
6200 result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
6201 result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
6202 self.assertEqual(result_div_mps, result_div_cpu)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006203
Kulin Setha6347f52022-06-07 18:22:10 +00006204 helper((2, 8, 4, 5), None)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006205 helper((2, 8, 4, 5), "floor")
6206 helper((2, 8, 4, 5), "trunc")
Ramin Azarmehrb63f0312022-12-20 17:02:29 +00006207 helper((2, 8, 4, 5), "floor_divide")
Kulin Sethe011a8e2022-05-13 18:28:53 +00006208
6209 def test_rounding(self):
6210 def helper(shape):
6211 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6212 mps_x = cpu_x.detach().clone().to('mps')
6213
6214 result_floor_cpu = torch.floor(cpu_x)
6215 result_floor_mps = torch.floor(mps_x)
6216 self.assertEqual(result_floor_mps, result_floor_cpu)
6217
6218 result_ceil_cpu = torch.ceil(cpu_x)
6219 result_ceil_mps = torch.ceil(mps_x)
6220 self.assertEqual(result_ceil_mps, result_ceil_cpu)
6221
6222 result_trunc_cpu = torch.trunc(cpu_x)
6223 result_trunc_mps = torch.trunc(mps_x)
6224 self.assertEqual(result_trunc_mps, result_trunc_cpu)
6225
6226 result_round_cpu = torch.round(cpu_x)
6227 result_round_mps = torch.round(mps_x)
6228 self.assertEqual(result_round_mps, result_round_cpu)
6229
6230 helper((2, 6, 3, 5))
6231 helper((2, 8, 4, 5))
6232
Denis Vieriucedb7e32023-02-14 01:06:49 +00006233 def test_remainder(self):
6234 res_cpu = torch.remainder(
6235 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32))
6236 res_mps = torch.remainder(
6237 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32))
6238 self.assertEqual(res_cpu, res_mps)
6239
6240 res_cpu = torch.remainder(
6241 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5)
6242 res_mps = torch.remainder(
6243 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
6244 self.assertEqual(res_cpu, res_mps)
6245
Kulin Sethe011a8e2022-05-13 18:28:53 +00006246 def test_expand(self):
6247 def helper(n, c):
6248 values = [[1.0], [4.0], [7.0]]
6249 cpu_x = torch.tensor(values, device='cpu')
6250 x = cpu_x.detach().clone().to('mps')
6251
6252 strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
6253 strided_mps = torch.as_strided(x, (3, 4), (1, 0))
6254
Kulin Sethe011a8e2022-05-13 18:28:53 +00006255 self.assertEqual(strided_mps, strided_cpu)
6256
6257 helper(3, 1)
6258
Kulin Seth0fe11582023-02-10 15:22:59 +00006259 def test_im2col(self):
6260 def helper(x):
6261 return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3)
6262 x_cpu = torch.rand(1, 1, 200, 100)
6263 x = x_cpu.detach().clone().to('mps')
6264 self.assertEqual(helper(x_cpu), helper(x))
6265
Kulin Sethe011a8e2022-05-13 18:28:53 +00006266 def test_select(self):
6267 def helper(n, c):
6268 cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
6269 x = cpu_x.detach().clone().to('mps').requires_grad_()
6270
6271 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
6272 strided_mps = torch.as_strided(x, (3, 1), (3, 1))
6273 self.assertEqual(strided_mps, strided_cpu)
6274
6275 strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
6276 strided_mps = torch.as_strided(x, (1, 3), (3, 1))
6277 self.assertEqual(strided_mps, strided_cpu)
6278
6279 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
6280 strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006281
6282 self.assertEqual(strided_mps, strided_cpu)
6283
6284 helper(3, 3)
6285
Kulin Seth18587cb2023-02-13 01:03:22 +00006286 def test_sort(self):
6287 for SIZE in (4, 2049):
6288 device = 'mps'
6289 x = torch.rand(4, SIZE, device=device)
6290 res1val, res1ind = torch.sort(x)
6291
6292 res2val = torch.tensor((), device=device)
6293 res2ind = torch.tensor((), device=device, dtype=torch.long)
6294 torch.sort(x, out=(res2val, res2ind))
6295 self.assertEqual(res1val, res2val, atol=0, rtol=0)
6296 self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
6297 self.assertEqual(torch.argsort(x), res1ind)
6298 self.assertEqual(x.argsort(), res1ind)
6299
6300 self.assertEqual(
6301 torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
6302 torch.tensor((10, 20, 30, 40, 50), device=device),
6303 atol=0, rtol=0
6304 )
6305
Kulin Sethe011a8e2022-05-13 18:28:53 +00006306 def test_upsample_nearest2d(self):
Denis Vieriua2afc652023-02-17 05:07:22 +00006307 def helper(N, C, H, W, memory_format):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006308 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
Denis Vieriua2afc652023-02-17 05:07:22 +00006309 requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006310 inputCPU.retain_grad()
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006311 inputMPS = inputCPU.detach().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006312
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006313 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00006314
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006315 for i in values:
6316 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00006317 upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
6318
6319 outputCPU = upsample_nearest2d(inputCPU)
6320 outputMPS = upsample_nearest2d(inputMPS)
6321
6322 self.assertEqual(outputCPU, outputMPS)
6323 upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
6324
6325 outputCPU = upsample_nearest2d(inputCPU)
6326 outputMPS = upsample_nearest2d(inputMPS)
6327
6328 self.assertEqual(outputCPU, outputMPS)
6329
6330 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6331 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6332
6333 self.assertEqual(inputCPU.grad, inputMPS.grad)
6334
Denis Vieriua2afc652023-02-17 05:07:22 +00006335 for memory_format in [torch.channels_last, torch.contiguous_format]:
6336 helper(1, 1, 4, 4, memory_format=memory_format)
6337 helper(7, 5, 3, 2, memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006338
6339 def test_upsample_bilinear2d(self):
6340 def helper(N, C, H, W):
6341 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
6342 requires_grad=True).reshape(N, C, H, W)
6343 inputCPU.retain_grad()
6344 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6345
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006346 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00006347
Alban Desmaisonbde246f2022-05-30 10:36:31 -04006348 for i in values:
6349 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00006350 upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
6351
6352 outputCPU = upsample_bilinear2d(inputCPU)
6353 outputMPS = upsample_bilinear2d(inputMPS)
6354
6355 self.assertEqual(outputCPU, outputMPS)
6356
6357 upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
6358
6359 outputCPU = upsample_bilinear2d(inputCPU)
6360 outputMPS = upsample_bilinear2d(inputMPS)
6361
6362 self.assertEqual(outputCPU, outputMPS)
6363
6364 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6365 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6366
6367 self.assertEqual(inputCPU.grad, inputMPS.grad)
6368
6369 helper(1, 1, 4, 4)
6370 helper(7, 5, 3, 2)
6371
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00006372 def test_interpolate(self):
6373 def helper(shape, output_size, scales, mode, align_corners=False):
6374 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6375 inputCPU.retain_grad()
6376 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
Kulin Seth067c8062022-07-13 21:39:50 +00006377
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00006378 # align_corners is used for 2D interpolation only
6379 if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
6380 if scales is not None:
6381 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
6382 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
6383 else:
6384 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
6385 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
6386 elif scales is not None:
6387 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
6388 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
6389 else:
6390 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
6391 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
Kulin Seth067c8062022-07-13 21:39:50 +00006392
6393 self.assertEqual(outputCPU, outputMPS)
6394
Ramin Azarmehrb44d4672023-01-05 00:48:51 +00006395 # backward pass (chose 0.6 just to have the grad_output != 1)
6396 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
6397 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
6398 self.assertEqual(inputCPU.grad, inputMPS.grad)
6399
6400 # 1D interpolation
6401 for mode in ['nearest', 'nearest-exact']:
6402 helper([2, 3, 4], [3], None, mode) # downsample with size
6403 helper([2, 3, 4], [6], None, mode) # upsample with size
6404 helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor
6405 helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor
6406 # 2D interpolation
6407 for mode in ['nearest', 'nearest-exact', 'bilinear']:
6408 helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size
6409 helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size
6410 helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor
6411 helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor
6412 # align_corners=True
6413 helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
6414 helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
Kulin Seth067c8062022-07-13 21:39:50 +00006415
Kulin Sethe011a8e2022-05-13 18:28:53 +00006416 # Test concat forward
6417 def test_cat1(self):
6418 def helper(shape_x, shape_y, shape_z):
6419 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6420 x = cpu_x.detach().clone().to('mps')
6421
6422 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6423 y = cpu_y.detach().clone().to('mps')
6424
6425 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6426 z = cpu_z.detach().clone().to('mps')
6427
6428 cat = torch.cat([x, y, z], dim=1)
6429 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6430
6431 self.assertEqual(cat, cat_cpu)
6432
6433 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
Abhishek Pathakd7210e62022-07-20 16:31:44 +00006434 helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6435 helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5])
6436 helper([2, 2, 6, 5], [0], [2, 5, 6, 5])
6437 helper([0], [2, 3, 6, 5], [2, 5, 6, 5])
6438 helper([2, 3, 4, 5], [2, 5, 4, 5], [0])
6439 helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5])
6440 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6441 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5])
Kulin Sethe011a8e2022-05-13 18:28:53 +00006442
Kulin Sethe011a8e2022-05-13 18:28:53 +00006443 # Test stack forward
6444 def test_stack(self):
6445 # All shapes must be same
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006446 def helper(shape, dtype=torch.float32):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006447
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006448 x, cpu_x = None, None
6449 y, cpu_y = None, None
6450 z, cpu_z = None, None
Kulin Sethe011a8e2022-05-13 18:28:53 +00006451
Thomas4935b592022-11-23 02:18:03 +00006452 if (dtype not in [torch.float32, torch.bool]):
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006453 cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6454 x = cpu_x.detach().clone().to('mps')
6455 cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6456 y = cpu_y.detach().clone().to('mps')
6457 cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6458 z = cpu_z.detach().clone().to('mps')
6459 elif (dtype == torch.bool):
6460 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6461 x = cpu_x.detach().clone().to('mps')
6462 cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6463 y = cpu_y.detach().clone().to('mps')
6464 cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6465 z = cpu_z.detach().clone().to('mps')
6466 else:
6467 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6468 x = cpu_x.detach().clone().to('mps').requires_grad_()
6469 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6470 y = cpu_y.detach().clone().to('mps').requires_grad_()
6471 cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6472 z = cpu_z.detach().clone().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006473
6474 stack = torch.stack([x, y, z], dim=1)
6475 stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
6476
6477 self.assertEqual(stack, stack_cpu)
6478
6479 helper([2, 8, 4, 5])
Denis Vieriue3b98ba2022-07-14 22:00:57 +00006480 helper([2, 8, 4, 5], dtype=torch.float16)
6481 helper([2, 8, 4, 5], dtype=torch.int32)
6482 helper([2, 8, 4, 5], dtype=torch.int64)
6483 helper([2, 8, 4, 5], dtype=torch.bool)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006484 # Empty test - Currently failing! Empty tensor not handled!
6485 # helper([0, 2, 4, 5])
6486
6487 # Test abs
6488 def test_abs(self):
6489 def helper(shape):
6490 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6491 x = cpu_x.detach().clone().to('mps')
6492
6493 abs_result = torch.abs(x)
6494 abs_result_cpu = torch.abs(cpu_x)
6495
6496 self.assertEqual(abs_result, abs_result_cpu)
6497
6498 helper((2, 8, 4, 5))
6499
6500 def test_log(self):
6501 def helper(shape):
6502 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6503 x = cpu_x.detach().clone().to('mps')
6504
6505 log_result = torch.log(x)
6506 log_result_cpu = torch.log(cpu_x)
6507
6508 self.assertEqual(log_result, log_result_cpu)
6509
6510 helper((2, 8, 4, 5))
6511
6512 def test_log_ten(self):
6513 def helper(shape):
6514 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6515 x = cpu_x.detach().clone().to('mps')
6516
6517 log_ten_result = torch.log10(x)
6518 log_ten_result_cpu = torch.log10(cpu_x)
6519
6520 self.assertEqual(log_ten_result, log_ten_result_cpu)
6521
6522 helper((2, 8, 4, 5))
6523
6524 def test_log_two(self):
6525 def helper(shape):
6526 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6527 x = cpu_x.detach().clone().to('mps')
6528
6529 log_two_result = torch.log2(x)
6530 log_two_result_cpu = torch.log2(cpu_x)
6531
6532 self.assertEqual(log_two_result, log_two_result_cpu)
6533
6534 helper((2, 8, 4, 5))
6535
6536 def test_log1p(self):
6537 def helper(shape):
6538 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6539 x = cpu_x.detach().clone().to('mps')
6540
6541 log_result = torch.log1p(x)
6542 log_result_cpu = torch.log1p(cpu_x)
6543
6544 self.assertEqual(log_result, log_result_cpu)
6545
6546 helper((2, 8, 4, 5))
6547
6548 def test_logaddexp(self):
6549 def helper(shape):
6550 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6551 x = cpu_x.detach().clone().to('mps')
6552
6553 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6554 y = cpu_y.detach().clone().to('mps')
6555
6556 log_result = torch.logaddexp(x, y)
6557 log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
6558
6559 self.assertEqual(log_result, log_result_cpu)
6560
6561 helper((2, 8, 4, 5))
6562
6563 def test_logaddexp2(self):
6564 def helper(shape):
6565 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6566 x = cpu_x.detach().clone().to('mps')
6567
6568 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6569 y = cpu_y.detach().clone().to('mps')
6570
6571 log_result = torch.logaddexp2(x, y)
6572 log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
6573
6574 self.assertEqual(log_result, log_result_cpu)
6575
6576 helper((2, 8, 4, 5))
6577
Tobias Ringwald758d7872024-09-03 17:28:36 +00006578 def test_logsumexp(self):
6579 def helper(shape):
6580 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6581 x = cpu_x.detach().clone().to('mps')
6582
6583 log_result = torch.logsumexp(x, -1)
6584 log_result_cpu = torch.logsumexp(cpu_x, -1)
6585
6586 self.assertEqual(log_result, log_result_cpu)
6587
6588 helper((2, 8, 4, 5))
6589
Kulin Sethe011a8e2022-05-13 18:28:53 +00006590 # Test concat forward
6591 def test_cat2(self):
6592
6593 def helper1(shape_x, shape_y, shape_z, shape_w):
6594 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6595 x = cpu_x.detach().clone().to('mps')
6596
6597 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6598 y = cpu_y.detach().clone().to('mps')
6599
6600 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6601 z = cpu_z.detach().clone().to('mps')
6602
6603 cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
6604 w = cpu_w.detach().clone().to('mps')
6605
6606 cat = torch.cat([x, y, z, w], dim=1)
6607 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
6608
6609 self.assertEqual(cat, cat_cpu)
6610
6611 def helper(shape_x, shape_y, shape_z):
6612 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6613 x = cpu_x.detach().clone().to('mps')
6614
6615 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6616 y = cpu_y.detach().clone().to('mps')
6617
6618 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6619 z = cpu_z.detach().clone().to('mps')
6620
6621 cat = torch.cat([x, y, z], dim=1)
6622 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6623
6624 self.assertEqual(cat, cat_cpu)
6625
6626 helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
6627 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
6628 # Empty test - Currently failing! Empty tensor not handled!
6629 # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
6630
6631 # Test isnan
6632 def test_isnan(self):
6633 def helper(shape):
6634 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6635 nan_index = [random.randrange(0, shape[0])]
6636 # make a selected row inf
6637 cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
6638 x = cpu_x.detach().clone().to('mps')
6639
6640 isnan_result = torch.isnan(x)
6641 isnan_result_cpu = torch.isnan(cpu_x)
6642
6643 self.assertEqual(isnan_result, isnan_result_cpu)
6644
6645 helper((8, 2, 4, 5))
6646
6647 # Test reciprocal
6648 def test_reciprocal(self):
6649 def helper(shape):
6650 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6651 x = cpu_x.detach().clone().to('mps').requires_grad_()
6652
6653 reciprocal_result = torch.reciprocal(x)
6654 reciprocal_result_cpu = torch.reciprocal(cpu_x)
6655
6656 cpu_grad = torch.ones_like(reciprocal_result_cpu)
6657 grad = cpu_grad.to('mps')
6658
6659 reciprocal_result.backward(gradient=grad)
6660 reciprocal_result_cpu.backward(gradient=cpu_grad)
6661
6662 self.assertEqual(reciprocal_result, reciprocal_result_cpu)
6663 self.assertEqual(x.grad, cpu_x.grad)
6664
6665 helper((2, 8, 4, 5))
6666
6667 # Test sqrt
6668 def test_sqrt(self):
6669 def helper(shape):
6670 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6671 x = cpu_x.detach().clone().to('mps').requires_grad_()
6672
6673 sqrt_result = torch.sqrt(x)
6674 sqrt_result_cpu = torch.sqrt(cpu_x)
6675
6676 cpu_grad = torch.ones_like(sqrt_result_cpu)
6677 grad = cpu_grad.to('mps')
6678
6679 sqrt_result.backward(gradient=grad)
6680 sqrt_result_cpu.backward(gradient=cpu_grad)
6681
6682 self.assertEqual(sqrt_result, sqrt_result_cpu)
6683 self.assertEqual(x.grad, cpu_x.grad)
6684
6685 helper((2, 8, 4, 5))
6686
6687 # Test selu, elu, celu
6688 def test_elu(self):
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006689 def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
6690 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
6691 cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006692
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006693 x = cpu_x.detach().clone().to('mps').requires_grad_(True)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006694 for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
6695 elu_result = activation_func(x)
6696 elu_result_cpu = activation_func(cpu_x)
6697
6698 cpu_grad = torch.randn(elu_result_cpu.shape)
6699 grad = cpu_grad.to('mps')
6700
6701 elu_result.backward(gradient=grad)
6702 elu_result_cpu.backward(gradient=cpu_grad)
6703
6704 self.assertEqual(elu_result, elu_result_cpu)
6705 self.assertEqual(x.grad, cpu_x.grad)
6706
6707 # Test empty shape too
Denis Vieriu4a762cb2023-02-11 22:05:18 +00006708 for memory_fromat in [torch.channels_last, torch.contiguous_format]:
6709 for shape in [(2, 8, 4, 5)]:
6710 for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
6711 helper(shape, alpha, memory_fromat)
Kulin Setha6347f52022-06-07 18:22:10 +00006712
Denis Vieriu58e045d2024-05-08 01:34:40 +00006713 def test_elu_strided_output(self):
6714 # https://github.com/pytorch/pytorch/issues/124834
6715 elu_input = torch.randn(1, 1024, 500)
6716 alpha = float(1)
6717 inplace = False
6718
6719 elu_input_noncontiguous = elu_input.transpose(1, 2)
6720 self.assertEqual(
6721 F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace),
6722 F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
6723 )
6724
qqaatwc980fc32022-06-30 08:58:42 +00006725 # Test glu
6726 def test_glu(self):
6727 def helper(shape, dim=0):
6728 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6729 x = cpu_x.detach().clone().to('mps').requires_grad_()
Kulin Setha6347f52022-06-07 18:22:10 +00006730
qqaatwc980fc32022-06-30 08:58:42 +00006731 for activation_func in [torch.nn.GLU(dim=dim)]:
6732 glu_result = activation_func(x)
6733 glu_result_cpu = activation_func(cpu_x)
6734
6735 cpu_grad = torch.randn(glu_result_cpu.shape)
6736 grad = cpu_grad.to('mps')
6737
6738 glu_result.backward(gradient=grad)
6739 glu_result_cpu.backward(gradient=cpu_grad)
6740
6741 self.assertEqual(glu_result, glu_result_cpu)
6742 self.assertEqual(x.grad, cpu_x.grad)
6743
6744 for shape in [[4], (2, 4), (2, 8, 4, 6)]:
6745 for dim in range(len(shape)):
6746 helper(shape, dim)
6747
6748 # Test softplus
Kulin Setha6347f52022-06-07 18:22:10 +00006749 def test_softplus(self):
Li-Huai (Allan) Lincce58a42023-05-28 21:52:25 +08006750 def helper(shape, beta, threshold, dtype):
6751 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
Kulin Setha6347f52022-06-07 18:22:10 +00006752 x = cpu_x.detach().clone().to('mps').requires_grad_()
6753
Li-Huai (Allan) Lin7c353eb2022-11-10 09:40:05 +00006754 softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
6755 softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
Kulin Setha6347f52022-06-07 18:22:10 +00006756
qqaatw87451182022-07-06 06:13:21 +00006757 cpu_grad = torch.randn(softplus_result.shape)
6758 grad = cpu_grad.to('mps')
6759
6760 softplus_result.backward(gradient=grad)
6761 softplus_result_cpu.backward(gradient=cpu_grad)
6762
Kulin Setha6347f52022-06-07 18:22:10 +00006763 self.assertEqual(softplus_result, softplus_result_cpu)
qqaatw87451182022-07-06 06:13:21 +00006764 self.assertEqual(x.grad, cpu_x.grad)
Kulin Setha6347f52022-06-07 18:22:10 +00006765
6766 # Test empty shape too
Li-Huai (Allan) Lincce58a42023-05-28 21:52:25 +08006767 for shape, beta, threshold, dtype in product(
6768 [(), (2, 3), (10, 10), (2, 3, 4, 5)],
6769 [0.5, 1, 2, 3, 4],
6770 [0.5, 20, 30, 40, 50],
6771 [torch.float16, torch.float32]
6772 ):
6773 helper(shape, beta, threshold, dtype)
Kulin Setha6347f52022-06-07 18:22:10 +00006774
Kulin Sethe011a8e2022-05-13 18:28:53 +00006775 # Test silu
6776
6777 def test_silu(self):
6778 def helper(shape):
6779 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6780 x = cpu_x.detach().clone().to('mps').requires_grad_()
6781
6782 silu_result = torch.nn.SiLU()(x)
6783 silu_result_cpu = torch.nn.SiLU()(cpu_x)
6784
6785 cpu_grad = torch.randn(silu_result_cpu.shape)
6786 grad = cpu_grad.to('mps')
6787
6788 silu_result.backward(gradient=grad)
6789 silu_result_cpu.backward(gradient=cpu_grad)
6790
6791 self.assertEqual(silu_result, silu_result_cpu)
6792 self.assertEqual(x.grad, cpu_x.grad)
6793
6794 # Test empty shape too
6795 for shape in [[], (2, 3), (2, 8, 4, 5)]:
6796 helper(shape)
6797
Denis Vieriu4247cc92022-09-14 17:24:24 +00006798 def test_cast_mps_to_cpu(self):
6799 def helper(src_dtype, dst_dtype):
6800 input = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6801 input_cast_mps = input.to('mps')
6802 input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype)
6803
6804 # needs to match the initial Tensor
6805 self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype))
6806 helper(torch.half, torch.float)
6807 helper(torch.float, torch.half)
6808
6809 def test_cast_mps_to_mps(self):
6810 def helper(src_dtype, dst_dtype):
6811 input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6812 input_mps = input_cpu.to('mps')
6813 output_mps = input_mps.to(dtype=dst_dtype)
6814 output_cpu = input_cpu.to(dtype=dst_dtype)
6815 self.assertEqual(output_mps.cpu(), output_cpu)
6816 helper(torch.half, torch.float)
6817 helper(torch.float, torch.half)
6818 helper(torch.half, torch.long)
6819 helper(torch.float, torch.int)
6820
Ramin Azarmehr6c80d0a2023-02-09 02:06:40 +00006821 def test_avg_pool2d_count_include_pad(self):
6822 cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True)
6823 x = cpu_x.detach().clone().to('mps').requires_grad_()
6824 pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True)
6825 ref_y = pool(cpu_x)
6826 y = pool(x)
6827 self.assertEqual(y, ref_y)
6828 cpu_grad = torch.randn(ref_y.shape)
6829 grad = cpu_grad.to('mps')
6830 ref_y.backward(gradient=cpu_grad)
6831 y.backward(gradient=grad)
6832 self.assertEqual(x.grad, cpu_x.grad)
6833
Kulin Sethe011a8e2022-05-13 18:28:53 +00006834 # Test adaptive avg pool2d - when the input size is a multiple of output size
6835 # Not testing for channels last right now
6836 def test_adaptive_avg_pool2d_simple(self):
6837 def helper(input_shape, out_shape, channels_last):
6838 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00006839 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00006840 cpu_x = cpu_x.to(memory_format=torch.channels_last)
6841 cpu_x.retain_grad()
6842 x = cpu_x.detach().clone().to('mps').requires_grad_()
6843
6844 avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
6845 avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
6846
6847 cpu_grad = torch.randn(avg_result_cpu.shape)
6848 grad = cpu_grad.to('mps')
6849
6850 avg_result.backward(gradient=grad)
6851 avg_result_cpu.backward(gradient=cpu_grad)
6852
6853 self.assertEqual(avg_result, avg_result_cpu)
6854 self.assertEqual(x.grad, cpu_x.grad)
6855
6856 helper((2, 2, 4, 4), (2, 2), False)
6857 helper((2, 2, 9, 9), (3, 3), False)
6858 helper((2, 2, 9, 9), (9, 9), False)
6859 helper((2, 2, 16, 16), (2, 2), False)
6860 helper((2, 2, 16, 16), (2, 16), False)
6861
6862 helper((2, 16, 16), (4, 4), False)
6863
Abhishek Pathake746fff2022-09-27 19:08:22 +00006864 # Output shape larger than input shape
6865
6866 helper((2, 2, 4, 4), (8, 8), False)
6867 helper((2, 2, 2, 2), (4, 4), False)
6868 helper((2, 2, 3, 3), (9, 9), False)
6869 helper((2, 2, 2, 2), (16, 16), False)
6870 helper((2, 2, 2, 16), (16, 16), False)
6871
6872 helper((2, 4, 4), (16, 16), False)
6873
6874 try:
6875 helper((2, 2, 3, 3), (7, 7), False)
6876 except Exception as e:
6877 pass
6878
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006879 # Test max avg pool2d - when the input size is a multiple of output size
6880 # Not testing for channels last right now
6881 def test_adaptive_max_pool2d_simple(self):
6882 def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
6883 cpu_x = None
Thomas4935b592022-11-23 02:18:03 +00006884 if (dtype in [torch.float16, torch.float32]):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006885 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
6886 else:
6887 cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00006888 if (channels_last):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006889 cpu_x = cpu_x.to(memory_format=torch.channels_last)
6890 cpu_x.retain_grad()
6891 x = cpu_x.detach().clone().to('mps').requires_grad_()
6892
6893 max_result, max_indices = None, None
6894 max_result_cpu, max_indices_cpu = None, None
6895
Thomas4935b592022-11-23 02:18:03 +00006896 if (return_indices):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006897 max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6898 max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6899 else:
6900 max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6901 max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6902
6903 cpu_grad = torch.randn(max_result_cpu.shape)
6904 grad = cpu_grad.to('mps')
6905
6906 max_result.backward(gradient=grad)
6907 max_result_cpu.backward(gradient=cpu_grad)
6908
6909 self.assertEqual(max_result, max_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00006910 if (return_indices):
Kulin Seth2e32d5f2022-05-27 11:59:07 +00006911 self.assertEqual(max_indices, max_indices_cpu)
6912 self.assertEqual(x.grad, cpu_x.grad)
6913
6914 for dtype in [torch.float32]:
6915 for return_indices in [False, True]:
6916 helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
6917 helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
6918 helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
6919 helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
6920 helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
6921 helper((2, 16, 16), (4, 4), return_indices, dtype)
6922
Kulin Sethe011a8e2022-05-13 18:28:53 +00006923 def test_gelu_simple(self):
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006924 def helper(shape, dtype=torch.float, contiguous=True):
6925 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6926 x = cpu_x.detach().clone().to('mps')
6927
6928 if not contiguous and (0 not in shape and len(shape) >= 2):
6929 # Tranposing will make the tensor non-contiguous
6930 cpu_x = cpu_x.transpose(0, 1)
6931 x = x.transpose(0, 1)
6932 assert not x.is_contiguous()
6933
6934 cpu_x.requires_grad_()
6935 x.requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00006936
6937 gelu_result = torch.nn.GELU()(x)
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006938 # GELU is not supported on CPU, so cast it to float
6939 gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float))
Kulin Sethe011a8e2022-05-13 18:28:53 +00006940
6941 cpu_grad = torch.ones_like(gelu_result_cpu)
6942 grad = cpu_grad.to('mps')
6943
6944 gelu_result.backward(gradient=grad)
6945 gelu_result_cpu.backward(gradient=cpu_grad)
6946
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006947 atol = 1e-5 if dtype == torch.float else 1e-2
6948 rtol = 1e-3 if dtype == torch.float else 1e-2
6949 self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006950
6951 assert x.grad is not None # Check that the grad is well-populated
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006952 self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00006953
6954 # Test empty shape too
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006955 for dtype in [torch.float, torch.half]:
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006956 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6957 for contiguous in [True, False]:
6958 helper(shape, dtype, contiguous)
Nikita Shulga97d2e1d2022-10-05 09:09:17 -07006959 # Test that gelu would raise an assert for integral types
6960 for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
6961 self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
Kulin Sethe011a8e2022-05-13 18:28:53 +00006962
Joël Tanga6a3f2e2024-04-21 00:12:29 +00006963 def test_mish_simple(self):
6964 def helper(shape, dtype=torch.float, contiguous=True):
6965 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6966 x = cpu_x.detach().clone().to('mps')
6967
6968 if not contiguous and (0 not in shape and len(shape) >= 2):
6969 # Tranposing will make the tensor non-contiguous
6970 cpu_x = cpu_x.transpose(0, 1)
6971 x = x.transpose(0, 1)
6972 assert not x.is_contiguous()
6973
6974 cpu_x.requires_grad_()
6975 x.requires_grad_()
6976
6977 mish_result = torch.nn.Mish()(x)
6978 mish_result_cpu = torch.nn.Mish()(cpu_x)
6979
6980 cpu_grad = torch.ones_like(mish_result_cpu)
6981 grad = cpu_grad.to('mps')
6982
6983 mish_result.backward(gradient=grad)
6984 mish_result_cpu.backward(gradient=cpu_grad)
6985
6986 atol = 1e-5 if dtype == torch.float else 1e-2
6987 rtol = 1e-3 if dtype == torch.float else 1e-2
6988 self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
6989
6990 assert x.grad is not None # Check that the grad is well-populated
6991 self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
6992
6993 # Test empty shape too
6994 for dtype in [torch.float, torch.half]:
6995 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6996 for contiguous in [True, False]:
6997 helper(shape, dtype, contiguous)
6998
Kulin Seth3d833212022-05-20 03:18:09 +00006999 def test_gelu(self):
7000 def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
7001 numpy_dtype = {
7002 torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
7003 }[dtype]
7004 devices = ['cpu']
7005 devices += ['mps']
7006
7007 def _gelu_ref(X):
Aaron Gokaslanbd10fea2024-01-01 08:40:46 +00007008 return X * stats.norm.cdf(X) # noqa: F821
Kulin Seth3d833212022-05-20 03:18:09 +00007009
7010 for d in devices:
7011 X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
7012 res = X
7013 ref = (X.to(numpy_dtype).cpu().detach().numpy())
7014 self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
7015
Alban Desmaisonbde246f2022-05-30 10:36:31 -04007016 for n in [1, 5, 10]:
7017 for m in [1, 5, 10]:
Kulin Seth3d833212022-05-20 03:18:09 +00007018 _test_gelu(n, m, torch.float32, True)
7019 _test_gelu(n, m, torch.float32, False)
7020
7021 # Test multi threaded
7022 num_threads = torch.get_num_threads()
7023 torch.set_num_threads(4)
7024 try:
7025 _test_gelu(32, 32, torch.float32, False)
7026 finally:
7027 torch.set_num_threads(num_threads)
7028
Denis Vieriu7ce785b2023-02-11 00:24:30 +00007029 def test_gelu_tanh(self):
7030 def helper(shape):
7031 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
7032 x = cpu_x.detach().clone().to('mps')
7033
7034 gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
7035 gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
7036 self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
7037
7038 helper((2, 8, 4, 5))
7039
Kulin Sethe011a8e2022-05-13 18:28:53 +00007040 # Test hardtanh
7041 def test_hardtanh(self):
7042 def helper(shape, min_val, max_val, inplace=False):
7043 cpu_x = None
7044 x = None
7045
Thomas4935b592022-11-23 02:18:03 +00007046 if (not inplace):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007047 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7048 x = cpu_x.detach().clone().to('mps').requires_grad_()
7049 else:
7050 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7051 x = cpu_x.detach().clone().to('mps')
7052
7053 hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
7054 hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
7055
7056 self.assertEqual(hardtanh_result, hardtanh_result_cpu)
7057
Thomas4935b592022-11-23 02:18:03 +00007058 if (not inplace):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007059 cpu_grad = torch.randn(hardtanh_result_cpu.shape)
7060 grad = cpu_grad.to('mps')
7061 hardtanh_result.backward(gradient=grad)
7062 hardtanh_result_cpu.backward(gradient=cpu_grad)
7063 self.assertEqual(x.grad, cpu_x.grad)
7064
7065 # Test empty shape too
7066 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
7067 for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
7068 helper(shape, min_val, max_val)
7069 helper(shape, min_val, max_val, inplace=True)
7070
Thomas4935b592022-11-23 02:18:03 +00007071 def test_hardswish(self):
7072 def helper(shape, inplace=False, requires_grad=True):
7073 m = nn.Hardswish(inplace=inplace)
7074
7075 input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad)
7076 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
7077
7078 if inplace and requires_grad: # check that both raise runtime error
7079 self.assertRaises(RuntimeError, lambda: m(input_cpu))
7080 self.assertRaises(RuntimeError, lambda: m(input_mps))
7081 return
7082
7083 output_cpu = m(input_cpu)
7084 output_mps = m(input_mps)
7085
7086 cpu_grad = torch.ones_like(output_cpu)
7087 mps_grad = cpu_grad.to('mps')
7088
7089 self.assertEqual(output_cpu, output_mps)
7090
7091 if requires_grad:
7092 output_cpu.backward(gradient=cpu_grad)
7093 output_mps.backward(gradient=mps_grad)
7094
7095 self.assertEqual(input_cpu.grad, input_mps.grad)
7096
7097 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
7098 helper(shape, inplace=False, requires_grad=False)
7099 helper(shape, inplace=True, requires_grad=False)
7100 helper(shape, inplace=False, requires_grad=True)
7101 helper(shape, inplace=True, requires_grad=True)
7102
Kulin Seth3d833212022-05-20 03:18:09 +00007103 def test_transpose_2D(self):
7104 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
7105 values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
7106 cpu_x = torch.tensor(values, device='cpu')
7107 mps_x = torch.tensor(values, device='mps')
7108 mps_x1 = torch.tensor(values1, device='mps')
7109
7110 cpu_transpose = torch.transpose(cpu_x, 0, 1)
7111 mps_transpose = torch.transpose(mps_x, 0, 1)
7112 self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
7113
7114 def test_transpose_3D(self):
7115 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
7116 cpu_x = torch.tensor(values, device='cpu')
7117 mps_x = torch.tensor(values, device='mps')
7118
7119 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
7120 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
7121 self.assertEqual(cpu_transpose1, mps_transpose1)
7122
7123 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
7124 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
7125 self.assertEqual(cpu_transpose2, mps_transpose2)
7126
7127 cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
7128 mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
7129 self.assertEqual(cpu_transpose3, mps_transpose3)
7130
7131
7132 def test_transpose_4D(self):
7133 values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
7134 [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
7135 cpu_x = torch.tensor(values, device='cpu')
7136 mps_x = torch.tensor(values, device='mps')
7137
7138 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
7139 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
7140 self.assertEqual(cpu_transpose1, mps_transpose1)
7141
7142 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
7143 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
7144 self.assertEqual(cpu_transpose2, mps_transpose2)
7145
7146 cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
7147 mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
7148 self.assertEqual(cpu_transpose3, mps_transpose3)
7149
7150 cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
7151 mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
7152 self.assertEqual(cpu_transpose4, mps_transpose4)
7153
7154 cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
7155 mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
7156 self.assertEqual(cpu_transpose5, mps_transpose5)
7157
7158 cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
7159 mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
7160 self.assertEqual(cpu_transpose6, mps_transpose6)
7161
Kulin Sethe011a8e2022-05-13 18:28:53 +00007162 # Test sign
7163 def test_sign(self):
7164 def helper(shape):
7165 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7166 x = cpu_x.detach().clone().to('mps').requires_grad_()
7167
7168 sign_result = torch.sign(x)
7169 sign_result_cpu = torch.sign(cpu_x)
7170
7171 cpu_grad = torch.ones_like(sign_result_cpu)
7172 grad = cpu_grad.to('mps')
7173
7174 sign_result.backward(gradient=grad)
7175 sign_result_cpu.backward(gradient=cpu_grad)
7176
7177 self.assertEqual(sign_result, sign_result_cpu)
7178
7179 helper((2, 8, 4, 5))
7180
Daniel Falbele8185742022-10-25 07:12:28 +00007181 def test_signbit(self):
7182 def helper(shape, dtype):
7183 cpu_x = torch.randn(shape, device='cpu').to(dtype)
7184 x = cpu_x.clone().to('mps')
7185
7186 signbit_result = torch.signbit(x)
7187 signbit_result_cpu = torch.signbit(cpu_x)
7188
7189 self.assertEqual(signbit_result, signbit_result_cpu)
7190
7191 helper((2, 8, 4, 5), torch.int)
7192 helper((2, 8, 4, 5), torch.float)
7193 helper((2, 8, 4, 5), torch.int64)
7194
Kulin Sethe011a8e2022-05-13 18:28:53 +00007195 # Test neg
7196 def test_neg(self):
7197 def helper(shape):
7198 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7199 x = cpu_x.detach().clone().to('mps').requires_grad_()
7200
7201 neg_result = torch.neg(x)
7202 neg_result_cpu = torch.neg(cpu_x)
7203
7204 cpu_grad = torch.ones_like(neg_result_cpu)
7205 grad = cpu_grad.to('mps')
7206
7207 neg_result.backward(gradient=grad)
7208 neg_result_cpu.backward(gradient=cpu_grad)
7209
7210 self.assertEqual(neg_result, neg_result_cpu)
7211
7212 helper((2, 8, 4, 5))
7213
Nikita Shulga01e6d642023-07-05 23:17:43 +00007214 def test_neg_strided_input(self):
7215 # See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337
7216 x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
7217 y = x.permute(1, 0, 2)[..., 1]
7218 z = y + y.neg()
7219 self.assertEqual(z.abs().max().item(), 0.0)
7220
qqaatw1caa25e2022-07-14 23:40:00 +00007221 # Test index add
7222 def test_index_add(self):
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00007223 def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):
7224 cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False)
qqaatw1caa25e2022-07-14 23:40:00 +00007225 x = cpu_x.detach().clone().to('mps')
7226
7227 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7228 idx = cpu_idx.detach().clone().to('mps')
7229
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00007230 cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False)
qqaatw1caa25e2022-07-14 23:40:00 +00007231 source = cpu_source.detach().clone().to('mps')
7232
7233 idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha)
7234 idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha)
7235 self.assertEqual(idx_result, idx_result_cpu)
7236
7237 helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5)
7238 helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0)
7239 helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5)
7240 helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0)
7241 helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4)
7242 helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0)
7243 # test result dim=1
7244 helper((2,), 0, [1], (1,), 6.0)
7245 helper(2, 0, 1, 1, 6)
Li-Huai (Allan) Linb7f35e42022-12-21 05:31:00 +00007246 # test float16
7247 helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16)
qqaatw1caa25e2022-07-14 23:40:00 +00007248
Nikita Shulga67840302024-01-09 06:49:45 -08007249 def test_index_64bit(self):
7250 """ Test that index operations work for 4Gb+ tensors """
7251 if product_version < 14.0:
7252 raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039")
7253 # Cleanup memory
7254 gc.collect()
7255 torch.mps.empty_cache()
7256 # Check that index operations work for 4+GB tensors
7257 x = torch.rand(16000, 67120, device="mps")
7258 self.assertGreater(x.element_size() * x.numel(), 2**32)
7259 idx = torch.arange(0, 2, device="mps")
7260 x_sampled = x[:, idx]
7261 self.assertEqual(x[:, 0], x_sampled[:, 0])
7262 # Reclaim memory after running the tests
7263 del x
7264 gc.collect()
7265 torch.mps.empty_cache()
7266
Nikita Shulga18728342024-01-17 01:33:08 +00007267 def test_mm_large(self):
7268 """ Test that MM works for matrices with index larger than 32K """
7269 x = torch.rand(10, 1, device="mps")
7270 y = torch.rand(1, 32769, device="mps")
7271 # This used to crash with:
7272 # error: subRange.start (24576) is not less than length of dimension[0] (16384)
7273 # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
7274 self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
Nikita Shulga24dd9f42024-02-01 17:53:38 +00007275
Nikita Shulga54988042024-03-13 14:34:03 +00007276 def compare_mm(m, n, k, dtype=torch.float):
7277 x = torch.rand(m, n, device="mps", dtype=dtype)
7278 y = torch.rand(n, k, device="mps", dtype=dtype)
Nikita Shulga24dd9f42024-02-01 17:53:38 +00007279 z = torch.mm(x, y).cpu()
7280 z_cpu = torch.mm(x.cpu(), y.cpu())
7281 self.assertEqual(z, z_cpu)
7282
7283 # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
7284 compare_mm(1024, 1, 32769)
7285 # one more time, but with dimensions inverted
7286 # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
7287 compare_mm(32769, 1, 1025)
Nikita Shulga18728342024-01-17 01:33:08 +00007288
Nikita Shulga54988042024-03-13 14:34:03 +00007289 if product_version >= 14.0:
7290 # Test bfloat16 mm
7291 compare_mm(1024, 1, 32769, torch.bfloat16)
7292
Nikita Shulgaabf3f902024-04-22 23:43:11 +00007293 @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test")
7294 @unittest.skipIf(product_version < 14.0, "Can't allocate 4Gb tensor on MacOS 13")
7295 def test_copy_large(self):
7296 """ Test that copy of 4Gb+ tensors works """
7297 x = torch.ones((2**30 + 11,), dtype=torch.float32)
7298 y = x.to(device="mps")
7299 self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps")))
7300 del y
7301 del x
7302
qqaatwc4da23e2022-06-28 19:51:43 +00007303 # Test flip
7304 def test_flip(self):
7305 def helper(shape, dims):
7306 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7307 x = cpu_x.detach().clone().to('mps')
7308
7309 flip_result = torch.flip(x, dims=dims)
7310 flip_result_cpu = torch.flip(cpu_x, dims=dims)
7311
7312 self.assertEqual(flip_result, flip_result_cpu)
7313
7314 helper((2, 8, 4, 5), [0])
7315 helper((8, 8, 4, 5), [0, 1])
7316 helper((2, 8, 4, 5), (0, 1, 2, 3))
7317 helper((2, 3, 3), (-1,))
7318 # empty dims
7319 helper((2, 8, 4, 5), [])
7320 # input.numel() == 1
7321 helper((1,), (0,))
7322 # input.numel() == 0
7323 helper((0,), (0,))
Li-Huai (Allan) Linc95bcb62023-03-14 00:34:26 +00007324 # none of dims that needs to be flipped
7325 helper((1, 3), [0])
qqaatwc4da23e2022-06-28 19:51:43 +00007326
Kulin Sethe011a8e2022-05-13 18:28:53 +00007327 # Test index select
7328 def test_index_select(self):
7329 def helper(shape, dim, index, idx_dtype=torch.int32):
7330 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7331 x = cpu_x.detach().clone().to('mps')
7332
7333 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7334 idx = cpu_idx.detach().clone().to('mps')
7335
Kulin Sethe011a8e2022-05-13 18:28:53 +00007336 idx_result = torch.index_select(x, dim=dim, index=idx)
7337 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7338
7339 self.assertEqual(idx_result, idx_result_cpu)
7340
7341 helper((2, 8, 4, 5), 0, [1])
7342 helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
7343 helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
7344 helper((2, 8, 4, 5), 2, [3, 0, 1])
7345 helper((2, 8, 4, 5), 3, [2, 3, 0])
7346 helper((2, 3, 3), -1, [1, 2])
Li-Huai (Allan) Linccbdf492023-01-19 14:08:02 +00007347 helper((), 0, [0])
Nikita Shulga8a888522023-02-05 05:45:57 +00007348 helper((5), 0, [])
Li-Huai (Allan) Linccbdf492023-01-19 14:08:02 +00007349
7350 def test_index_select_scalar(self):
7351 def helper(value, dim, index, idx_dtype=torch.int32):
7352 cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
7353 x = cpu_x.detach().clone().to('mps')
7354
7355 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7356 idx = cpu_idx.detach().clone().to('mps')
7357
7358 idx_result = torch.index_select(x, dim=dim, index=idx)
7359 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7360
7361 self.assertEqual(idx_result, idx_result_cpu)
7362
Li-Huai (Allan) Lin4afef852023-03-28 19:23:55 +00007363 helper(22, 0, [0])
7364 with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
7365 helper(22, 0, [])
Kulin Sethe011a8e2022-05-13 18:28:53 +00007366
7367 def test_embedding_dense_backward(self):
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007368 def helper(n, d, m, idx):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007369 embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
Nikita Shulga62ef15e2022-11-10 23:52:27 +00007370 emedding_weight = embeddingMPS.weight.detach().cpu()
Kulin Sethe011a8e2022-05-13 18:28:53 +00007371 W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
Nikita Shulga62ef15e2022-11-10 23:52:27 +00007372 idx_MPS = torch.tensor(idx, device='mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007373 a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable
7374 a_MPS.retain_grad()
7375 b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place
7376 b_MPS.retain_grad()
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007377 out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007378 loss_MPS = out_MPS.sigmoid().prod()
7379 loss_MPS.backward()
7380
Nikita Shulga62ef15e2022-11-10 23:52:27 +00007381 embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007382 W_CPU = W_MPS.to('cpu')
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007383 idx_CPU = torch.tensor(idx)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007384 a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable
7385 a_CPU.retain_grad()
7386 b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place
7387 b_CPU.retain_grad()
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007388 out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007389 loss_CPU = out_CPU.sigmoid().prod()
7390 loss_CPU.backward()
7391
7392 self.assertEqual(b_CPU.grad, b_MPS.grad)
7393 self.assertEqual(a_CPU.grad, a_MPS.grad)
7394
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007395 helper(3, 5, 7, [0, 1, 2])
Li-Huai (Allan) Lin330c9072023-05-21 13:47:46 +08007396 helper(3, 6, 7, [0, 1, 2]) # verify if changes in shape would cause cached graph lookup problems
Li-Huai (Allan) Lin15e54292022-11-04 19:43:56 +00007397 helper(3, 5, 7, 2) # test scalar index
Kulin Sethe011a8e2022-05-13 18:28:53 +00007398
7399 # Test pytorch gather
7400 def test_gather(self):
7401 def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
7402 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7403 x = cpu_x.detach().clone().to('mps').requires_grad_()
7404
7405 # Indices should be taken from range of axis along which gathering is done
7406 idx_np = np.random.randint(0, shape[dim], idx_shape)
7407
7408 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7409 idx = cpu_idx.detach().clone().to('mps')
7410
7411 gather_result = torch.gather(x, dim=dim, index=idx)
7412 gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
7413
7414 cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
7415 grad = cpu_grad.to('mps')
7416 gather_result.backward(gradient=grad)
7417 gather_result_cpu.backward(gradient=cpu_grad)
7418
7419 self.assertEqual(gather_result, gather_result_cpu)
7420 self.assertEqual(cpu_x.grad, x.grad)
7421
7422 helper((6, 3, 3), 0, (3, 3, 3))
7423 helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
7424 helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
7425 helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
7426 helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
7427 helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
7428 helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
7429 helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
7430 helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
7431
Abhishek Pathak81b366a2022-09-30 00:24:16 +00007432 # Test pytorch gather
7433 def test_gather_scalar(self):
7434 idx_dtype = torch.int64
7435 cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7436 x = cpu_x.detach().clone().to('mps').requires_grad_()
7437
7438 idx_np = [0]
7439
7440 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7441 idx = cpu_idx.detach().clone().to('mps')
7442
7443 gather_result = torch.gather(x, dim=0, index=idx)
7444 gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
7445
7446 cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
7447 grad = cpu_grad.to('mps')
7448 gather_result.backward(gradient=grad)
7449 gather_result_cpu.backward(gradient=cpu_grad)
7450
7451 self.assertEqual(gather_result, gather_result_cpu)
7452 self.assertEqual(cpu_x.grad, x.grad)
7453
Kulin Sethe011a8e2022-05-13 18:28:53 +00007454 # Test pytorch scatter_add and scatter
7455 def test_scatter_add(self):
7456 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
7457 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7458 x = cpu_x.detach().clone().to('mps').requires_grad_()
7459
7460 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7461 src = cpu_src.detach().clone().to('mps').requires_grad_()
7462
7463 # Indices should be taken from range of axis along which gathering is done
7464 idx_np = None
Thomas4935b592022-11-23 02:18:03 +00007465 if (do_add):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007466 idx_np = np.random.randint(0, shape[dim], idx_shape)
7467 else:
7468 idx_np = np.array([[0, 1, 2],
7469 [1, 2, 3],
7470 [2, 3, 4],
7471 [3, 4, 5],
7472 [4, 5, 6]])
7473
7474 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7475 idx = cpu_idx.detach().clone().to('mps')
7476
7477 scatter_result = None
7478 scatter_result_cpu = None
7479
Thomas4935b592022-11-23 02:18:03 +00007480 if (do_add):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007481 scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
7482 scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7483 else:
7484 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
7485 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7486
7487 cpu_grad = None
7488 grad = None
7489
Thomas4935b592022-11-23 02:18:03 +00007490 if (idx_shape == src_shape):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007491 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7492 grad = cpu_grad.to('mps')
7493 scatter_result.backward(gradient=grad)
7494 scatter_result_cpu.backward(gradient=cpu_grad)
7495
7496 self.assertEqual(scatter_result, scatter_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00007497 if (idx_shape == src_shape):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007498 self.assertEqual(cpu_x.grad, x.grad)
7499 self.assertEqual(cpu_src.grad, src.grad)
7500
7501 helper((2, 3), 0, (5, 3), (5, 3))
7502 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7503 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7504 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
7505 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
7506 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
7507
7508 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
7509 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
7510 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
7511 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
7512
7513 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
7514 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
7515 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
7516
7517 # Test scatter src
7518 helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
7519 helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
7520
Abhishek Pathak81b366a2022-09-30 00:24:16 +00007521 # Test pytorch scatter_add and scatter for scalar input
7522 def test_scatter_add_scalar(self):
7523 def helper(idx_dtype=torch.int64, do_add=True):
7524 cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
7525 x = cpu_x.detach().clone().to('mps').requires_grad_()
7526
7527 cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7528 src = cpu_src.detach().clone().to('mps').requires_grad_()
7529
7530 # Indices should be taken from range of axis along which gathering is done
7531 idx_np = [0]
7532
7533 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7534 idx = cpu_idx.detach().clone().to('mps')
7535
7536 scatter_result = None
7537 scatter_result_cpu = None
7538
Thomas4935b592022-11-23 02:18:03 +00007539 if (do_add):
Abhishek Pathak81b366a2022-09-30 00:24:16 +00007540 scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
7541 scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7542 else:
7543 scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
7544 scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7545
7546 cpu_grad = None
7547 grad = None
7548
7549 cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
7550 grad = cpu_grad.to('mps')
7551 scatter_result.backward(gradient=grad)
7552 scatter_result_cpu.backward(gradient=cpu_grad)
7553
7554 self.assertEqual(scatter_result, scatter_result_cpu)
7555 self.assertEqual(cpu_x.grad, x.grad)
7556 self.assertEqual(cpu_src.grad, src.grad)
7557
7558 helper()
7559 helper(do_add=False)
7560
Kulin Sethe011a8e2022-05-13 18:28:53 +00007561 # Test pytorch scatter_reduce
7562 def test_scatter_reduce(self):
7563 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
7564 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7565 x = cpu_x.detach().clone().to('mps').requires_grad_()
7566
7567 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7568 src = cpu_src.detach().clone().to('mps').requires_grad_()
7569
7570 # Indices should be taken from range of axis along which gathering is done
7571 idx_np = np.random.randint(0, shape[dim], idx_shape)
7572
7573 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7574 idx = cpu_idx.detach().clone().to('mps')
7575
7576 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
7577 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
7578
7579 self.assertEqual(scatter_result, scatter_result_cpu)
7580
7581 # for reduce in ["sum", "prod", "amax", "amin"]:
Denis Vieriu4acdc442023-02-13 23:31:06 +00007582 for reduce_type in ["add", "multiply"]:
7583 helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
7584 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7585 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7586 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7587 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7588 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007589
Denis Vieriu4acdc442023-02-13 23:31:06 +00007590 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
7591 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
7592 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
7593 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007594
Denis Vieriu4acdc442023-02-13 23:31:06 +00007595 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
7596 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
7597 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007598
7599 def test_is_nonzero(self):
7600 self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
7601 self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
7602 self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
7603 self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
7604
7605 # Test triu
7606 def test_triu(self):
7607 def helper(shape, diag=0):
7608 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7609 x = cpu_x.detach().clone().to('mps').requires_grad_()
7610
7611 triu_result = torch.triu(x, diag)
7612 triu_result_cpu = torch.triu(cpu_x, diag)
7613
7614 cpu_grad = torch.randn(triu_result_cpu.shape)
7615 grad = cpu_grad.to('mps')
7616
7617 triu_result.backward(gradient=grad)
7618 triu_result_cpu.backward(gradient=cpu_grad)
7619
7620 self.assertEqual(triu_result, triu_result_cpu)
7621 self.assertEqual(x.grad, cpu_x.grad)
7622
7623 helper((2, 8, 4, 5))
7624 helper((2, 8, 4, 5), diag=1)
7625 helper((2, 8, 4, 5), diag=2)
7626 helper((2, 8, 4, 5), diag=3)
7627 helper((2, 8, 4, 5), diag=-1)
7628 helper((2, 8, 4, 5), diag=-2)
7629 helper((2, 8, 4, 5), diag=-3)
7630
Kulin Seth8ecb49b2022-12-19 22:00:07 +00007631 # Test inverse
7632 def test_inverse(self):
7633 def helper(n):
7634 cpu_input = torch.randn(n, n, device='cpu')
7635 mps_input = cpu_input.to('mps')
7636
7637 cpu_result = torch.linalg.inv(cpu_input)
7638 mps_result = torch.linalg.inv(mps_input)
7639 self.assertEqual(cpu_result, mps_result)
7640
7641 helper(2)
7642 helper(6)
7643 helper(3)
7644 helper(8)
7645
Kulin Sethe011a8e2022-05-13 18:28:53 +00007646 # Test tril
7647 def test_tril(self):
7648 def helper(shape, diag=0):
7649 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7650 x = cpu_x.detach().clone().to('mps').requires_grad_()
7651
7652 tril_result = torch.tril(x, diag)
7653 tril_result_cpu = torch.tril(cpu_x, diag)
7654
7655 cpu_grad = torch.randn(tril_result_cpu.shape)
7656 grad = cpu_grad.to('mps')
7657
7658 tril_result.backward(gradient=grad)
7659 tril_result_cpu.backward(gradient=cpu_grad)
7660
7661 self.assertEqual(tril_result, tril_result_cpu)
7662 self.assertEqual(x.grad, cpu_x.grad)
7663
7664 helper((2, 8, 4, 5))
7665 helper((2, 8, 4, 5), diag=1)
7666 helper((2, 8, 4, 5), diag=2)
7667 helper((2, 8, 4, 5), diag=3)
7668 helper((2, 8, 4, 5), diag=-1)
7669 helper((2, 8, 4, 5), diag=-2)
7670 helper((2, 8, 4, 5), diag=-3)
7671
Kulin Seth8552acb2022-05-27 17:07:02 +00007672 # test eye
7673 def test_eye(self):
7674 def helper(n, m, dtype):
7675 cpu_result = None
7676 result = None
7677
Thomas4935b592022-11-23 02:18:03 +00007678 if (n == m):
Kulin Seth8552acb2022-05-27 17:07:02 +00007679 cpu_result = torch.eye(n, dtype=dtype, device='cpu')
7680 result = torch.eye(n, dtype=dtype, device='mps')
7681 else:
7682 cpu_result = torch.eye(n, m, device='cpu')
7683 result = torch.eye(n, m, device='mps')
7684
7685 self.assertEqual(result, cpu_result)
7686
Li-Huai (Allan) Lin100641aa2023-03-20 18:08:36 +00007687 for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]:
Kulin Seth8552acb2022-05-27 17:07:02 +00007688 helper(2, 2, dtype)
7689 helper(2, 3, dtype)
7690 helper(0, 2, dtype)
7691 helper(0, 0, dtype)
7692 helper(3, 8, dtype)
7693 helper(8, 3, dtype)
7694
Kulin Sethe011a8e2022-05-13 18:28:53 +00007695 # Test diag
7696 def test_diag(self):
7697 def helper(shape, diag=0):
7698 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7699 x = cpu_x.detach().clone().to('mps').requires_grad_()
7700
7701 diag_result = torch.diag(x, diag)
7702 diag_result_cpu = torch.diag(cpu_x, diag)
7703
7704 # cpu_grad = torch.randn(diag_result_cpu.shape)
7705 # grad = cpu_grad.to('mps')
7706
7707 # diag_result.backward(gradient=grad)
7708 # diag_result_cpu.backward(gradient=cpu_grad)
7709
7710 self.assertEqual(diag_result, diag_result_cpu)
7711 # self.assertEqual(x.grad, cpu_x.grad)
7712
7713 for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
7714 for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
7715 helper(shape, diag=diag)
7716
Kulin Setha3bdafe2022-06-01 13:47:14 +00007717 # Test linspace
7718 def test_linspace(self):
7719 def helper(start, end, steps, dtype=torch.float32):
7720 cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
7721 result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
7722 self.assertEqual(cpu_result, result)
7723
7724 for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
7725 helper(2, 5, 10, dtype)
7726 helper(2, 2, 10, dtype)
7727 helper(5, 2, 10, dtype)
7728 helper(2, 2, 0, dtype)
7729
Nikita Shulga55cac222022-06-03 21:54:41 +00007730 # Test argange
7731 def test_arange(self):
7732 self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
7733 self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
7734 self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
7735 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
7736
Kulin Sethf35f1232023-02-09 19:30:14 +00007737 def test_arange_empty(self):
7738 out_mps = torch.tensor([], device="mps")
7739 out_cpu = torch.tensor([], device="cpu")
7740
7741 y_mps = torch.arange(0, 0, 1, out=out_mps)
7742 y_cpu = torch.arange(0, 0, 1, out=out_cpu)
7743 self.assertEqual(y_mps, y_cpu)
7744
OwenPendrighElliott840fb742023-02-13 23:19:06 +00007745 # Test rgange
7746 def test_range(self):
7747 self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
7748 self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
7749 self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
7750 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
7751
Kulin Sethe011a8e2022-05-13 18:28:53 +00007752 # Test softmax
7753 def test_softmax(self):
7754 def helper(shape, dim, channels_last=False):
7755 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
Thomas4935b592022-11-23 02:18:03 +00007756 if (channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007757 cpu_x = cpu_x.to(memory_format=torch.channels_last)
7758 cpu_x.retain_grad()
7759 x = cpu_x.detach().clone().to('mps').requires_grad_()
7760
7761 softmax_result = torch.nn.functional.softmax(x, dim=dim)
7762 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7763
7764 # Currently NOT testing backward for channels last backward
7765 cpu_grad = None
7766 grad = None
7767
Thomas4935b592022-11-23 02:18:03 +00007768 if (not channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007769 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7770 grad = cpu_grad.to('mps')
7771
7772 softmax_result.backward(gradient=grad)
7773 softmax_result_cpu.backward(gradient=cpu_grad)
7774
7775 self.assertEqual(softmax_result, softmax_result_cpu)
Thomas4935b592022-11-23 02:18:03 +00007776 if (not channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007777 self.assertEqual(x.grad, cpu_x.grad)
7778
7779 def helper2(dim):
7780 cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
7781 x = cpu_x.detach().clone().to('mps').requires_grad_()
7782
7783 softmax_result = torch.nn.functional.softmax(x, dim=dim)
7784 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7785
7786 cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
7787 grad = cpu_grad.to('mps')
7788
7789 softmax_result.backward(gradient=grad)
7790 softmax_result_cpu.backward(gradient=cpu_grad)
7791
7792 self.assertEqual(softmax_result, softmax_result_cpu)
7793 self.assertEqual(x.grad, cpu_x.grad)
7794
7795 helper2(0)
7796
Kulin Seth3d833212022-05-20 03:18:09 +00007797 for channels_last in [False]:
Kulin Sethe011a8e2022-05-13 18:28:53 +00007798 for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
Thomas4935b592022-11-23 02:18:03 +00007799 if (len(shape) != 4 and channels_last):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007800 continue
7801 for dim in [0, 1, 2, 3, -1, -2, -3]:
7802 helper(shape, dim, channels_last)
7803
Ramin Azarmehr229f12b2023-01-05 02:17:48 +00007804 def test_nan_to_num(self):
7805 inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
7806 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
7807 outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0)
7808 outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0)
7809 self.assertEqual(outputMPS, outputCPU)
7810
Kulin Sethe011a8e2022-05-13 18:28:53 +00007811 # Test where
7812 def test_where(self):
7813 def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
7814
7815 cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
7816 cond = cpu_cond.detach().clone().to('mps')
7817
7818 cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7819 x = cpu_x.detach().clone().to('mps').requires_grad_()
7820
7821 cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7822 y = cpu_y.detach().clone().to('mps').requires_grad_()
7823
7824 cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
7825 out = torch.where(cond, x, y)
7826
7827 cpu_grad = torch.randn(cpu_out.shape)
7828 grad = cpu_grad.to('mps')
7829
7830 cpu_out.backward(gradient=cpu_grad)
7831 out.backward(gradient=grad)
7832
7833 self.assertEqual(out, cpu_out)
7834 self.assertEqual(x.grad, cpu_x.grad)
7835 self.assertEqual(y.grad, cpu_y.grad)
7836
7837 for shape in ([(0, 3), [], (2, 3), (9,)]):
7838 helper(shape, shape, shape)
7839
7840 helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
7841 helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
7842 helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
7843 helper([], (1, 1, 4), (2, 3, 1))
7844 helper([], (2, 3, 4), [])
Alexca69ddb2022-10-07 01:38:57 +00007845 helper((5, 2, 3), (2, 3), (2, 3))
7846 helper((2, 3), (5, 2, 3), (2, 3))
7847 helper((2, 3), (2, 3), (5, 2, 3))
7848 helper((2, 3), (5, 2, 3), (6, 5, 2, 3))
Nikita Shulga9b03a062024-03-08 07:25:49 -08007849 # Test that output is correctly resizes
7850 # TODO: Remove me when out OpInfo testing is enabled on MPS
7851 output = torch.tensor(0.0, device="mps")
7852 cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps")
7853 inp = torch.rand(3, 3, device="mps")
7854 other = torch.rand(3, 3, device="mps")
7855 out = torch.where(cond, inp, other, out=output)
7856 self.assertEqual(id(out), id(output))
7857 self.assertEqual(out.shape, (3, 3))
Kulin Sethe011a8e2022-05-13 18:28:53 +00007858
7859 # Test normal
7860 def test_normal(self):
7861 def helper(shape, mean=0.0, std=1.0):
Kulin Sethe011a8e2022-05-13 18:28:53 +00007862 mps_out = torch.normal(mean, std, shape, device='mps')
7863
Kulin Sethe011a8e2022-05-13 18:28:53 +00007864 mean_array = np.ones(shape)
7865 mean_array *= mean
7866 cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
7867 mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
7868
7869 std_array = np.ones(shape)
7870 std_array *= std
7871 cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
7872 std_tensor = cpu_std_tensor.detach().clone().to('mps')
7873
qqaatwe1b15b72022-06-28 15:19:39 +00007874 # test out
Kulin Sethe011a8e2022-05-13 18:28:53 +00007875 mps_out = torch.zeros(shape, device='mps')
7876 torch.normal(mean_tensor, std, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007877
7878 mps_out = torch.zeros(shape, device='mps')
7879 torch.normal(mean, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007880
7881 mps_out = torch.zeros(shape, device='mps')
7882 torch.normal(mean_tensor, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007883
qqaatwe1b15b72022-06-28 15:19:39 +00007884 # test without out
7885 mps_out = torch.normal(mean_tensor, std)
7886 self.assertEqual(mps_out.size(), mean_tensor.size())
7887
7888 mps_out = torch.normal(mean, std_tensor)
7889 self.assertEqual(mps_out.size(), std_tensor.size())
7890
7891 inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
7892 mps_out = torch.normal(mean_tensor, std_tensor)
7893 self.assertEqual(mps_out.size(), inferred_shape)
7894
Kulin Sethe011a8e2022-05-13 18:28:53 +00007895 helper((2, 3, 4, 5, 6))
7896 helper((100, 100), 2.5, 1.2)
7897
7898 def test_bernoulli(self):
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007899 shape = (10, 10)
7900 all_ones = torch.ones(shape, device='mps')
7901 all_zeros = torch.zeros(shape, device='mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00007902
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007903 prob_tensor = all_ones * 0.5
7904 # probability of drawing "1" is 0.5
7905 mps_out = torch.bernoulli(prob_tensor)
7906 # We can't check reliably the mean and std.
7907 # Just make sure we don't return constant values
7908 self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
7909 self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007910
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007911 # probability of drawing "1" is 0
7912 mps_out = torch.bernoulli(all_zeros)
7913 self.assertEqual(mps_out, all_zeros)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007914
Ramin Azarmehra4cc6392022-09-30 22:40:50 +00007915 # probability of drawing "1" is 1
7916 mps_out = torch.bernoulli(all_ones)
7917 self.assertEqual(mps_out, all_ones)
Kulin Sethe011a8e2022-05-13 18:28:53 +00007918
Nikita Shulgab7bf9532023-05-11 23:52:38 +00007919 # Check it works for different dtypes
7920 for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]:
7921 mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5)
7922 # Check that output is not all zeros or ones
7923 if product_version > 13.0:
7924 uniq = mps_out.unique()
Nikita Shulga9e089db2023-05-13 01:19:08 +00007925 self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
Nikita Shulgab7bf9532023-05-11 23:52:38 +00007926 else:
7927 self.assertEqual(mps_out.min().item(), 0.)
7928 self.assertEqual(mps_out.max().item(), 1.)
7929
Ramin Azarmehr688e3512023-01-03 16:01:19 +00007930 def test_mps_generator(self):
7931 # explicit manual seeding by creating an MPS Generator
7932 g_mps = torch.Generator(device='mps')
7933 g_mps.manual_seed(999)
7934 mps_x = torch.randn(5, device='mps', generator=g_mps)
7935 g_mps.manual_seed(999)
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007936 # generate random numbers with offset `0`
Ramin Azarmehr688e3512023-01-03 16:01:19 +00007937 mps_y = torch.randn(5, device='mps', generator=g_mps)
7938 # seed values were the same, so the random tensor contents should match
7939 self.assertEqual(mps_x, mps_y)
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007940 # save generator's state (offset = 1) to restore it later
Ramin Azarmehr688e3512023-01-03 16:01:19 +00007941 g_state = g_mps.get_state()
7942
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007943 # generate random numbers with offset `1`
Ramin Azarmehr688e3512023-01-03 16:01:19 +00007944 mps_x = torch.randn(5, device='mps', generator=g_mps)
7945 # in this case, the random results must differ from the last generated random results
7946 self.assertNotEqual(mps_x, mps_y)
7947
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007948 # mps_x was produced by g_state, we use it as our reference mps_y.
7949 mps_y = mps_x
7950
Ramin Azarmehr688e3512023-01-03 16:01:19 +00007951 # restore the previously saved state, and the results should match again
7952 g_mps.set_state(g_state)
7953 mps_x = torch.randn(5, device='mps', generator=g_mps)
7954 self.assertEqual(mps_x, mps_y)
7955
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007956 @serialTest()
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007957 def test_default_mps_generator(self):
7958 # manual seeding on the "default" MPS generator using
7959 # the global torch.manual_seed()
7960 torch.manual_seed(230)
7961 mps_x = torch.randn(5, device='mps')
7962 # manual seeding using torch.mps.manual_seed()
7963 # which should set the "default" MPS generator
7964 # like the global torch.manual_seed()
7965 torch.mps.manual_seed(230)
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007966 # generate random numbers with offset `0`
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007967 mps_y = torch.randn(5, device='mps')
7968 # seed values were the same, so the random tensor contents should match
7969 self.assertEqual(mps_x, mps_y)
7970
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007971 # save the default generator's state (offset = 1) to restore it later
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007972 g_state = torch.mps.get_rng_state()
7973
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007974 # generate random numbers with offset `1`
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007975 mps_x = torch.randn(5, device='mps')
7976 # in this case, the random results must differ from the last generated random results
7977 self.assertNotEqual(mps_x, mps_y)
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007978 # since we called randn twice after seeding, the offset should be 2
7979 self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2)
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007980
Li-Huai (Allan) Lin8ea03372024-07-17 11:46:09 -07007981 # mps_x was produced by g_state, we use it as our reference mps_y.
7982 mps_y = mps_x
7983
7984 # restore the previously saved state to the "default" MPS generator, and the results should match again
Ramin Azarmehrbdd8f512023-02-12 21:22:28 +00007985 torch.mps.set_rng_state(g_state)
7986 mps_x = torch.randn(5, device='mps')
7987 self.assertEqual(mps_x, mps_y)
7988
7989 def test_device_synchronize(self):
7990 # just running some ops each followed by a synchronize to wait for
7991 # MPS stream to finish running each of them
7992 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
7993 .to(device='mps', dtype=torch.float)
7994
7995 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7996 torch.mps.synchronize()
7997 x = net1(x)
7998 torch.mps.synchronize()
7999 x.backward(torch.randn_like(x))
8000 torch.mps.synchronize()
8001
Huy Dofdd0a7f2024-07-01 18:44:48 +00008002 @serialTest()
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00008003 def test_mps_allocator_module(self):
8004 # first garbage collect and empty the cached blocks
8005 gc.collect()
8006 torch.mps.empty_cache()
8007 # measure memory allocations from MPSAllocator
8008 current_alloc_before = torch.mps.current_allocated_memory()
8009 # after garbage collection and emptying the cache the
8010 # current_allocated_memory must be zero
Nikita Shulgabc689072024-06-26 16:29:59 +00008011 self.assertEqual(current_alloc_before, 0)
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00008012 # measure total memory allocations from Metal driver
8013 driver_alloc_before = torch.mps.driver_allocated_memory()
8014 # allocate a new 8 MB tensor to force allocation of a new Metal Heap
8015 x = torch.ones(1024 * 1024 * 8, device="mps")
8016 # get memory allocations after allocating tensor x
8017 current_alloc_after = torch.mps.current_allocated_memory()
8018 driver_alloc_after = torch.mps.driver_allocated_memory()
8019 # current and driver memory allocations must have
8020 # grown at this point
Nikita Shulgabc689072024-06-26 16:29:59 +00008021 self.assertGreater(current_alloc_after, current_alloc_before)
8022 self.assertGreater(driver_alloc_after, driver_alloc_before)
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00008023
Kulin Seth8df56af2024-06-12 16:03:57 +00008024 def test_mps_allocator_stats(self):
8025 max_memory = torch.mps.recommended_max_memory()
8026 print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB")
Nikita Shulgabc689072024-06-26 16:29:59 +00008027 self.assertGreater(max_memory, 0)
Kulin Seth8df56af2024-06-12 16:03:57 +00008028
Ramin Azarmehr0be53d82023-05-12 21:55:34 +00008029 # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
8030 # press record, then run this python test, and press stop. Next expand
8031 # the os_signposts->PyTorchMPS and check if events or intervals are logged
8032 # like this example:
8033 # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
8034 def test_mps_profiler_module(self):
8035 with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
8036 # just running some ops to capture the OS Signposts traces for profiling
8037 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
8038 .to(device='mps', dtype=torch.float)
8039 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8040 x = net1(x)
8041
8042 torch.mps.profiler.start(mode="interval", wait_until_completed=True)
8043 # just running some ops to capture the OS Signposts traces for profiling
8044 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8045 x = net1(x)
8046 torch.mps.profiler.stop()
8047
Ramin Azarmehrcdfd0ea2023-08-08 03:45:45 +00008048 def test_mps_event_module(self):
8049 startEvent = torch.mps.Event(enable_timing=True)
8050 startEvent.record()
8051 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
8052 .to(device='mps', dtype=torch.float)
8053 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8054 x = net1(x)
8055 endEvent = torch.mps.Event(enable_timing=True)
8056 endEvent.record()
8057 elapsedTime = startEvent.elapsed_time(endEvent)
Nikita Shulgabc689072024-06-26 16:29:59 +00008058 self.assertGreater(elapsedTime, 0.0)
Ramin Azarmehrcdfd0ea2023-08-08 03:45:45 +00008059
Denis Vieriude7ec2d2023-05-25 23:32:29 +00008060 def test_jit_save_load(self):
8061 m = torch.nn.Module()
8062 m.x = torch.rand(3, 3, device='mps')
8063 buffer = io.BytesIO()
8064 torch.jit.save(torch.jit.script(m), buffer)
8065 buffer.seek(0)
8066 n = torch.jit.load(buffer)
8067 self.assertEqual(n.x, m.x)
8068
Nikita Shulga29cde002023-04-05 21:24:45 +00008069 # Test random_, random_.to and random_.from
Kulin Sethe011a8e2022-05-13 18:28:53 +00008070 def test_random(self):
8071 def helper(shape, low, high, dtype=torch.int32):
8072
Kulin Sethe011a8e2022-05-13 18:28:53 +00008073 mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
8074
Alban Desmaison02551a02022-05-28 12:39:10 -04008075 # We can't check reliably the mean and std.
8076 # Just make sure we don't return constant values
Nikita Shulga29cde002023-04-05 21:24:45 +00008077 self.assertNotEqual(mps_out.float().mean().item(), 0.)
8078 self.assertNotEqual(mps_out.float().std().item(), 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008079
8080 helper([100, 100], 0, 10)
8081 helper([100, 100], 23, 89)
8082 helper([100, 100], 23, 89, dtype=torch.float32)
8083 helper([100, 100], 23, 89, dtype=torch.int64)
8084 helper([100, 100], 0, 2, dtype=torch.bool)
8085
Nikita Shulga29cde002023-04-05 21:24:45 +00008086 # Test random_
8087 for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]:
8088 x = torch.empty(10, 10, dtype=dtype, device='mps')
8089 x.random_()
8090 self.assertNotEqual(x.max().item(), 0)
8091
Kulin Seth83239352022-06-10 13:16:21 +00008092 # Test exponential
8093 def test_exponential(self):
8094 def helper(shape, lamda, dtype=torch.float32):
8095
8096 mps_out = torch.zeros(shape, device='mps', dtype=dtype)
8097 mps_out.exponential_(lamda)
8098
8099 print(mps_out.to('cpu').float().mean(), 1 / lamda)
8100 print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2))
8101
8102 for dtype in [torch.float32, torch.float16]:
8103 helper([100, 100], 2, dtype)
8104 helper([100, 100], 1, dtype)
8105 helper([100, 100], 3, dtype)
8106 helper([100, 100], 0.5, dtype)
8107
8108 def test_exponential_1(self):
8109 rate = torch.randn(5, 5).abs().requires_grad_()
8110 rate_1d = torch.randn(1).abs().requires_grad_()
8111 self.assertEqual(Exponential(rate).sample().size(), (5, 5))
8112 self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
8113 self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
8114 self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
8115 self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
8116 self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
8117
Kulin Sethe011a8e2022-05-13 18:28:53 +00008118 # Test add
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008119 def test_add_sub(self):
8120 def helper(shape, alpha, op_name, inplace):
8121 if op_name == "add":
8122 op = torch.Tensor.add_ if inplace else torch.add
8123 elif op_name == "sub":
8124 op = torch.Tensor.sub_ if inplace else torch.sub
8125
Kulin Setha6347f52022-06-07 18:22:10 +00008126 for dtype in [torch.float16, torch.float32]:
8127 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
8128 mps_x = cpu_x.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00008129
Kulin Setha6347f52022-06-07 18:22:10 +00008130 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
8131 mps_y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00008132
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008133 cpu_out = op(cpu_x, cpu_y, alpha=alpha)
8134 mps_out = op(mps_x, mps_y, alpha=alpha)
Kulin Setha6347f52022-06-07 18:22:10 +00008135 # fp16 isn't accurate when alpha is passed
8136 # TODO: remove or fix 'tol' when we fix problems with fp16
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008137 tol = 2e-3 if dtype is torch.float16 else None
Kulin Setha6347f52022-06-07 18:22:10 +00008138 self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol)
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008139 if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted.
8140 # create a scalar tensor
8141 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8142 mps_s = cpu_s.detach().clone().to('mps')
8143 # primary tensor is scalar
8144 self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y))
Kulin Setha6347f52022-06-07 18:22:10 +00008145 # create a scalar tensor
8146 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8147 mps_s = cpu_s.detach().clone().to('mps')
Kulin Setha6347f52022-06-07 18:22:10 +00008148 # secondary tensor is scalar
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008149 self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008150
Li-Huai (Allan) Lin2f66b572023-03-07 17:17:53 +00008151
8152 for op_name, inplace in product(["add", "sub"], [True, False]):
8153 helper((), 0.0, op_name, inplace)
8154 helper((2, 8, 4, 5), 0.0, op_name, inplace)
8155 helper((2, 8, 4, 5), 0.1, op_name, inplace)
8156 helper((2, 8, 4, 5), 1.0, op_name, inplace)
8157 helper((2, 8, 3, 5), 0.1, op_name, inplace)
8158 helper((2, 8, 3, 5), 0.2, op_name, inplace)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008159
8160 # Test add
8161 def test_add_scalars(self):
Kulin Setha6347f52022-06-07 18:22:10 +00008162 def helper(alpha):
8163 for dtype in [torch.float16, torch.float32]:
8164 cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8165 x = cpu_x.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00008166
Kulin Setha6347f52022-06-07 18:22:10 +00008167 cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False)
8168 y = cpu_y.detach().clone().to('mps')
Kulin Sethe011a8e2022-05-13 18:28:53 +00008169
Kulin Setha6347f52022-06-07 18:22:10 +00008170 cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
8171 out = torch.add(x, y, alpha=alpha)
8172 # fp16 isn't accurate when alpha is passed
8173 tol = 1e-3 if dtype is torch.float16 else None
8174 self.assertEqual(out, cpu_out, rtol=tol, atol=tol)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008175
Kulin Setha6347f52022-06-07 18:22:10 +00008176 helper(1.0)
8177 helper(0.0)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008178 helper(0.1)
8179 helper(0.2)
8180
Nikita Shulga06f874e2022-06-25 02:21:34 +00008181 # Test int32 tensor + int64 scalar add
8182 # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
8183 x = torch.ones(4, dtype=torch.int32, device='mps')
8184 self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
PyTorch MergeBotcba96362022-12-02 21:36:13 +00008185 self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
Nikita Shulga06f874e2022-06-25 02:21:34 +00008186
Kulin Seth50f7b402022-06-09 17:33:06 +00008187 def test_types_binary_op(self):
8188 # Float * Bool
8189 cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
8190 mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
8191 self.assertEqual(cpu_x, mps_x)
8192 # Float * Int64
8193 cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
8194 mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
8195 self.assertEqual(cpu_y, mps_y)
8196
Kulin Setha6347f52022-06-07 18:22:10 +00008197 def test_unary_ops(self):
8198 def helper(shape, op):
8199 for dtypef in [torch.float32]:
8200 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8201 mps_x = cpu_x.detach().clone().to('mps')
8202 self.assertEqual(op(cpu_x), op(mps_x))
8203
8204 for dtypei in [torch.int32, torch.int16]:
8205 cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False)
8206 mps_x = cpu_x.to('mps')
8207 self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4)
Peter Pham74dfdc52023-12-15 23:14:03 +00008208 # test slice
8209 for dtypef in [torch.float32]:
8210 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8211 mps_x = cpu_x.detach().clone().to('mps')
8212 cpu_slice = cpu_x[:, ::2, :, :]
8213 mps_slice = mps_x[:, ::2, :, :]
8214 self.assertEqual(op(cpu_slice), op(mps_slice))
8215 # test view
8216 for dtypef in [torch.float32]:
8217 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8218 mps_x = cpu_x.detach().clone().to('mps')
8219 # create view of tensor by reducing the 3rd and 4th dimension
8220 combined_dim = shape[-1] * shape[-2]
8221 reshaped_dims = list(shape[:-2]) + [combined_dim]
8222 cpu_view = cpu_x.view(*reshaped_dims)
8223 mps_view = mps_x.view(*reshaped_dims)
8224 self.assertEqual(op(cpu_view), op(mps_view))
Kulin Setha6347f52022-06-07 18:22:10 +00008225
8226 helper((2, 8, 4, 5), torch.exp)
8227 helper((2, 8, 3, 5), torch.exp2)
arnaudstiegler16e35bd2022-10-26 17:45:46 +00008228 helper((2, 8, 3, 5), torch.expm1)
Kulin Setha6347f52022-06-07 18:22:10 +00008229 helper((2, 8, 3, 5), torch.log)
8230 helper((2, 8, 3, 5), torch.cos)
Peter Phambba06ad2023-07-23 01:36:43 +00008231 helper((2, 8, 3, 5), torch.erfinv)
8232
Kulin Setha6347f52022-06-07 18:22:10 +00008233
Peter Stefekd2c24ec2023-07-19 03:56:35 +00008234 def test_non_dense_in_storage_unary_ops(self):
8235 def helper(op):
8236 for dtypef in [torch.float32]:
8237 cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False)
8238 mps_x = cpu_x.detach().clone().to('mps')
8239 self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]))
8240
8241 for dtypei in [torch.int32, torch.int16, torch.int8]:
8242 cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False)
8243 mps_x = cpu_x.to('mps')
8244 self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4)
8245
8246 helper(torch.exp)
8247 helper(torch.exp2)
8248 helper(torch.expm1)
8249 helper(torch.log)
8250 helper(torch.cos)
8251
Li-Huai (Allan) Lin538114d2023-11-14 22:03:21 +00008252 def test_unary_ops_storage_offset_strided(self):
8253 def helper(shape, op, inplace, dtype=torch.float32):
8254 # test in-place with storage_offset
8255 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8256 mps_x = cpu_x.detach().clone().to('mps')
8257 y = op(mps_x[1])
8258 cpu_y = op(cpu_x[1])
8259 self.assertEqual(y, cpu_y)
8260
8261
8262 # See https://github.com/pytorch/pytorch/issues/100764
8263 if not inplace:
8264 cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8265 mps_x = cpu_x.detach().clone().to('mps')
8266 cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t()
8267 mps_y = cpu_y.detach().clone().to('mps')
8268 op(cpu_x, out=cpu_y)
8269 op(mps_x, out=mps_y)
8270 self.assertEqual(mps_y, cpu_y)
8271
8272
8273 helper((5, 5), torch.exp, False)
8274 helper((5, 5), torch.cos, False)
8275 helper((5, 5), torch.neg, False)
8276 helper((5, 5), torch.tanh, False)
8277 helper((5, 5), torch.tanh_, True)
8278
Kulin Sethe011a8e2022-05-13 18:28:53 +00008279 def test_atan2(self):
8280 def helper(shape):
8281 input_cpu = torch.randn(shape)
8282 input_mps = input_cpu.detach().clone().to("mps")
8283
8284 other_cpu = torch.randn(shape)
8285 other_mps = other_cpu.detach().clone().to("mps")
8286
8287 atan2_cpu = torch.atan2(input_cpu, other_cpu)
8288 atan2_mps = torch.atan2(input_mps, other_mps)
8289
8290 self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
8291
8292 helper(4)
8293 helper(10000)
8294 helper((10000, 40))
8295
Kulin Seth6a842e32022-10-03 21:05:30 +00008296 def test_multinomial(self):
8297 # Test with num_dist = 1
8298 def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
8299 cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
8300 prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
8301
8302 mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
Thomas4935b592022-11-23 02:18:03 +00008303 if (not replacement):
Kulin Seth6a842e32022-10-03 21:05:30 +00008304 print(mps_out.to('cpu'))
8305 else:
8306 # Compare "real" with theoretical values
8307 print(mps_out.to('cpu').float().mean(), compare_mean)
8308 print(mps_out.to('cpu').float().std() ** 2, compare_var)
8309
8310 # TODO: Add tests for data types
8311 helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
8312 helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8313 helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8314 helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8315 helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008316
Nikita Shulga10a1efb2023-02-05 18:21:29 +00008317 def test_cumsum_dim_check(self):
8318 x = torch.rand((3, 3), device="mps")
8319 self.assertEqual(x.cumsum(1), x.cumsum(-1))
8320 self.assertEqual(x.cumsum(0), x.cumsum(-2))
8321 self.assertRaises(IndexError, lambda: x.cumsum(2))
8322 self.assertRaises(IndexError, lambda: x.cumsum(-3))
8323
Peter Stefek97e50552023-08-01 21:51:16 +00008324 def test_cumprod_dim_check(self):
8325 x = torch.rand((3, 3), device="mps")
8326 self.assertEqual(x.cumprod(1), x.cumprod(-1))
8327 self.assertEqual(x.cumprod(0), x.cumprod(-2))
8328 self.assertRaises(IndexError, lambda: x.cumprod(2))
8329 self.assertRaises(IndexError, lambda: x.cumprod(-3))
8330
Li-Huai (Allan) Lin88a659e2023-11-08 16:19:38 -08008331class TestLogical(TestCaseMPS):
8332 def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
8333 return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
8334
8335 def test_logical_not(self):
8336 def helper(x):
8337 cpu_x = x
8338 x = cpu_x.detach().clone().to('mps')
8339
8340 result = torch.logical_not(x)
8341 result_cpu = torch.logical_not(cpu_x)
8342
8343 self.assertEqual(result, result_cpu)
8344
8345 helper(self._wrap_tensor([1, 1, 0, 0]))
8346 helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True))
8347 helper(self._wrap_tensor([True, True, False, False]))
8348 helper(self._wrap_tensor(1))
8349 helper(self._wrap_tensor(0))
8350 helper(self._wrap_tensor(True))
8351 helper(self._wrap_tensor(False))
8352
8353 def test_logical_and(self):
8354 def helper(x, other):
8355 cpu_x = x
8356 x = cpu_x.detach().clone().to('mps')
8357
8358 cpu_other = other
8359 other = cpu_other.detach().clone().to('mps')
8360
8361 result = torch.logical_and(x, other)
8362 result_cpu = torch.logical_and(cpu_x, cpu_other)
8363 self.assertEqual(result, result_cpu)
8364
8365 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8366 helper(
8367 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8368 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8369 )
8370 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8371 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8372 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8373 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8374 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8375
8376 def test_logical_or(self):
8377 def helper(x, other):
8378 cpu_x = x
8379 x = cpu_x.detach().clone().to('mps')
8380
8381 cpu_other = other
8382 other = cpu_other.detach().clone().to('mps')
8383
8384 result = torch.logical_or(x, other)
8385 result_cpu = torch.logical_or(cpu_x, cpu_other)
8386
8387 self.assertEqual(result, result_cpu)
8388
8389 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8390 helper(
8391 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8392 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8393 )
8394 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8395 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8396 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8397 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8398 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8399
8400 def test_logical_xor(self):
8401 def helper(x, other):
8402 cpu_x = x
8403 x = cpu_x.detach().clone().to('mps')
8404
8405 cpu_other = other
8406 other = cpu_other.detach().clone().to('mps')
8407
8408 result = torch.logical_xor(x, other)
8409 result_cpu = torch.logical_xor(cpu_x, cpu_other)
8410
8411 self.assertEqual(result, result_cpu)
8412
8413 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8414 helper(
8415 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8416 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8417 )
8418 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8419 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8420 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8421 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8422 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8423
8424 def test_min_max(self):
8425 def helper(dtype):
8426 for _ in range(10):
8427 if dtype == torch.float32 or dtype == torch.float16:
8428 x = torch.randn((30, 15), device='mps', dtype=dtype)
8429 else:
8430 x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
8431 x_cpu = x.to("cpu")
8432
8433 y = x.max()
8434 y_cpu = x_cpu.max()
8435 self.assertEqual(y, y_cpu)
8436
8437 z = x.min()
8438 z_cpu = x_cpu.min()
8439 self.assertEqual(z, z_cpu)
8440
8441 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
8442
Joona Havukainen082d0b82024-07-29 20:09:15 +00008443 def test_min_max_nan_propagation(self):
8444 def helper(dtype):
8445 cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
8446 mps_x = cpu_x.detach().clone().to('mps')
8447
8448 cpu_max = torch.max(cpu_x)
8449 mps_max = torch.max(mps_x).to('cpu')
8450
8451 cpu_amax = torch.amax(cpu_x)
8452 mps_amax = torch.amax(mps_x).to('cpu')
8453
8454 cpu_min = torch.min(cpu_x)
8455 mps_min = torch.min(mps_x).to('cpu')
8456
8457 cpu_amin = torch.amin(cpu_x)
8458 mps_amin = torch.amin(mps_x).to('cpu')
8459
8460 self.assertEqual(cpu_max, mps_max)
8461 self.assertEqual(cpu_amax, mps_amax)
8462 self.assertEqual(cpu_min, mps_min)
8463 self.assertEqual(cpu_amin, mps_amin)
8464 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
8465
Joona Havukainenc451d102024-05-01 23:14:05 +00008466 def test_isin(self):
8467 def helper(dtype):
8468 shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),
8469 ([5], [10]), ([0], [5]), ([5], [0])]
8470 for shape_tuple in shapes:
8471 for inverted in [True, False]:
8472 if dtype.is_floating_point:
8473 # Half is not supported for CPU isin. Compute reference in FP32
8474 A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32)
8475 B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32)
8476 else:
8477 A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype)
8478 B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype)
8479
8480 A_mps = A.clone().detach().to('mps')
8481 B_mps = B.clone().detach().to('mps')
8482
8483 cpu_ref = torch.isin(A, B, invert=inverted)
jhavukainend28868c2024-05-20 20:23:53 +00008484 if dtype in [torch.float16, torch.bfloat16]:
Joona Havukainenc451d102024-05-01 23:14:05 +00008485 cpu_ref.type(dtype)
8486
8487 mps_out = torch.isin(A_mps, B_mps, invert=inverted)
8488 self.assertEqual(mps_out, cpu_ref)
8489
jhavukainend28868c2024-05-20 20:23:53 +00008490 dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8]
8491 if product_version < 14.0:
8492 # Int types expected to fail on MacOS < 14.0
8493 dtypes = [torch.float32, torch.float16, torch.bfloat16]
Joona Havukainenc451d102024-05-01 23:14:05 +00008494
jhavukainend28868c2024-05-20 20:23:53 +00008495 [helper(dtype) for dtype in dtypes]
8496
Joona Havukainenc451d102024-05-01 23:14:05 +00008497 def test_isin_asserts(self):
8498 A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8499 B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16)
8500 with self.assertRaisesRegex(RuntimeError, 'Expected elements.dtype()*'):
8501 out = torch.isin(A, B)
8502
8503
8504 C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8505 D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32)
8506 with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'):
8507 out = torch.isin(C, D)
8508
Li-Huai (Allan) Lin88a659e2023-11-08 16:19:38 -08008509class TestSmoothL1Loss(TestCaseMPS):
8510
8511 def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
8512 # CPU
8513 input_cpu = torch.randn(4, 7, requires_grad=requires_grad)
8514 target_cpu = torch.randn(4, 7)
8515
8516 # MPS
8517 input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
8518 target_mps = target_cpu.detach().clone().to('mps')
8519
8520 smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
8521 smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
8522
8523 self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
8524
8525 if requires_grad:
8526 smooth_l1_loss_cpu.backward()
8527 smooth_l1_loss_mps.backward()
8528 self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
8529
8530 return smooth_l1_loss_cpu, smooth_l1_loss_mps
8531
8532 def test_smooth_l1_loss_reduction_none(self):
8533 self._smooth_l1_loss_helper(reduction="none")
8534
8535 def test_smooth_l1_loss_reduction_mean(self):
8536 self._smooth_l1_loss_helper(reduction="mean")
8537
8538 def test_smooth_l1_loss_reduction_sum(self):
8539 self._smooth_l1_loss_helper(reduction="sum")
8540
8541 def test_smooth_l1_loss_reduction_mean_backward(self):
8542 self._smooth_l1_loss_helper(reduction="mean", requires_grad=True)
8543
8544 def test_smooth_l1_loss_reduction_mean_sum_backward(self):
8545 self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
8546
8547class TestNLLLoss(TestCaseMPS):
8548 def test_nll_loss_mismatched_batch(self, device='mps'):
8549 x = torch.randn((10, 3), requires_grad=True, device=device)
8550 # t should have size (10,)
8551 t = torch.zeros((3,), dtype=torch.int64, device=device)
8552 with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
8553 F.nll_loss(x, t)
8554
8555 def test_nll_loss_out_of_bounds_ignore_index(self):
8556
8557 def test_nll_loss_out_of_bounds_ignore_index_helper(device):
8558 output = []
8559 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8560 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8561 t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
8562 t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device)
8563 for reduction in ['mean', 'none']:
8564 # out of bound ignore_index
8565 output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction))
8566 # default ignore_index
8567 output.append(F.nll_loss(x, t2, reduction=reduction))
8568 return output
8569
8570 output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu')
8571 output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps')
8572
8573 for cpu, mps in zip(output_cpu, output_mps):
8574 self.assertEqual(cpu, mps)
8575
8576 def test_nll_loss_invalid_target_dim(self):
8577
8578 def _test_nll_loss_invalid_target_dim(device):
8579 output = []
8580 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8581 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8582 t = torch.zeros((6, 2), dtype=torch.int64, device=device)
8583 with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
8584 F.nll_loss(x, t)
8585
8586 _test_nll_loss_invalid_target_dim(device='cpu')
8587 _test_nll_loss_invalid_target_dim(device='mps')
8588
8589 def test_nll_loss_invalid_weights(self):
8590
8591 def _test_nll_loss_invalid_weights(device):
8592 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8593 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8594 t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
8595 invalid_weights = [
8596 torch.zeros(4, device=device),
8597 torch.zeros((1, 3), device=device),
8598 ]
8599 msg = "weight tensor should be defined either for all 3 classes or no classes"
8600 for weight in invalid_weights:
8601 with self.assertRaisesRegex(RuntimeError, msg):
8602 F.nll_loss(x, t, weight=weight)
8603
8604 _test_nll_loss_invalid_weights(device='cpu')
8605 _test_nll_loss_invalid_weights(device='mps')
8606
8607 def _nll_loss_helper(self, input_size, reduction, expected):
8608
8609 # CPU
8610 input = torch.rand(input_size, requires_grad=True, device='cpu')
8611 num_channels = input_size[1]
8612 target_size = (input_size[0], ) + tuple(input_size[2:])
8613 target = torch.randint(num_channels, target_size, device='cpu')
8614 weights = torch.randn(num_channels)
8615
8616 # MPS
8617 input_mps = input.detach().clone().to('mps').requires_grad_()
8618 target_mps = target.detach().clone().to('mps')
8619 weights_mps = weights.to("mps")
8620
8621 output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
8622 output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
8623 self.assertEqual(output_cpu, output_mps.to('cpu'))
8624
8625 output_cpu.sum().backward()
8626 output_mps.sum().backward()
8627 self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8628
8629 def _nll_loss_1d_helper(self, input_size, reduction):
8630
8631 # CPU
8632 input = torch.rand(input_size, requires_grad=True, device='cpu')
8633 num_channels = input_size[0]
8634 target = torch.randint(num_channels, [], device='cpu')
8635
8636 # MPS
8637 input_mps = input.detach().clone().to('mps').requires_grad_()
8638 target_mps = target.detach().clone().to('mps')
8639
8640 output_cpu = F.nll_loss(input, target, reduction=reduction)
8641 output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
8642 self.assertEqual(output_cpu, output_mps.to('cpu'))
8643
8644 output_cpu.sum().backward()
8645 output_mps.sum().backward()
8646 self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8647
8648 def test_nll_loss_1d(self, device='cpu'):
8649 self._nll_loss_1d_helper([10], "none")
8650 self._nll_loss_1d_helper([10], "mean")
8651 self._nll_loss_1d_helper([10], "sum")
8652
8653 def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
8654 self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
8655 self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
8656 self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
8657 self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
8658 self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
8659
8660 def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
8661 nan = torch.tensor(float('nan'), device=device)
8662 self._nll_loss_helper([1, 3], "mean", nan)
8663 self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
8664 self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
8665 self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
8666 self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
8667
8668 def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
8669 zero = torch.tensor(0, device=device)
8670 self._nll_loss_helper([1, 3], "sum", zero)
8671 self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
8672 self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
8673 self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
8674 self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
8675
8676 def test_nll_loss_byte_target_matches_long(self, device='cpu'):
8677 N, C = 10, 4
8678 input = torch.randn(N, C, device=device, requires_grad=True)
8679 target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
8680
8681 def compute_result_and_gradient(reduction, target_dtype):
8682 result, grad = {}, {}
8683 for dev in ['cpu', 'mps']:
8684 input_dev = input.to(dev)
8685 input_ = input_dev.detach()
8686 input_.requires_grad_()
8687
8688 target_dev = target.to(dev)
8689
8690 prob = F.log_softmax(input_, dim=-1)
8691 loss = nn.NLLLoss(reduction=reduction)
8692 result[dev] = loss(prob, target_dev.to(target_dtype))
8693 result[dev].sum().backward()
8694 grad[dev] = input_.grad
8695
8696 return result, grad
8697
8698 for reduction in ["none", "mean", "sum"]:
8699 result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
8700 result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
8701
8702 self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
8703 self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
Soof Golane4fe11e2023-02-09 10:42:48 +00008704
8705class TestTopK(TestCase):
8706 def _test_topk(self, shape, largest):
8707 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
8708 x = cpu_x.detach().clone().to('mps')
8709 if isinstance(shape, tuple):
8710 for curr_dim, dim_size in enumerate(shape):
8711 for k in range(1, dim_size + 1):
8712 topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
8713 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
8714 self.assertEqual(topk_values, topk_values_cpu)
8715 self.assertEqual(topk_indices, topk_indices_cpu)
8716 else:
8717 for k in range(1, shape):
8718 topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
8719 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
8720 self.assertEqual(topk_values, topk_values_cpu)
8721 self.assertEqual(topk_indices, topk_indices_cpu)
8722
8723 def test_topk(self):
8724 largest_vals = [True, False]
8725 shapes = [
8726 # Zero Element Tensors
8727 0,
8728 (1, 0),
8729 (0, 1),
8730 (1, 0, 1),
8731 # Multiple Element Tensors
8732 1,
8733 2,
8734 (5, 1),
8735 (1, 5),
8736 (5, 9, 7, 4),
8737 ]
8738
8739 for shape in shapes:
8740 for largest_val in largest_vals:
8741 with self.subTest(shape=shape, largest_val=largest_val):
8742 self._test_topk(shape, largest_val)
8743
Kulin Sethe011a8e2022-05-13 18:28:53 +00008744class TestNNMPS(NNTestCase):
8745
8746 def _create_basic_net(self):
8747 class Layer(nn.Module):
Oguz Ulgen221350e2024-08-01 00:22:48 -07008748 def __init__(self) -> None:
Xuehai Pan046e88a2023-02-12 22:20:50 +00008749 super().__init__()
Kulin Sethe011a8e2022-05-13 18:28:53 +00008750 self.layer_dummy_param = Parameter(torch.empty(3, 5))
ekamiti9e473fd2024-07-31 10:32:37 +00008751 self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7))
Kulin Sethe011a8e2022-05-13 18:28:53 +00008752
8753 class Net(nn.Module):
Oguz Ulgen221350e2024-08-01 00:22:48 -07008754 def __init__(self) -> None:
Xuehai Pan046e88a2023-02-12 22:20:50 +00008755 super().__init__()
Kulin Sethe011a8e2022-05-13 18:28:53 +00008756 self.l1 = Layer()
8757 self.dummy_param = Parameter(torch.empty(3, 5))
ekamiti9e473fd2024-07-31 10:32:37 +00008758 self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1))
Kulin Sethe011a8e2022-05-13 18:28:53 +00008759
8760 l = Layer()
8761 n = Net()
8762 s = nn.Sequential(n, n)
8763
8764 return l, n, s
8765
8766 def test_requires_grad_(self):
8767 m = self._create_basic_net()[-1]
8768 assert len(list(m.buffers())) > 0, 'invalid test'
8769 assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
8770 assert len(list(m.parameters())) > 0, 'invalid test'
8771 assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
8772 for requires_grad in (False, True):
8773 self.assertIs(m.requires_grad_(requires_grad), m)
8774 for p in m.parameters():
8775 self.assertEqual(p.requires_grad, requires_grad)
8776 for b in m.buffers():
8777 self.assertFalse(b.requires_grad)
8778
8779 def test_module_backcompat(self):
8780 from torch.serialization import SourceChangeWarning
8781 path = download_file('https://download.pytorch.org/test_data/linear.pt')
8782 with warnings.catch_warnings():
8783 warnings.simplefilter('ignore', SourceChangeWarning)
8784 m = torch.load(path)
8785 input = torch.randn(2, 3, dtype=torch.float)
8786 self.assertEqual(m(input).size(), (2, 5))
8787
8788 def test_conv_backcompat(self):
8789 from torch.serialization import SourceChangeWarning
8790 # This file was generated by running on PyTorch 1.0.1 on Python 2:
8791 #
8792 # import torch
8793 # from torch import nn
8794 # m = nn.Conv2d(1, 1, 1)
8795 # torch.save(m, 'legacy_conv2d.pt')
8796 #
8797 # NB: This Pickle also contains some Unicode data!
8798 path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
8799 with warnings.catch_warnings():
8800 warnings.simplefilter('ignore', SourceChangeWarning)
8801 m = torch.load(path, encoding='utf-8')
8802 input = torch.randn((1, 1, 1, 1), dtype=torch.float)
8803 self.assertEqual(m(input).size(), (1, 1, 1, 1))
8804
Kulin Seth017b0ae2022-05-31 02:09:03 +00008805 def test_conv_expand(self):
8806 device = 'mps'
8807 input_ = torch.rand(2, 3, 16, 16, device=device)
8808 kernel = torch.rand(1, 1, 3, 11, device=device)
8809 tmp_kernel = kernel.expand(-1, 3, -1, -1)
8810 output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
8811
8812 # The test should not crash
8813 def test_permute(self):
PumeTufc1c0cd2022-11-18 07:24:33 +00008814 M_cpu = torch.randn(5, 5)
8815 M_mps = M_cpu.to('mps')
8816
8817 output_cpu = M_cpu.permute(1, 0)
8818 output_mps = M_mps.permute(1, 0)
8819
8820 self.assertEqual(output_cpu, output_mps)
8821 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Seth017b0ae2022-05-31 02:09:03 +00008822
8823 # Printing of non_contiguous should not crash
8824 def test_print_non_contiguous(self):
8825 print(torch.ones(100, 100, device='mps').nonzero())
8826 print(torch.ones(100, 100, device='mps').nonzero().contiguous())
8827
Kulin Sethe011a8e2022-05-13 18:28:53 +00008828 def test_zero_grad(self):
8829 i = torch.randn(2, 5, requires_grad=True)
8830 module = nn.Linear(5, 5)
8831 for p in module.parameters():
8832 p.requires_grad = False
8833 module.zero_grad()
8834
8835 module.weight.requires_grad = True
8836 module.zero_grad()
8837 self.assertIsNone(module.weight.grad) # uninitialized grad
8838
8839 module(i).sum().backward()
8840 self.assertIsNotNone(module.weight.grad)
8841 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8842 module.zero_grad()
Jane Xub90496e2023-01-25 19:47:57 +00008843 self.assertIsNone(module.weight.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008844
8845 module.bias.requires_grad = True
8846 module.zero_grad()
Jane Xub90496e2023-01-25 19:47:57 +00008847 self.assertIsNone(module.weight.grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008848 self.assertIsNone(module.bias.grad)
8849 module(i).sum().backward()
8850 self.assertIsNotNone(module.weight.grad)
8851 self.assertIsNotNone(module.bias.grad)
8852 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8853 self.assertGreater(module.bias.grad.data.abs().sum(), 0)
Jane Xub90496e2023-01-25 19:47:57 +00008854
8855 # Force set to zeros.
8856 module.zero_grad(set_to_none=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +00008857 self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
8858 self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
8859
Jane Xub90496e2023-01-25 19:47:57 +00008860 module.zero_grad()
Kulin Sethe011a8e2022-05-13 18:28:53 +00008861 self.assertIsNone(module.weight.grad)
Jane Xub90496e2023-01-25 19:47:57 +00008862 self.assertIsNone(module.bias.grad)
8863
Kulin Sethe011a8e2022-05-13 18:28:53 +00008864
8865 def test_no_grad(self):
8866 for dtype in [torch.bfloat16, torch.float, torch.double]:
8867 module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
8868 input = torch.randn(1, 2, 10, 10).to(dtype)
8869 x = input
8870 y = input.clone()
8871
8872 output = module(x)
8873 self.assertTrue(output.requires_grad)
8874 output.backward(torch.ones(1, 5, 10, 10))
8875
8876 with torch.no_grad():
8877 output2 = module(y)
8878 self.assertFalse(output2.requires_grad)
8879 self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
8880
8881 def test_invalid_conv1d(self):
8882 for dtype in [torch.bfloat16, torch.float, torch.double]:
8883 module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
8884 input = torch.randn(1, 3, 4).to(dtype)
8885 with self.assertRaisesRegex(RuntimeError,
8886 r'Calculated padded input size per channel: \(4\). ' +
8887 r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
8888 module(input)
8889
8890 # Negative stride check
8891 module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
8892 input = torch.randn(1, 3, 4).to(dtype)
8893 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8894 module(input)
8895
8896 def test_conv2d_discontiguous_weight(self):
8897 # Test for https://github.com/pytorch/pytorch/issues/55781
8898 x = torch.ones(64, 16, 16, 16)
8899 weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
8900 self.assertFalse(weight.is_contiguous())
8901 y = torch.nn.functional.conv2d(x, weight, None)
8902 if torch.backends.mkldnn.is_available():
8903 # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
8904 with torch.backends.mkldnn.flags(enabled=False):
8905 y_ = torch.nn.functional.conv2d(x, weight, None)
8906 self.assertEqual(y, y_)
8907 self.assertEqual(y.sum(), 4186112.)
8908
8909 def test_invalid_conv2d(self):
8910 for dtype in [torch.bfloat16, torch.float, torch.double]:
8911 module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
8912 input = torch.empty(1, 1, 4, 4).to(dtype)
8913 self.assertRaises(RuntimeError, lambda: module(input))
8914
8915 module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
8916 input = torch.randn(1, 3, 1, 1)
8917 with self.assertRaisesRegex(RuntimeError,
8918 r'Calculated padded input size per channel: \(1 x 1\). ' +
8919 r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
8920 module(input)
8921
8922 # Negative stride check
8923 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
8924 input = torch.randn(1, 3, 4, 4).to(dtype)
8925 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8926 module(input)
8927
8928 # Zero stride check
8929 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
8930 input = torch.randn(1, 3, 4, 4).to(dtype)
8931 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8932 module(input)
8933
Nikita Shulgafa799132022-10-06 15:38:57 +00008934 # Input and weights on different devices
8935 self.assertRaisesRegex(RuntimeError,
8936 'must be on the same device',
8937 lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
8938 self.assertRaisesRegex(RuntimeError,
8939 'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same',
8940 lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
8941
8942
Kulin Sethe011a8e2022-05-13 18:28:53 +00008943 def test_conv2d_valid_padding(self, device='mps'):
8944 # Test F.conv2d padding='valid' is the same as no padding
8945 x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
8946 y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
8947
8948 expect = F.conv2d(x, y)
8949 actual = F.conv2d(x, y, padding='valid')
8950 self.assertEqual(expect.to('cpu'), actual.to('cpu'))
8951
Nikita Shulga265d6aa2023-11-10 04:29:33 +00008952 def test_conv2d_backward_collision(self):
8953 # Test for https://github.com/pytorch/pytorch/issues/112998
8954 x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True)
8955 m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps")
8956 m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps")
8957 y1, y2 = m1(x), m2(x)
8958 self.assertEqual(y1.shape, y2.shape)
8959 y1.sum().backward()
8960 # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8961 y2.sum().backward()
8962
Lucas Steuernagel2e517b22023-12-15 23:05:01 +00008963 @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
8964 def test_conv3d_backward_collision(self):
8965 # Conv3D is only available from MacOS 13.2 onwards
8966 x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
8967 m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
8968 m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
8969 y1, y2 = m1(x), m2(x)
8970 self.assertEqual(y1.shape, y2.shape)
8971 y1.sum().backward()
8972 # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8973 y2.sum().backward()
Nikita Shulga265d6aa2023-11-10 04:29:33 +00008974
Kulin Seth4858c562022-06-02 06:17:19 +00008975 def test_gemm_permute_transpose(self):
8976 batch_size = 32
8977 n = 20
8978 hidden = 768
8979 num_attention_heads = 12
8980 attention_head_size = hidden // num_attention_heads
8981
8982 def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
8983 new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
8984 x = x.view(new_x_shape)
8985 return x.permute(0, 2, 1, 3)
8986
8987 def attention2(key, *, workaround=False, device):
8988 key = transpose_for_scores(key)
8989 res = key.transpose(-1, -2)
8990 return res
8991
8992 A = torch.randn(batch_size, n, hidden)
8993 A_mps = A.detach().clone().to("mps")
8994
8995 r1 = attention2(A, device="cpu")
8996 r2 = attention2(A_mps, device="mps")
8997
8998 r2_cpu = r2.to("cpu")
8999 self.assertEqual(r1, r2_cpu)
9000
Nikita Shulgafd3a7262022-12-21 21:35:54 -08009001 def test_group_norm_backward(self, device='mps'):
9002 # See https://github.com/pytorch/pytorch/issues/88331 for more detail
9003 shape = [1, 4, 16, 16]
9004 x = torch.full(shape, 7.0, device=device)
9005
9006 target = torch.ones((1, 3, 128, 128), device=device)
9007
9008 conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
9009 conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
9010 norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device)
9011
9012 with torch.enable_grad():
9013 x = x.detach().requires_grad_()
9014 out = 5.5 * x
9015 out = conv_in(out)
9016 out = out + norm(out)
9017 out = out + norm(out)
9018 out = out + norm(out)
9019 out = F.interpolate(out, scale_factor=8.0, mode="nearest")
9020 out = norm(out)
9021 out = conv_out(out)
9022
9023 loss = (out - target).norm(dim=-1).sum()
9024 grad = -torch.autograd.grad(loss, x)[0]
9025 self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd')
9026
9027
Kulin Sethe011a8e2022-05-13 18:28:53 +00009028 # def test_conv2d_same_padding(self, device='mps'):
9029 # x = torch.rand(1, 1, 10, 11, device=device)
9030 # y = torch.rand(1, 1, 4, 5, device=device)
9031 # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
9032 # actual = F.conv2d(x, y, padding='same')
9033 # self.assertEqual(expect.to('cpu'), actual.to('cpu'))
9034
9035 # # With dilation
9036 # y = torch.rand(1, 1, 3, 4, device=device)
9037 # expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
9038 # actual = F.conv2d(x, y, padding='same', dilation=2)
9039 # self.assertEqual(expect, actual)
9040
9041 # # Dilation with asymmetric padding
9042 # y = torch.rand(1, 1, 4, 4, device=device)
9043 # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
9044 # actual = F.conv2d(x, y, padding='same', dilation=3)
9045 # self.assertEqual(expect, actual)
9046
9047
Li-Huai (Allan) Lin38e14402023-11-08 16:19:38 -08009048class TestPad(TestCaseMPS):
9049 def test_constant_pad(self):
9050 m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
9051 input_cpu = torch.randn(1, 16, 16, 16)
9052 input_mps = input_cpu.detach().clone().to("mps")
9053 r_cpu = m(input_cpu)
9054 r_mps = m(input_mps)
9055 self.assertEqual(r_cpu, r_mps.to("cpu"))
9056
9057 # Arbitrary input dimensions
9058 pad = (1, 1, 0, 0, 0, 0)
9059 value = 3.5
9060 input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
9061 input_mps = input_cpu.detach().clone().to("mps")
9062 r_cpu = F.pad(input_cpu, pad=pad, value=value)
9063 r_mps = F.pad(input_mps, pad=pad, value=value)
9064 self.assertEqual(r_cpu, r_mps.to("cpu"))
9065
9066 def test_circular_pad(self):
9067 # https://github.com/pytorch/pytorch/issues/80856
9068 k_cpu = torch.ones(3, 3, 9, 9)
9069 k_mps = k_cpu.detach().clone().to("mps")
9070
9071 x_cpu = torch.rand(1, 3, 32, 32)
9072 x_mps = x_cpu.detach().clone().to("mps")
9073
9074 x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular')
9075 x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular')
9076
9077 y_cpu = F.conv2d(x_pad_cpu, k_cpu)
9078 y_mps = F.conv2d(x_pad_mps, k_mps)
9079
9080 self.assertEqual(y_cpu, y_mps.cpu())
9081
9082 def test_constant_pad_4d_warning(self):
9083 inputCPU = torch.rand((1, 2, 2, 2, 1, 1))
9084 inputMPS = inputCPU.detach().clone().to('mps')
9085 outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0])
9086 outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0])
9087 self.assertEqual(outputCPU, outputMPS)
9088
9089 def test_pad(self):
9090 def helper(shape, padding, op, value=0):
9091 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
9092 inputCPU.retain_grad()
9093 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
9094
9095 if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
9096 padCriteria = op(padding, value)
9097 else:
9098 padCriteria = op(padding)
9099 outputCPU = padCriteria(inputCPU)
9100 outputMPS = padCriteria(inputMPS)
9101 self.assertEqual(outputCPU, outputMPS)
9102
9103 # backward pass (chose 0.6 just to have the grad_output != 1)
9104 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
9105 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
9106 self.assertEqual(inputCPU.grad, inputMPS.grad)
9107
9108 # 1D Padding
9109 helper((2, 4, 3), 2, nn.ReflectionPad1d)
9110 # verify if a change in shape of input would cause problems with graph caching
9111 helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
9112 # Replication 1D
9113 helper((2, 1, 6), 3, nn.ReplicationPad1d)
9114 # Constant Pad 1D
9115 helper((2, 3, 4), 2, nn.ConstantPad1d)
9116 # Constant Pad 1D with single dimension input
9117 helper((16), (1, 2), nn.ConstantPad1d)
9118
9119 # 2D Padding
9120 helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
9121 # verify if a change in shape of input would cause problems with graph caching
9122 helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
9123 # this should make the padding (2, 2, 2, 2)
9124 helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
9125 # verify if a change in shape of padding would cause problems with graph caching
9126 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
9127 # Constant Pad 2D
9128 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
9129 # input size < pad size
9130 helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
9131 # pad dims < input dims
9132 helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
9133 # pad dims == input dims
9134 helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
9135 # input.numel() == 0 but output.numel() > 0
9136 helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
9137 # pad dims < input dims - 2
9138 helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
9139
9140 # 3D Padding
9141 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
9142 # verify if a change in shape of padding would cause problems with graph caching
9143 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
9144 # case where input_d == pad_front/back for ReplicationPad3d
9145 helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
9146 # Constant Pad 3D
9147 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
9148 # input size < pad size
9149 helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
9150 # check the workaround for the right padding bug in Monterey
9151 helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
9152
9153 def test_constant_pad_nd_preserves_memory_format(self):
Kulin Sethe011a8e2022-05-13 18:28:53 +00009154 nchw_tensor = torch.rand((1, 2, 5, 3))
9155 nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
9156 self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
9157
9158 nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
9159 nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
9160 self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
9161
9162
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00009163class TestLinalgMPS(TestCaseMPS):
Kulin Sethe011a8e2022-05-13 18:28:53 +00009164 def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
9165 dtype = t.dtype
9166 numpy_dtype = dtype
9167 alpha = 1.2 if alpha is None else alpha
9168 beta = 0.8 if beta is None else beta
9169 res1 = f(t, m, v, alpha=alpha, beta=beta)
9170 res2 = torch.full_like(res1, math.nan)
9171 if transpose_out:
9172 res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
9173 f(t, m, v, alpha=alpha, beta=beta, out=res2)
9174 res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
9175 if beta != 0:
9176 res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9177 res3 = torch.from_numpy(res3).to(dtype)
Kulin Seth978304f2022-05-14 13:33:16 +00009178 self.assertEqual(res1, res2)
9179 self.assertEqual(res1, res3)
Kulin Sethe011a8e2022-05-13 18:28:53 +00009180
9181 def test_addmm(self, device="mps", dtype=torch.float32):
9182 M = torch.randn(10, 25, device=device).to(dtype)
9183 m1 = torch.randn(10, 50, device=device).to(dtype)
9184 m2 = torch.randn(50, 25, device=device).to(dtype)
9185 self._test_addmm_addmv(torch.addmm, M, m1, m2)
9186
Kulin Sethe011a8e2022-05-13 18:28:53 +00009187 # Test beta=0, M=nan
9188 M = torch.full((10, 25), math.nan, device=device).to(dtype)
9189 m1 = torch.randn(10, 50, device=device).to(dtype)
9190 m2 = torch.randn(50, 25, device=device).to(dtype)
9191 self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
9192
Kulin Seth978304f2022-05-14 13:33:16 +00009193 # Test transpose
9194 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
9195 def maybe_transpose(cond, m):
9196 if not cond:
9197 return m
9198 return m.t().clone(memory_format=torch.contiguous_format).t()
Kulin Sethe011a8e2022-05-13 18:28:53 +00009199
Kulin Seth978304f2022-05-14 13:33:16 +00009200 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
9201 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
9202 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
9203 self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
Kulin Sethe011a8e2022-05-13 18:28:53 +00009204
Denis Vieriu507b8c32023-02-11 00:16:46 +00009205 def _test_addr(self, f, t, m, v, alpha=None, beta=None):
9206 dtype = t.dtype
9207 numpy_dtype = dtype
9208 alpha = 1.2 if alpha is None else alpha
9209 beta = 0.8 if beta is None else beta
9210 res1 = f(t, m, v, alpha=alpha, beta=beta)
9211 res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
9212 if beta != 0:
9213 res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9214 res2 = torch.from_numpy(res2).to(dtype)
9215 self.assertEqual(res1, res2)
9216
9217 def test_addr(self, device="mps", dtype=torch.float32):
9218 M = torch.randn(10, 25, device=device).to(dtype)
9219 m1 = torch.randn(10, device=device).to(dtype)
9220 m2 = torch.randn(25, device=device).to(dtype)
9221 self._test_addr(torch.addr, M, m1, m2)
9222
9223 # Test beta=0, M=nan
9224 M = torch.full((10, 25), math.nan, device=device).to(dtype)
9225 m1 = torch.randn(10, device=device).to(dtype)
9226 m2 = torch.randn(25, device=device).to(dtype)
9227 self._test_addr(torch.addr, M, m1, m2, beta=0)
9228
watarungurunnnd444a3b2024-02-05 15:36:55 +00009229 def test_matrix_rank(self, device="mps", dtype=torch.float32):
9230 matrix_rank = torch.linalg.matrix_rank
9231
9232 def run_test(shape0, shape1, batch):
9233 a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
9234 rank_a = matrix_rank(a)
9235
9236 self.assertEqual(rank_a, matrix_rank(a.mH))
9237 aaH = torch.matmul(a, a.mH)
9238 rank_aaH = matrix_rank(aaH)
9239 rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
9240 self.assertEqual(rank_aaH, rank_aaH_hermitian)
9241 aHa = torch.matmul(a.mH, a)
9242 self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
9243
9244 # check against NumPy
9245 self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
9246 self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
9247
9248 self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
9249 self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
9250
9251 # hermitian flag for NumPy was added in 1.14.0
9252 if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
9253 self.assertEqual(rank_aaH_hermitian,
9254 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
9255 self.assertEqual(matrix_rank(aaH, 0.01, True),
9256 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
9257
9258 # check out= variant
9259 out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
9260 ans = matrix_rank(a, out=out)
9261 self.assertEqual(ans, out)
9262 self.assertEqual(ans, rank_a)
9263
9264 shapes = (3, 13)
9265 batches = ((), (0, ), (4, ), (3, 5, ))
9266 for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
9267 # escape only when NotImplementedError of downstream function is raised
9268 # TODO: remove this once the required function is implemented
9269 try:
9270 run_test(shape0, shape1, batch)
9271 except NotImplementedError as e:
9272 with self.assertRaisesRegex(
9273 NotImplementedError,
9274 "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
9275 raise e
9276
9277 def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
9278 from torch.testing._internal.common_utils import random_hermitian_pd_matrix
9279
9280 def run_test_main(A, hermitian):
9281 # Testing against definition for pseudo-inverses
9282 A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
9283 np_A = A.cpu().numpy()
9284 np_A_pinv = A_pinv.cpu().numpy()
9285 if A.numel() > 0:
9286 self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
9287 self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
9288 self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9289 self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9290 else:
9291 self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
9292
9293 # Check out= variant
9294 out = torch.empty_like(A_pinv)
9295 ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
9296 self.assertEqual(ans, out)
9297 self.assertEqual(ans, A_pinv)
9298
9299 def run_test_numpy(A, hermitian):
9300 # Check against NumPy output
9301 # Test float rcond, and specific value for each matrix
9302 rconds = [float(torch.rand(1)), ]
9303 # Test different types of rcond tensor
9304 for rcond_type in MPS_DTYPES:
9305 rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
9306 # Test broadcasting of rcond
9307 if A.ndim > 2:
9308 rconds.append(torch.rand(A.shape[-3], device=device))
9309 for rcond in rconds:
9310 actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
9311 torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
9312 self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
9313 numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
9314 expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
9315 self.assertEqual(actual, expected, atol=precision, rtol=precision)
9316
9317 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
9318 (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices
9319 (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices
9320 (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices
9321 A = torch.randn(*sizes, dtype=dtype, device=device)
9322 hermitian = False
9323 run_test_main(A, hermitian)
9324 run_test_numpy(A, hermitian)
9325
9326 # Check hermitian = True
9327 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
9328 (0, 0), (3, 0, 0), ]: # zero numel square matrices
9329 A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
9330 hermitian = True
9331 # escape only when NotImplementedError of downstream function is raised
9332 # TODO: remove this once the required function is implemented
9333 try:
9334 run_test_main(A, hermitian)
9335 except NotImplementedError as e:
9336 with self.assertRaisesRegex(
9337 NotImplementedError,
9338 "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9339 raise e
9340 try:
9341 run_test_numpy(A, hermitian)
9342 except NotImplementedError as e:
9343 with self.assertRaisesRegex(
9344 NotImplementedError,
9345 "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9346 raise e
9347
Nikita Shulga4ff91132024-05-24 16:08:04 +00009348 @parametrize("m", [1, 32, 64])
Nikita Shulga30610252024-05-03 15:20:39 +00009349 @parametrize("n", [48, 64])
Manuel Candaleseabe6572024-06-28 15:01:30 +00009350 @parametrize("q_group", [32, 64, 128, 256])
9351 @parametrize("num_groups", [1, 2])
9352 def test__int4_mm(self, m, n, q_group, num_groups):
9353 k = q_group * num_groups
Nikita Shulga30610252024-05-03 15:20:39 +00009354 inner_k_tiles = 2
9355
9356 torch.manual_seed(1)
9357 a_f32 = torch.rand((m, k), device="mps")
9358 b_f32 = torch.rand((k, n), device="mps")
9359
9360 def convert_weight_to_int4pack(b):
9361 b_int32, b_scales_and_zeros = _group_quantize_tensor(
Jiang, Yanbing6f662e92024-07-11 15:26:47 +00009362 b.to("cpu"), n_bit=4, q_group_size=q_group
Nikita Shulga30610252024-05-03 15:20:39 +00009363 )
Jiang, Yanbing6f662e92024-07-11 15:26:47 +00009364 b_int32 = b_int32.to("mps")
9365 b_scales_and_zeros = b_scales_and_zeros.to("mps")
Nikita Shulga30610252024-05-03 15:20:39 +00009366 b_int4pack = torch._convert_weight_to_int4pack(
Manuel Candales749c0342024-06-23 02:10:46 +00009367 b_int32, inner_k_tiles
9368 )
Nikita Shulga30610252024-05-03 15:20:39 +00009369
9370 return b_int4pack, b_scales_and_zeros
9371
9372 def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
9373 return torch._weight_int4pack_mm(
9374 a, b_int4pack, q_group, b_scales_and_zeros
Nikita Shulga4ff91132024-05-24 16:08:04 +00009375 )
Nikita Shulga30610252024-05-03 15:20:39 +00009376
9377 b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
9378
9379 for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9380 a = a_f32.to(dtype=dtype)
9381 b = b_f32.to(dtype=dtype)
9382 b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
9383 ref = torch.mm(a, b)
9384 res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
9385
9386 mean_err = ((res - ref).abs() / ref).mean()
Nikita Shulgabc689072024-06-26 16:29:59 +00009387 self.assertLess(mean_err, 0.05)
Nikita Shulga30610252024-05-03 15:20:39 +00009388
Nikita Shulga4ff91132024-05-24 16:08:04 +00009389 @parametrize("m", [1, 32, 64])
9390 @parametrize("k", [32, 64])
9391 @parametrize("n", [32, 64])
9392 def test__int8_mm(self, m, k, n):
9393 torch.manual_seed(1)
9394 a_f32 = torch.rand((m, k), device="mps")
9395 b_f32 = torch.rand((n, k), device="mps")
9396
9397 def convert_weight_to_int8pack(b):
9398 b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
9399 b, -128, 127, torch.int8
9400 )
9401 return b_int8pack, b_scales
9402
9403 def weight_int8pack_mm(a, b_int8pack, b_scales):
9404 return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
9405
9406 b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32)
9407 for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9408 a = a_f32.to(dtype=dtype)
9409 b = b_f32.to(dtype=dtype)
9410 b_scales = b_scales_f32.to(dtype=dtype)
9411 res = weight_int8pack_mm(a, b_int8pack, b_scales)
9412 ref = torch.mm(a, b.transpose(0, 1))
9413
9414 mean_err = ((res - ref).abs() / ref).mean()
Nikita Shulgabc689072024-06-26 16:29:59 +00009415 self.assertLess(mean_err, 0.05)
Nikita Shulga4ff91132024-05-24 16:08:04 +00009416
Nikita Shulga30610252024-05-03 15:20:39 +00009417
Manuel Candalesd6115432024-07-25 03:24:37 +00009418class TestSDPA(TestCaseMPS):
9419 def _compare_tensors(self, y, ref):
9420 denom = torch.maximum(ref.abs(), torch.tensor([1e-6], device=ref.device, dtype=ref.dtype))
9421 err = ((y - ref).abs() / denom).mean().item()
9422 self.assertLess(err, 0.01)
watarungurunnnd444a3b2024-02-05 15:36:55 +00009423
Li-Huai (Allan) Line7711d62024-08-28 17:44:58 -07009424 def _test_sdpa_no_mask(
9425 self,
9426 is_causal: bool,
9427 dtype: torch.dtype,
9428 L: int = 1,
9429 S: int = 72,
9430 NH: int = 32,
9431 HS: int = 128,
9432 requires_grad: bool = False
9433 ):
9434
Manuel Candalesd6115432024-07-25 03:24:37 +00009435 torch.manual_seed(1729)
9436 with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
Li-Huai (Allan) Line7711d62024-08-28 17:44:58 -07009437 q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
Manuel Candalesd6115432024-07-25 03:24:37 +00009438 k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9439 v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
Li-Huai (Allan) Line7711d62024-08-28 17:44:58 -07009440 q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad)
9441 k_cpu = k.cpu()
9442 v_cpu = v.cpu()
watarungurunnnd444a3b2024-02-05 15:36:55 +00009443
Manuel Candalesd6115432024-07-25 03:24:37 +00009444 y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
Li-Huai (Allan) Line7711d62024-08-28 17:44:58 -07009445 y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal)
Manuel Candalesd6115432024-07-25 03:24:37 +00009446
9447 self._compare_tensors(y.cpu(), y_ref)
9448
Li-Huai (Allan) Line7711d62024-08-28 17:44:58 -07009449 if requires_grad and torch.is_grad_enabled():
9450 y.sum().backward()
9451 y_ref.sum().backward()
9452
9453 self._compare_tensors(q.grad.cpu(), q_cpu.grad)
9454
Manuel Candalesd6115432024-07-25 03:24:37 +00009455 def test_sdpa_no_mask_no_causal_fp32(self):
9456 self._test_sdpa_no_mask(False, torch.float32)
9457
9458 def test_sdpa_no_mask_no_causal_fp16(self):
9459 self._test_sdpa_no_mask(False, torch.float16)
9460
9461 def test_sdpa_no_mask_causal_fp32(self):
9462 self._test_sdpa_no_mask(True, torch.float32)
9463
9464 def test_sdpa_no_mask_causal_fp16(self):
9465 self._test_sdpa_no_mask(True, torch.float16)
9466
9467 def test_sdpa_no_mask_causal_fp16_L7(self):
9468 self._test_sdpa_no_mask(True, torch.float16, 7)
9469
9470 def test_sdpa_no_mask_causal_fp16_L7_S17(self):
9471 self._test_sdpa_no_mask(True, torch.float16, 7, 17)
9472
9473 def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self):
9474 self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
9475
Li-Huai (Allan) Line7711d62024-08-28 17:44:58 -07009476 def test_sdpa_no_mask_no_causal_fp32_grad(self):
9477 self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
9478
9479 with torch.no_grad():
9480 self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
9481
Manuel Candalesd6115432024-07-25 03:24:37 +00009482 def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
9483 torch.manual_seed(1729)
9484 causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
9485 with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
9486 i = 42
9487
9488 q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
9489 k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9490 v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9491
9492 input_pos = torch.tensor([i], dtype=torch.int32, device='mps')
9493 mask = causal_mask[None, None, input_pos]
9494
9495 y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
9496 y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
9497
9498 self._compare_tensors(y.cpu(), y_ref)
9499
9500 def test_sdpa_mask_fp32(self):
9501 self._test_sdpa_mask(torch.float32)
9502
9503 def test_sdpa_mask_fp16(self):
9504 self._test_sdpa_mask(torch.float16)
9505
9506 def test_sdpa_mask_fp16_L6(self):
9507 self._test_sdpa_mask(torch.float16, 6)
9508
9509 def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
Li-Huai (Allan) Line7711d62024-08-28 17:44:58 -07009510 self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
watarungurunnnd444a3b2024-02-05 15:36:55 +00009511
9512
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00009513class TestGatherScatter(TestCaseMPS):
Kulin Sethb744e1c2022-07-01 15:10:56 +00009514 def test_slicing_with_step(self):
9515 # Slicing with step
9516 # https://github.com/pytorch/pytorch/issues/78886
9517 x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
9518 x_mps[::2] = 1.0
9519
Kulin Seth54361342022-07-06 03:39:20 +00009520 x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
Kulin Sethb744e1c2022-07-01 15:10:56 +00009521 x_cpu[::2] = 1.0
9522
9523 self.assertEqual(x_cpu, x_mps)
9524
Denis Vieriu4247cc92022-09-14 17:24:24 +00009525 def test_cast_gather_scatter(self):
9526 for _ in range(0, 50):
9527 input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8)
9528 with torch.no_grad():
9529 s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
9530 s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0)
9531 s = s.long()
9532 s_cpu = s_cpu.long()
9533 self.assertEqual(s.cpu(), s_cpu)
9534
9535 s = s.float()
9536 s_cpu = s_cpu.float()
9537 self.assertEqual(s.cpu(), s_cpu)
9538
9539 s /= 255
9540 s_cpu /= 255
9541 self.assertEqual(s.cpu(), s_cpu)
9542
Kulin Sethb744e1c2022-07-01 15:10:56 +00009543 def test_slicing_replace_column(self):
9544 # https://github.com/pytorch/pytorch/issues/78074
9545 def _helper(tensor_data):
9546 x_cpu = torch.tensor(tensor_data)
9547 x_mps = x_cpu.to('mps')
9548
9549 x_cpu[:, 0] = 7
9550 x_mps[:, 0] = 7
9551
9552 self.assertEqual(x_cpu, x_mps)
9553
9554 _helper([[1, 2, 3], [4, 5, 6]])
9555 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
9556 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
9557
9558 def test_inplace_scatter(self):
9559 # https://github.com/pytorch/pytorch/issues/79672
9560 a_mps = torch.ones((2, 2),).to(torch.device("mps"))
9561 b_mps = torch.ones((2, 2),).to(torch.device("mps"))
9562
9563 a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9564 b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9565
9566 a_mps[:, 0] += b_mps[:, 0]
9567 a_cpu[:, 0] += b_cpu[:, 0]
9568 self.assertEqual(a_cpu, a_mps)
9569
9570 a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
9571 a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
9572 self.assertEqual(a_cpu, a_mps)
9573
Kulin Seth76cff182022-07-04 06:41:39 +00009574# These tests were taken from test/test_view_ops.py
9575# They are subset of those tests as currently only this subset is working.
9576# This whole `class` will be removed when we add generic device testing. There
9577# are no additional tests added apart from what is part of test_view_ops.py
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +00009578class TestViewOpsMPS(TestCaseMPS):
Kulin Sethb744e1c2022-07-01 15:10:56 +00009579 exact_dtype = True
9580
Ramin Azarmehr36062dd2023-02-07 15:51:26 +00009581 def test_permute_slicing(self):
9582 # test the fix for crash reported in
9583 # https://github.com/pytorch/pytorch/issues/94190
9584 cpu_x = (torch.randn([3, 2, 2]).float())
9585 mps_x = cpu_x.detach().clone().to('mps')
9586 cpu_out = cpu_x.permute((2, 0, 1)) * 2.0
9587 mps_out = mps_x.permute((2, 0, 1)) * 2.0
9588 # this print caused a crash prior to fix PR#94259
9589 print(torch.zeros_like(mps_out))
Ramin Azarmehr4f691d22023-02-09 19:07:13 +00009590 # test the fix for fill_scalar_mps() mentioned in issue #94190
9591 self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out))
9592 self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1))
Ramin Azarmehr36062dd2023-02-07 15:51:26 +00009593
Kulin Sethb744e1c2022-07-01 15:10:56 +00009594 def is_view_of(self, base, other):
9595 if (not other._is_view() or
9596 other is base or
9597 other._base is not base or
9598 base.device != other.device):
9599 return False
9600 # Note: only validates storage on native device types
9601 # because some accelerators, like XLA, do not expose storage
Kulin Seth76cff182022-07-04 06:41:39 +00009602 if base.device.type == 'mps':
Nikita Shulgab8a706a2024-05-09 14:04:21 +00009603 if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
Kulin Sethb744e1c2022-07-01 15:10:56 +00009604 return False
9605
9606 return True
9607
9608 # Returns true if v1 and v2 are views of the same base
9609 def is_view_of_same_base(self, v1, v2):
9610 if (not v1._is_view() or v1 is v2):
9611 return False
9612 return self.is_view_of(v1._base, v2)
9613
9614 # Performs transpose if contiguous=True, else returns the input tensor as is
9615 def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
9616 if contiguous:
9617 return x
9618 else:
9619 return x.transpose(dim0, dim1)
9620
9621 def test_diagonal_view(self, device="mps"):
9622 t = torch.ones((5, 5), device=device)
9623 v = torch.diagonal(t)
9624 self.assertTrue(self.is_view_of(t, v))
9625
9626 v[0] = 0
9627 self.assertEqual(t[0, 0], v[0])
9628
9629 t = torch.ones((3, 3, 3), device="mps")
9630 v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
9631 self.assertTrue(self.is_view_of(t, v))
9632
9633 v[0, 0] = 0
9634 self.assertEqual(t[0, 0, 1], v[0, 0])
9635
9636 def test_select_view(self, device="mps") -> None:
9637 t = torch.ones((5, 5), device=device)
9638 v = t.select(0, 2)
9639 self.assertTrue(self.is_view_of(t, v))
9640
9641 v[0] = 0
9642 self.assertEqual(t[2, 0], v[0])
9643
9644 def test_unbind_view(self, device="mps") -> None:
9645 t = torch.zeros((5, 5), device=device)
9646 tup = torch.unbind(t)
9647
9648 for idx, v in enumerate(tup):
9649 self.assertTrue(self.is_view_of(t, v))
9650
9651 v[0] = idx + 1
9652 self.assertEqual(t[idx, 0], v[0])
9653
9654 def test_expand_view(self, device="mps") -> None:
9655 t = torch.ones((5, 1), device=device)
9656 v = t.expand(5, 5)
9657 self.assertTrue(self.is_view_of(t, v))
9658
9659 v[2, 2] = 0
9660 self.assertEqual(t[2, 0], v[2, 2])
9661
9662 def test_expand_as_view(self, device="mps"):
9663 t = torch.ones((5, 1), device=device)
9664 e = torch.empty((5, 5), device=device)
9665 v = t.expand_as(e)
9666 self.assertTrue(self.is_view_of(t, v))
9667
9668 v[2, 2] = 0
9669 self.assertEqual(t[2, 0], v[2, 2])
9670
9671 def test_narrow_view(self, device="mps"):
9672 t = torch.ones((5, 5), device=device)
9673 v = torch.narrow(t, 1, 2, 2)
9674 self.assertTrue(self.is_view_of(t, v))
9675
9676 v[0, 0] = 0
9677 self.assertEqual(t[0, 2], v[0, 0])
9678
9679 def test_permute_view(self, device="mps") -> None:
9680 t = torch.ones((5, 5), device=device)
9681 v = t.permute(1, 0)
9682 self.assertTrue(self.is_view_of(t, v))
9683
9684 v[0, 1] = 0
9685 self.assertEqual(t[1, 0], v[0, 1])
9686
9687 def test_transpose_view(self, device="mps"):
9688 for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
9689 t = torch.ones((5, 5), device=device)
9690 v = fn(t, 0, 1)
9691 self.assertTrue(self.is_view_of(t, v))
9692
9693 v[0, 1] = 0
9694 self.assertEqual(t[1, 0], v[0, 1])
9695
9696 def test_transpose_inplace_view(self, device="mps"):
9697 t = torch.ones(5, 5, device=device)
9698 v = t.view_as(t)
9699 v = v.swapdims_(0, 1)
9700 self.assertTrue(self.is_view_of(t, v))
9701 v[0, 1] = 0
9702 self.assertEqual(t[1, 0], v[0, 1])
9703
9704 t = torch.ones(5, 5, device=device)
9705 v = t.view_as(t)
9706 v = v.swapaxes_(0, 1)
9707 self.assertTrue(self.is_view_of(t, v))
9708 v[0, 1] = 0
9709 self.assertEqual(t[1, 0], v[0, 1])
9710
9711 t = torch.ones(5, 5, device=device)
9712 v = t.view_as(t)
9713 v = v.transpose_(0, 1)
9714 self.assertTrue(self.is_view_of(t, v))
9715 v[0, 1] = 0
9716 self.assertEqual(t[1, 0], v[0, 1])
9717
9718 def test_t_view(self, device="mps"):
9719 t = torch.ones((5, 5), device=device)
9720 v = t.t()
9721 self.assertTrue(self.is_view_of(t, v))
9722
9723 v[0, 1] = 0
9724 self.assertEqual(t[1, 0], v[0, 1])
9725
Denis Vieriuba275482024-05-08 01:00:37 +00009726 def test_inplace_view_add(self):
9727 # https://github.com/pytorch/pytorch/issues/96153
9728 t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3)
9729 t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3)
9730 t_mps = t_mps + 1
9731 t_cpu = t_cpu + 1
9732 self.assertEqual(t_mps, t_cpu)
9733
Kulin Sethb744e1c2022-07-01 15:10:56 +00009734 def test_t_inplace_view(self, device="mps"):
9735 t = torch.ones(5, 5, device=device)
9736 v = t.view_as(t)
9737 v = v.t_()
9738 self.assertTrue(self.is_view_of(t, v))
9739 v[0, 1] = 0
9740 self.assertEqual(t[1, 0], v[0, 1])
9741
9742 def test_T_view(self, device="mps"):
9743 for op in ("T", "H", "mT", "mH"):
9744 t = torch.ones((5, 5), device=device)
9745 v = getattr(t, op)
9746 self.assertTrue(self.is_view_of(t, v))
9747
9748 v[0, 1] = 0
9749 self.assertEqual(t[1, 0], v[0, 1])
9750
Denis Vieriu4477a5b2022-12-22 21:21:00 +00009751 def test_unfold_view(self, device="mps"):
9752 t = torch.ones(10, device=device)
9753 v = t.unfold(0, 3, 2)
9754 self.assertTrue(self.is_view_of(t, v))
Kulin Sethb744e1c2022-07-01 15:10:56 +00009755
Denis Vieriu4477a5b2022-12-22 21:21:00 +00009756 v[1, 0] = 0
9757 self.assertEqual(t[2], v[1, 0])
Kulin Sethb744e1c2022-07-01 15:10:56 +00009758
9759 def test_squeeze_view(self, device="mps"):
9760 t = torch.ones(5, 1, 5, device=device)
9761 v = torch.squeeze(t)
9762 self.assertTrue(self.is_view_of(t, v))
9763 v[0, 1] = 0
Nikita Shulgabc689072024-06-26 16:29:59 +00009764 self.assertIs(t, v._base)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009765
9766 def test_squeeze_inplace_view(self, device="mps"):
9767 t = torch.ones(5, 5, device=device)
9768 v = t.view_as(t)
9769 v = v.squeeze_()
9770 self.assertTrue(self.is_view_of(t, v))
9771 v[0, 1] = 0
Nikita Shulgabc689072024-06-26 16:29:59 +00009772 self.assertIs(t, v._base)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009773
9774 def test_unsqueeze_view(self, device="mps"):
9775 t = torch.ones(5, 5, device=device)
9776 v = torch.unsqueeze(t, 1)
9777 self.assertTrue(self.is_view_of(t, v))
9778
9779 v[0, 0, 1] = 0
9780 self.assertEqual(t[0, 1], v[0, 0, 1])
9781
9782 def test_unsqueeze_inplace_view(self, device="mps"):
9783 t = torch.ones(5, 5, device=device)
9784 v = t.view_as(t)
9785 v = v.unsqueeze_(1)
9786 self.assertTrue(self.is_view_of(t, v))
9787 v[0, 0, 1] = 0
9788 self.assertEqual(t[0, 1], v[0, 0, 1])
9789
9790 def test_as_strided_view(self, device="mps"):
9791 t = torch.ones(5, 5, device=device)
9792 v = torch.as_strided(t, (25,), (1,))
9793 self.assertTrue(self.is_view_of(t, v))
9794
9795 v[6] = 0
9796 self.assertEqual(t[1, 1], v[6])
9797
9798 def test_as_strided_inplace_view(self, device="mps"):
9799 t = torch.ones(5, 5, device=device)
9800 v = t.view_as(t)
9801 v = v.as_strided_((25,), (1,))
9802 self.assertTrue(self.is_view_of(t, v))
9803 v[6] = 0
9804 self.assertEqual(t[1, 1], v[6])
9805
9806 def test_view_view(self, device="mps"):
9807 t = torch.ones(5, 5, device=device)
9808 v = t.view(25)
9809 self.assertTrue(self.is_view_of(t, v))
9810
9811 v[6] = 0
9812 self.assertEqual(t[1, 1], v[6])
9813
9814 def test_view_as_view(self, device="mps"):
9815 t = torch.ones(5, 5, device=device)
9816 e = torch.empty((25,))
9817 v = t.view_as(e)
9818 self.assertTrue(self.is_view_of(t, v))
9819
9820 v[6] = 0
9821 self.assertEqual(t[1, 1], v[6])
9822
9823 def test_contiguous_self(self, device="mps"):
9824 t = torch.ones(5, 5, device=device)
9825 s = t.contiguous()
Nikita Shulgabc689072024-06-26 16:29:59 +00009826 self.assertIs(s, t)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009827
9828 def test_contiguous_nonview(self, device="mps"):
9829 t = torch.ones(5, 5, device=device)
9830 nv = t.t().contiguous()
Nikita Shulgabc689072024-06-26 16:29:59 +00009831 self.assertFalse(self.is_view_of(t, nv))
Kulin Sethb744e1c2022-07-01 15:10:56 +00009832
9833 nv[0, 0] = 0
9834 self.assertNotEqual(t[0, 0], nv[0, 0])
9835
9836 def test_reshape_view(self, device="mps"):
9837 t = torch.ones(5, 5, device=device)
9838 v = torch.reshape(t, (25,))
9839 self.assertTrue(self.is_view_of(t, v))
9840
9841 v[6] = 0
9842 self.assertEqual(t[1, 1], v[6])
9843
9844 def test_reshape_as_view(self, device="mps"):
9845 t = torch.ones(5, 5, device=device)
9846 e = torch.empty((25,), device=device)
9847 v = t.reshape_as(e)
9848 self.assertTrue(self.is_view_of(t, v))
9849
9850 v[6] = 0
9851 self.assertEqual(t[1, 1], v[6])
9852
9853 def test_reshape_nonview(self, device="mps"):
9854 t = torch.ones(5, 5, device=device)
9855 nv = torch.reshape(t.t(), (25,))
Nikita Shulgabc689072024-06-26 16:29:59 +00009856 self.assertFalse(self.is_view_of(t, nv))
Kulin Sethb744e1c2022-07-01 15:10:56 +00009857
9858 nv[6] = 0
9859 self.assertNotEqual(t[1, 1], nv[6])
9860
9861 def test_flatten_view(self, device="mps"):
9862 def test_writes_propagate(t, v):
9863 idx_t = (0,) * t.ndim
9864 idx_v = (0,) * v.ndim
9865 v[idx_v] = 0
9866 self.assertEqual(t[idx_t], v[idx_v])
9867
9868 t = torch.ones(1, 2, 3, 4, device=device)
9869 v = t.flatten()
9870 self.assertTrue(self.is_view_of(t, v))
9871 test_writes_propagate(t, v)
9872
9873 # zero-dimensional tensor
9874 t = torch.tensor(1, device=device)
9875 v = t.flatten()
9876 test_writes_propagate(t, v)
9877 self.assertTrue(self.is_view_of(t, v))
9878
9879 t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
9880 v = t.flatten(0, 1)
9881 test_writes_propagate(t, v)
9882 self.assertTrue(self.is_view_of_same_base(t, v))
9883
9884 # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
9885 t = torch.ones(720, device=device) \
9886 .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
9887 # [--1--|---2---|-3-] [--1--|----2---|-3-]
9888 v1 = t.flatten(0, 1)
9889 v2 = v1.flatten(1, 3)
9890 v3 = v2.flatten(2, 2)
9891 test_writes_propagate(t, v1)
9892 self.assertTrue(self.is_view_of_same_base(t, v1))
9893 test_writes_propagate(t, v2)
9894 self.assertTrue(self.is_view_of_same_base(t, v2))
9895 test_writes_propagate(t, v3)
9896 self.assertTrue(self.is_view_of_same_base(t, v3))
9897
9898 def test_flatten_nonview(self, device="mps"):
9899 def assert_is_nonview(t, nv):
9900 idx_t = (0,) * t.ndim
9901 idx_nv = (0,) * nv.ndim
Nikita Shulgabc689072024-06-26 16:29:59 +00009902 self.assertFalse(nv._is_view())
Kulin Sethb744e1c2022-07-01 15:10:56 +00009903 nv[idx_nv] = 0
9904 self.assertNotEqual(t[idx_t], nv[idx_nv])
9905 t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
9906 nv = t.flatten(1, 3)
9907 assert_is_nonview(t, nv)
9908
9909 t = torch.ones(2, 2, device=device).T
9910 nv = t.flatten()
9911 assert_is_nonview(t, nv)
9912
9913 # flatten returns the original object if start_dim=end_dim
9914 t = t = torch.ones(2, 2, device=device)
9915 nv = t.flatten(1, 1)
Nikita Shulgabc689072024-06-26 16:29:59 +00009916 self.assertIs(t, nv)
Kulin Sethb744e1c2022-07-01 15:10:56 +00009917
9918 def test_basic_indexing_slice_view(self, device="mps"):
9919 t = torch.ones(5, 5, device=device)
9920 v = t[:2, :3]
9921 self.assertTrue(self.is_view_of(t, v))
9922
9923 v[0, 0] = 0
9924 self.assertEqual(t[0, 0], v[0, 0])
9925
9926 def test_basic_indexing_ellipses_view(self, device="mps"):
9927 t = torch.ones(5, 5, device=device)
9928 v = t[..., :2]
9929 self.assertTrue(self.is_view_of(t, v))
9930
9931 v[0, 0] = 0
9932 self.assertEqual(t[0, 0], v[0, 0])
9933
9934 def test_basic_indexing_newaxis_view(self, device="mps"):
9935 t = torch.ones(5, 5, device=device)
9936 v = t[None, :2, 3]
9937 self.assertTrue(self.is_view_of(t, v))
9938
9939 v[0, 0] = 0
9940 self.assertEqual(t[0, 3], v[0, 0])
9941
9942 def test_chunk_view(self, device="mps"):
9943 t = torch.zeros(3, 3, device=device)
9944 l = torch.chunk(t, 3)
9945
9946 for idx, v in enumerate(l):
9947 self.assertTrue(self.is_view_of(t, v))
9948
9949 v[0, 0] = idx + 1
9950 self.assertEqual(t[idx, 0], v[0, 0])
9951
9952 def test_split_view(self, device="mps"):
9953 t = torch.zeros(3, 3, device=device)
9954 l = torch.split(t, [1, 1, 1])
9955
9956 for idx, v in enumerate(l):
9957 self.assertTrue(self.is_view_of(t, v))
9958
9959 v[0, 0] = idx + 1
9960 self.assertEqual(t[idx, 0], v[0, 0])
9961
9962 def test_movedim_view(self, device="mps"):
9963 def run_test(device, op):
9964 t = torch.zeros(3, 3, device=device)
9965 out = op(t)
9966
9967 self.assertTrue(self.is_view_of(t, out))
9968
9969 # Randomly change values in output
9970 # and verify that original is changed
9971 # as well.
9972 for _ in range(3):
9973 idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
9974 out[idx_1, idx_2] = random.random()
9975 self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
9976
9977 for fn in [torch.movedim, torch.moveaxis]:
9978 op = partial(fn, source=(0, 1), destination=(1, 0))
9979 run_test(device, op)
9980
9981 op = partial(fn, source=0, destination=1)
9982 run_test(device, op)
9983
9984 # Testing that the generated view_copy kernel and its derivative are implemented correctly
9985 def test_view_copy(self, device="mps"):
9986 a = torch.randn(4, device=device, requires_grad=True)
9987 a_ref = a.clone().detach().requires_grad_()
9988 a_view = a_ref.view(2, 2)
9989 a_view_copy = torch.view_copy(a, (2, 2))
9990
9991 # view_copy ops don't preserve view relationship
9992 self.assertTrue(self.is_view_of(a_ref, a_view))
9993 self.assertFalse(self.is_view_of(a, a_view_copy))
9994
9995 a_view_copy.sum().backward()
9996 a_view.sum().backward()
9997
9998 # forward and backward give the same shape + result
9999 self.assertEqual(a_view_copy, a_view)
10000 self.assertEqual(a.grad, a_ref.grad)
10001
10002 def test_view_copy_out(self, device="mps"):
10003 a = torch.randn(2, 2, device=device)
10004 out = torch.empty(2, device=device)
10005
10006 torch.diagonal_copy(a, out=out)
10007 expected = torch.diagonal_copy(a)
10008
10009 self.assertEqual(expected, out)
10010
10011 a = torch.randn(4, device=device)
10012 out1 = torch.empty(2, device=device)
10013 out2 = torch.empty(2, device=device)
10014
10015 torch.split_copy(a, 2, out=(out1, out2))
10016 expected1, expected2 = torch.split_copy(a, 2)
10017
10018 self.assertEqual(expected1, out1)
10019 self.assertEqual(expected2, out2)
10020
Nikita Shulga13cff2e2022-10-14 17:35:18 +000010021 def test_detached_view_copy(self, device="mps"):
10022 # https://github.com/pytorch/pytorch/issues/86052
10023 x = torch.arange(2)
10024 # .detach() makes y not a view, but contig tensor
10025 # with non-zero offset
10026 y = x[1].detach()
10027 z = y.to(device)
10028 self.assertEqual(y, z.cpu())
10029
Kulin Sethb744e1c2022-07-01 15:10:56 +000010030 def test_empty_reshape(self, device="mps"):
10031 x = torch.randn(0, 6, device=device)
10032 self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
10033 # should be viewable -- i.e. data_ptr is the same.
10034 self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
10035
10036 # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
10037 self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
10038
10039 def test_expand(self, device="mps"):
10040 tensor = torch.rand(1, 8, 1, device=device)
10041 tensor2 = torch.rand(5, device=device)
10042 template = torch.rand(4, 8, 5, device=device)
10043 target = template.size()
10044 self.assertEqual(tensor.expand_as(template).size(), target)
10045 self.assertEqual(tensor.expand(4, 8, 5).size(), target)
10046 self.assertEqual(tensor.expand(target).size(), target)
10047 self.assertEqual(tensor2.expand_as(template).size(), target)
10048 self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
10049 self.assertEqual(tensor2.expand(target).size(), target)
10050
10051 # test double expand
10052 self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
10053
10054 # test non-contiguous
10055 noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
10056 self.assertFalse(noncontig.is_contiguous())
10057 self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
10058
10059 # make sure it's compatible with unsqueeze
10060 expanded = tensor2.expand(1, 1, 5)
10061 unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
10062 self.assertEqual(expanded, unsqueezed)
10063 self.assertEqual(expanded.stride(), unsqueezed.stride())
10064
10065 # test -1 as target size
10066 self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
10067 self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
10068
10069 # test expanding empty to empty
10070 self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
10071
10072 def test_view_empty(self, device="mps"):
10073 x = torch.randn(0, 6, device=device)
10074 self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
10075
10076 def test_reshape(self, device="mps"):
10077 x = torch.randn(3, 3, device=device)
10078 self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
10079 self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
10080 self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
10081 self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
10082
10083 y = torch.randn(4, 4, 4, device=device)[:, 0, :]
10084 # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
10085 if device != "meta":
10086 self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
10087 self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
10088 self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
10089
10090 s = torch.randn((), device=device)
10091 self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
10092 self.assertEqual(s.reshape(-1).shape, (1,))
10093 self.assertRaises(RuntimeError, lambda: s.reshape(2))
10094
10095 empty = torch.tensor([], device=device)
10096 self.assertEqual(empty, empty.reshape(-1))
10097 self.assertEqual(empty, empty.reshape([0]))
10098 # TODO: fix these once we have multi-dimensional empty tensors
10099 self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
10100 self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
10101 self.assertRaises(RuntimeError, lambda: empty.reshape(1))
10102
10103 x = torch.randn(3, 3, device=device)
10104 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
10105 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
10106 self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
10107
10108 def test_narrow(self, device="mps"):
10109 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
10110 self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
10111 self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
10112 self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
10113 self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
10114 self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
10115 self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
10116 self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
10117 self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
10118
10119 def test_narrow_tensor(self, device="mps"):
10120 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
10121 self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
10122 with self.assertRaises(Exception):
10123 x.narrow(0, torch.tensor(0.), 1)
10124 with self.assertRaises(Exception):
10125 x.narrow(0, torch.tensor([0]), 1)
10126 with self.assertRaises(Exception):
10127 x.narrow(0, torch.tensor([0, 1]), 1)
10128
10129 def test_t(self, device="mps"):
10130 # Test 0D tensors
10131 x = torch.randn(())
10132 self.assertEqual(x, x.t())
10133 x = x.to_sparse()
10134 self.assertEqual(x, x.t())
10135
10136 # Test 1D tensors
10137 x = torch.arange(4)
10138 self.assertEqual(x, x.t())
10139 x = x.to_sparse()
10140 self.assertEqual(x, x.t())
10141
10142 # Test 2D tensors
10143 x = torch.rand((2, 2))
10144 self.assertEqual(x.t(), x.transpose(0, 1))
10145 x = x.to_sparse()
10146 self.assertEqual(x.t(), x.transpose(0, 1))
10147
10148 # Test 3D tensor
10149 x = torch.rand((2, 2, 2))
10150 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
10151 x.t()
10152 x = x.to_sparse()
10153 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
10154 x.t()
10155
10156 def test_split(self, device="mps"):
10157 tensor = torch.rand(7, 4)
10158 split_size = 3
10159 dim = 0
10160 target_sizes = ([3, 4], [3, 4], [1, 4])
10161 splits = tensor.split(split_size, dim)
10162 start = 0
10163 for target_size, split in zip(target_sizes, splits):
10164 self.assertEqual(split.size(), target_size)
10165 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10166 start = start + target_size[dim]
10167
10168 # Variable sections split
10169 tensor = torch.randn(20, 10)
10170 dim = 0
10171 split_sizes = [5, 5, 10]
10172 target_sizes = ([[5, 10], [5, 10], [10, 10]])
10173 splits = tensor.split(split_sizes, dim)
10174 start = 0
10175 for target_size, split in zip(target_sizes, splits):
10176 self.assertEqual(split.size(), target_size)
10177 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10178 start = start + target_size[dim]
10179
10180 split_sizes = [2, 2, 6]
10181 target_sizes = ([20, 2], [20, 2], [20, 6])
10182 dim = 1
10183 splits = tensor.split(split_sizes, dim)
10184 start = 0
10185 for target_size, split in zip(target_sizes, splits):
10186 self.assertEqual(split.size(), target_size)
10187 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10188 start = start + target_size[dim]
10189
10190 def test_chunk(self, device="mps"):
10191 tensor = torch.rand(4, 7)
10192 num_chunks = 3
10193 dim = 1
10194 target_sizes = ([4, 3], [4, 3], [4, 1])
10195 splits = tensor.chunk(num_chunks, dim)
10196 start = 0
10197 for target_size, split in zip(target_sizes, splits):
10198 self.assertEqual(split.size(), target_size)
10199 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
10200 atol=0, rtol=0)
10201 start = start + target_size[dim]
10202
10203 # Invalid chunk sizes
10204 error_regex = 'chunk expects.*greater than 0'
10205 with self.assertRaisesRegex(RuntimeError, error_regex):
10206 tensor.chunk(0)
10207 with self.assertRaisesRegex(RuntimeError, error_regex):
10208 tensor.chunk(-2)
10209
10210 def test_unsqueeze(self, device="mps") -> None:
10211 x = torch.randn(2, 3, 4)
10212 y = x.unsqueeze(1)
10213 self.assertEqual(y, x.view(2, 1, 3, 4))
10214 y = x.clone().unsqueeze_(2)
10215 self.assertEqual(y, x.view(2, 3, 1, 4))
10216
10217 x = x[:, 1]
10218 self.assertFalse(x.is_contiguous())
10219 y = x.unsqueeze(1)
10220 self.assertEqual(y, x.contiguous().view(2, 1, 4))
10221 y = x.clone().unsqueeze_(2)
10222 self.assertEqual(y, x.contiguous().view(2, 4, 1))
10223
10224 # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
10225 def test_big_transpose(self, device="mps"):
10226 t = torch.rand(456, 789, device=device)
10227 t1 = t.t().contiguous()
10228 t2 = torch.from_numpy(t.cpu().numpy().transpose())
10229 self.assertEqual(t1, t2)
10230
10231 def test_T(self, device="mps"):
10232 a = torch.randn(2, 3, 4, device=device)
10233 t1 = a.T
10234 t2 = a.permute(2, 1, 0)
10235 self.assertEqual(t2, t1)
10236 b = torch.randn(10, device=device)
10237 self.assertEqual(b, b.T)
Kulin Sethb744e1c2022-07-01 15:10:56 +000010238
10239 def test_transposes(self, device="mps", dtype=torch.float32):
10240 for op in ("T", "H", "mT", "mH", "adjoint"):
lezcano46a81c82023-01-15 19:35:15 +000010241 shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
Kulin Sethb744e1c2022-07-01 15:10:56 +000010242 for shape in shapes:
10243 a = make_tensor(shape, device=device, dtype=dtype)
10244 t1 = getattr(a, op)
10245 if op == "adjoint":
10246 t1 = t1()
10247 t2 = a
10248 if a.ndim != 0:
10249 t2 = t2.transpose(-2, -1)
10250 if op[-1] == "H" or op == "adjoint":
10251 t2 = t2.conj()
10252 self.assertEqual(t2, t1)
10253
10254 def test_transposes_errors(self, device="mps", dtype=torch.float32):
10255 for op in ("H", "mT", "mH", "adjoint"):
10256 shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
10257 for shape in shapes:
10258 a = make_tensor(shape, device=device, dtype=dtype)
10259 with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
10260 t1 = getattr(a, op)
10261 if op == "adjoint":
10262 t1 = t1()
10263
10264 def test_python_types(self, device="mps"):
10265 a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
10266 a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
10267 self.assertEqual(a1.dtype, a2.dtype)
10268
10269 b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
10270 b2 = torch.arange(10, 20, dtype=int, device=device)
10271 self.assertEqual(b1.dtype, b2.dtype)
10272
10273 c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
10274 c2 = torch.tensor([True, False], dtype=bool, device=device)
10275 self.assertEqual(c1.dtype, c2.dtype)
10276
10277 # TODO: is resize best put in test_view_ops?
10278 def test_resize_as_preserves_strides(self, device="mps"):
10279 x = torch.empty(2, 3).t()
10280 old_strides = x.stride()
10281 x.resize_as_(x)
10282 self.assertEqual(x.stride(), old_strides)
10283
10284 def test_memory_format_resize_as(self, device="mps"):
10285 def test_helper(shape, memory_format, device="mps"):
10286 xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
10287 flat = torch.randn(xc.numel(), device=device)
10288 flat.resize_as_(xc, memory_format=torch.preserve_format)
10289 self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10290
10291 test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
10292 test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
10293
10294 def test_memory_format_resize_(self, device="mps"):
10295 def test_helper(shape, numel, memory_format, device="mps"):
10296 flat = torch.randn(numel, device=device)
10297 flat.resize_(shape, memory_format=memory_format)
10298 self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10299
10300 test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
10301 test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
10302
10303 # TODO: OpInfo this
10304 def _test_atleast(self, device, torch_fn):
10305 # 0-dim
10306 s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
10307
10308 gradcheck(lambda x: torch_fn(x), s)
10309 gradgradcheck(lambda x: torch_fn(x), s)
10310
10311 # 1-dim
10312 a = torch.rand(4, dtype=torch.double, requires_grad=True)
10313
10314 gradcheck(lambda x: torch_fn(x), a)
10315 gradgradcheck(lambda x: torch_fn(x), a)
10316
10317 # 2,3,4-dim
10318 b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
10319 c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
10320 d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
10321
10322 input_tuple = (s, a, b, c, d)
10323 gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10324 gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10325
10326 def test_atleast_gradient(self, device="mps"):
10327 self._test_atleast(device, torch.atleast_1d)
10328 self._test_atleast(device, torch.atleast_2d)
10329 self._test_atleast(device, torch.atleast_3d)
10330
10331 def test_view(self, device="mps"):
10332 tensor = torch.rand(15, device=device)
10333 template = torch.rand(3, 5, device=device)
10334 empty = torch.empty(0, device=device)
10335 target = template.size()
10336 self.assertEqual(tensor.view_as(template).size(), target)
10337 self.assertEqual(tensor.view(3, 5).size(), target)
10338 self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
10339 self.assertEqual(tensor.view(-1, 5).size(), target)
10340 self.assertEqual(tensor.view(3, -1).size(), target)
10341 tensor_view = tensor.view(5, 3)
10342 tensor_view.fill_(random.uniform(0, 1))
10343 self.assertEqual(empty.view_as(empty), empty)
10344 self.assertEqual(empty.view(0), empty)
10345 self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
10346 self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
10347
10348 # test size inference with empty tensors
10349 self.assertEqual(empty.view(-1).size(), torch.Size([0]))
10350 self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
10351
10352 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10353 empty.view(-1, 0)
10354
10355 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10356 empty.view(3, 0, -1, 0)
10357
10358 self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
10359 self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
10360 self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
10361
Kulin Seth76cff182022-07-04 06:41:39 +000010362 def test_contiguous(self, device="mps"):
10363 x = torch.randn(1, 16, 5, 5, device=device)
10364 self.assertTrue(x.is_contiguous())
10365 stride = list(x.stride())
10366 stride[0] = 20
10367 # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
10368 x.set_(x.storage(), 0, x.size(), stride)
10369 self.assertTrue(x.is_contiguous())
Kulin Sethb744e1c2022-07-01 15:10:56 +000010370
Nikita Shulga436993d2023-03-04 01:29:07 +000010371 def test_resize_mps_dtypes(self, device="mps"):
Kulin Sethb744e1c2022-07-01 15:10:56 +000010372 shape = (2, 2)
Nikita Shulga436993d2023-03-04 01:29:07 +000010373 for dt in MPS_DTYPES:
Kulin Sethb744e1c2022-07-01 15:10:56 +000010374 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10375 x.resize_(shape)
10376 self.assertEqual(shape, x.shape)
10377
Nikita Shulga436993d2023-03-04 01:29:07 +000010378 def test_resize_as_mps_dtypes(self, device="mps"):
10379 for dt in MPS_DTYPES:
Kulin Sethb744e1c2022-07-01 15:10:56 +000010380 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10381 y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
10382 x.resize_as_(y)
10383 self.assertEqual(y.shape, x.shape)
10384
10385 def test_resize_overflow(self, device="mps"):
10386 x = torch.empty((), dtype=torch.float64)
10387 with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
10388 x.resize_([2, 4, 2**29, 2**29])
10389 with self.assertRaisesRegex(RuntimeError, 'overflow'):
10390 x.resize_([8, 8, 2**29, 2**29])
10391
10392 def test_view_all_dtypes_and_devices(self, device="mps"):
10393 for dt in (torch.float, torch.bool):
10394 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10395 self.assertEqual(x.view(6).shape, [6])
Kulin Sethe011a8e2022-05-13 18:28:53 +000010396
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000010397class TestConvolutionMPS(TestCaseMPS):
Kulin Seth31d4b6f2022-08-17 00:26:41 +000010398 def test_conv1d_all_strides_paddings(self):
10399 # https://github.com/pytorch/pytorch/issues/82921
10400 def helper(stride, padding):
10401 y_cpu = torch.randn(1, 57, 40)
10402 conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False)
10403 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10404 x_cpu = conv_cpu(y_cpu)
10405
10406 y_gpu = y_cpu.to(device='mps')
10407 x_gpu = conv_gpu(y_gpu)
10408 self.assertEqual(x_cpu, x_gpu.cpu())
10409 for stride in range(1, 4):
10410 for padding in range(1, 4):
10411 helper(stride, padding)
10412
10413
10414 def test_conv1d_channels_last(self):
10415 # https://github.com/pytorch/pytorch/issues/81557
10416 model_cpu = torch.nn.Conv1d(1, 128, 3)
10417 a_cpu = torch.arange((128 * 176), dtype=torch.float32)
10418 a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
10419 out_cpu = model_cpu(a_cpu)
10420
10421 a_mps = a_cpu.detach().clone().to("mps")
10422 model_mps = model_cpu.to("mps")
10423 out_mps = model_mps(a_mps)
10424
10425 self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
10426
10427 def test_conv_transpose_1d_all_strides(self):
10428 # https://github.com/pytorch/pytorch/issues/82711
10429 def helper(stride):
10430 y_cpu = torch.ones(1, 1, 2)
10431 deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1)
10432 deconv_cpu.weight.data = torch.ones(1, 1, 2)
10433 deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
10434 x_cpu = deconv_cpu(y_cpu)
10435
10436 y_gpu = y_cpu.to(device='mps')
10437 x_gpu = deconv_gpu(y_gpu)
10438 self.assertEqual(x_cpu, x_gpu.cpu())
10439 [helper(stride) for stride in [1, 2, 3]]
10440
10441 def test_conv_transpose_1d_nn_functional(self):
10442 # https://github.com/pytorch/pytorch/issues/82563
10443 tin = torch.rand((1, 512, 1245), dtype=torch.float32)
10444 tparams = torch.rand((512, 256, 16), dtype=torch.float32)
10445 tbias = torch.rand((256), dtype=torch.float32)
10446
10447 device = 'cpu'
10448 tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10449
10450 device = 'mps'
10451 tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10452
10453 self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
10454
Kulin Seth077db3d2022-09-20 06:19:40 +000010455 def test_conv_backward_1d_channels_last(self):
Denis Vieriue0b82d72023-01-10 18:30:18 +000010456 def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
10457 # https://github.com/pytorch/pytorch/issues/84511
Denis Vieriu5e475712023-02-22 18:04:09 +000010458 conv_cpu = torch.nn.Conv1d(
10459 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
Denis Vieriue0b82d72023-01-10 18:30:18 +000010460 conv_mps = torch.nn.Conv1d(
10461 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
10462 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
10463 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
Kulin Seth077db3d2022-09-20 06:19:40 +000010464
Kulin Seth077db3d2022-09-20 06:19:40 +000010465
Denis Vieriue0b82d72023-01-10 18:30:18 +000010466 data = torch.rand(shape, dtype=torch.float32)
10467 x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
10468 x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
10469 res_cpu = conv_cpu(x_cpu)
10470 res_mps = conv_mps(x_mps)
10471 self.assertEqual(res_cpu, res_mps)
10472 res_cpu = res_cpu.sum().backward()
10473 res_mps = res_mps.sum().backward()
10474
10475 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10476 self.assertEqual(x_cpu.grad, x_mps.grad)
10477
10478 helper(shape=(1, 176, 1))
10479 helper(shape=(2, 12, 1))
10480 helper(shape=(3, 176, 1))
10481 helper(shape=(4, 376, 1))
10482 helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
10483 helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
Kulin Seth077db3d2022-09-20 06:19:40 +000010484
Kulin Seth31d4b6f2022-08-17 00:26:41 +000010485 def test_conv1d_contiguous(self):
10486 model_cpu = torch.nn.Conv1d(1, 128, 3)
10487 a_cpu = torch.ones(128, 1, 176)
10488 out_cpu = model_cpu(a_cpu)
10489
10490 a_mps = a_cpu.detach().clone().to("mps")
10491 model_mps = model_cpu.to("mps")
10492 out_mps = model_mps(a_mps)
10493
10494 self.assertEqual(out_cpu.shape, out_mps.shape)
10495 self.assertEqual(out_cpu, out_mps.cpu())
10496
10497 def test_conv2d_all_strides_paddings(self):
10498 # https://github.com/pytorch/pytorch/issues/83180
Denis Vieriu5e475712023-02-22 18:04:09 +000010499 def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
10500 x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
10501 x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
10502
10503 if permute_data:
10504 x_cpu.permute(0, 2, 3, 1)
10505 x_mps.permute(0, 2, 3, 1)
10506
10507 for strideX in range(1, 4):
10508 for strideY in range(1, 4):
10509 conv_cpu = torch.nn.Conv2d(
10510 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
10511 conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
10512
10513 conv_mps = torch.nn.Conv2d(
10514 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
10515 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10516 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10517
10518 res_cpu = conv_cpu(x_cpu)
10519 res_mps = conv_mps(x_mps)
10520 self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
Denis Vieriu5e475712023-02-22 18:04:09 +000010521 res_cpu = res_cpu.sum().backward()
10522 res_mps = res_mps.sum().backward()
10523 self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
Denis Vieriu861bdf92024-08-16 21:07:48 +000010524
Denis Vieriu5e475712023-02-22 18:04:09 +000010525 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10526 self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
10527 self.assertEqual(x_cpu.grad, x_mps.grad)
10528
10529 for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10530 for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
10531 for permute_data in [True, False]:
10532 helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
10533 helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10534 helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10535
10536 def test_conv_transpose_2d_strided(self):
10537 def helper(m_cpu, memory_format):
10538 m_mps = copy.deepcopy(m_cpu).requires_grad_()
10539 m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10540 m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10541
10542 input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
10543 input_mps = input_cpu.detach().clone().to("mps")
10544
10545 output_cpu = m_cpu(input_cpu)
10546 output_mps = m_mps(input_mps)
10547 self.assertEqual(output_cpu, output_mps)
10548
10549 for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10550 # With square kernels and equal stride
10551 helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
10552
10553 # non-square kernels and unequal stride and with padding
10554 helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
10555
10556 def test_conv_transpose_2d_specified_output(self):
10557 input_cpu = torch.randn(1, 16, 12, 12)
10558 input_mps = input_cpu.detach().clone().to("mps")
10559
10560 downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
10561 downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
10562 downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10563 downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10564
10565 upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
10566 upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
10567 upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10568 upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10569
10570 h_cpu = downsample_cpu(input_cpu)
10571 h_mps = downsample_mps(input_mps)
10572 self.assertEqual(h_cpu, h_mps)
10573
10574 size_cpu = h_cpu.size()
10575 size_mps = h_mps.size()
10576 self.assertEqual(size_cpu, size_mps)
10577
10578 output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
10579 output_mps = upsample_mps(h_mps, output_size=input_mps.size())
10580 self.assertEqual(output_cpu, output_mps)
10581 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Seth31d4b6f2022-08-17 00:26:41 +000010582
10583 def test_conv2d_single_stride(self):
10584 y_cpu = torch.randn(2, 2, 3, 6)
10585 y_gpu = y_cpu.to(device='mps')
10586 for stride in range(1, 4):
10587 conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride)
10588 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10589 x_cpu = conv_cpu(y_cpu)
10590 x_gpu = conv_gpu(y_gpu)
10591 self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10592
Lucas Steuernagel2e517b22023-12-15 23:05:01 +000010593 @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
10594 def test_conv3d_single_stride(self):
10595 # Conv3d is only available from MacOS 13.2 onwards
10596 y_cpu = torch.randn(2, 2, 3, 6)
10597 y_gpu = y_cpu.to(device='mps')
10598 for stride in range(1, 4):
10599 conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
10600 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10601 x_cpu = conv_cpu(y_cpu)
10602 x_gpu = conv_gpu(y_gpu)
10603 self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10604
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010605 def test_grid_sample(self):
10606 def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
10607 def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
10608 for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
10609 # grid_dim_contig_order specifies the dimension order that can
10610 # make grid to be contiguous.
10611 # i.e., grid.permute(grid_dim_contig_order) is contiguous.
10612 # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
10613 # initialized with contiguous tensor of shape [N, 2, H, W]
10614 # and permuted to [N, H, W, 2] afterwards.
10615 grid_shape = [N, H, W, 2]
10616 grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
10617 grid_fwd_permute = [None, None, None, None]
10618 for i, d in enumerate(grid_dim_contig_order):
10619 grid_fwd_permute[d] = i
10620
10621 def get_grid(device='cpu', data=None):
10622 if data is not None:
10623 assert list(data.shape) == grid_shape
10624 data = data.permute(grid_dim_contig_order).to(device)
10625 else:
10626 data = torch.randn(grid_init_shape, device=device)
10627 grid = data.permute(grid_fwd_permute)
10628 assert grid.permute(grid_dim_contig_order).is_contiguous()
10629 return grid
10630
10631 input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
10632 grid_cpu = get_grid().requires_grad_()
10633 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10634 align_corners=align_corners)
Nikita Shulgabc689072024-06-26 16:29:59 +000010635 self.assertEqual(out_cpu.size(), torch.Size([N, C, H, W]))
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010636
10637 gradients = torch.randn_like(out_cpu)
10638 out_cpu.backward(gradients)
10639
10640
10641 # Compare against unvectorized CPU fallback
10642
10643 # NOTE [ grid_sample CPU fallback ]
10644 # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
10645 # 32-bit floats. So we also have a fallback that is used only for float tensors
10646 # requiring 64-bit indexing. That requires too much memory to run on CI, so we
10647 # also export the fallback and test it here to ensure feature parity with
10648 # the vectorized version.
10649 input_fallback = input_cpu.float().detach_().requires_grad_()
10650 grid_fallback = grid_cpu.float().detach_().requires_grad_()
10651 out_fallback = torch._grid_sampler_2d_cpu_fallback(
10652 input_fallback, grid_fallback,
10653 F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
10654 F.GRID_SAMPLE_PADDING_MODES[padding_mode],
10655 align_corners)
10656 self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
10657
10658 out_fallback.backward(gradients.float())
10659 if input_requires_grad:
10660 self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10661 self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10662
10663 input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad)
10664 grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
10665 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10666 self.assertEqual(out_cpu, out_mps)
10667 out_mps.backward(gradients.to("mps"))
10668 if input_requires_grad:
10669 self.assertEqual(input_cpu.grad, input_mps.grad)
10670 self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0)
10671
10672 # check that zero-dimensional input strides don't error out
10673 base_input = torch.randn(N, C, 1, IW)
10674 input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad)
10675 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10676 align_corners=align_corners)
10677
10678 input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
10679 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10680 self.assertEqual(out_cpu, out_mps)
10681
10682 # test same size output
10683 test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
10684
10685 # test larger output
10686 N = random.randint(2, 8)
10687 C = random.randint(2, 8)
10688 IH = random.randint(2, 8)
10689 IW = random.randint(2, 8)
10690 H = random.randint(IH + 1, 12)
10691 W = random.randint(IW + 1, 12)
10692 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10693
10694 # test smaller output
10695 N = random.randint(2, 8)
10696 C = random.randint(2, 8)
10697 IH = random.randint(2, 8)
10698 IW = random.randint(2, 8)
10699 H = random.randint(2, IH)
10700 W = random.randint(2, IW)
10701 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10702
10703 # test 1x1 inpput
10704 N = random.randint(2, 8)
10705 C = random.randint(2, 8)
10706 IH = 1
10707 IW = 1
10708 H = random.randint(2, 5)
10709 W = random.randint(2, 5)
10710 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10711
10712 # testing empty grid
10713 N = random.randint(2, 8)
10714 C = random.randint(2, 8)
10715 IH = random.randint(2, 8)
10716 IW = random.randint(2, 8)
10717 W = random.randint(3, IW + 2)
10718 test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
10719
10720 # testing empty channel
10721 N = random.randint(2, 8)
10722 IH = random.randint(2, 8)
10723 IW = random.randint(2, 8)
10724 H = random.randint(3, IH + 2)
10725 W = random.randint(3, IW + 2)
10726 test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
10727
10728 # testing empty batch
10729 C = random.randint(2, 8)
10730 IH = random.randint(2, 8)
10731 IW = random.randint(2, 8)
10732 H = random.randint(3, IH + 2)
10733 W = random.randint(3, IW + 2)
10734 test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
10735
10736 for mode in ('bilinear', 'nearest'):
10737 for padding_mode in ('zeros', 'reflection'):
10738 for align_corners in (True, False):
10739 # test known input
10740 input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
10741 grid = torch.tensor(
10742 [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
10743 [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
10744 if mode == 'bilinear':
10745 if padding_mode == 'zeros':
10746 if align_corners:
10747 groundtruth = torch.tensor(
10748 [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10749 [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
10750 else:
10751 groundtruth = torch.tensor(
10752 [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
10753 [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
10754 elif padding_mode == 'border':
10755 if align_corners:
10756 groundtruth = torch.tensor(
10757 [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10758 [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
10759 else:
10760 groundtruth = torch.tensor(
10761 [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10762 [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
10763 elif padding_mode == 'reflection':
10764 if align_corners:
10765 groundtruth = torch.tensor(
10766 [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
10767 [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
10768 else:
10769 groundtruth = torch.tensor(
10770 [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10771 [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
10772 else:
Justin Chu73e14552023-07-19 07:40:18 -070010773 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010774 elif mode == 'nearest':
10775 if padding_mode == 'zeros':
10776 if align_corners:
10777 groundtruth = torch.tensor(
10778 [[0., 8., 5., 7., 9.],
10779 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10780 else:
10781 groundtruth = torch.tensor(
10782 [[0., 8., 5., 7., 0.],
10783 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10784 elif padding_mode == 'border':
10785 if align_corners:
10786 groundtruth = torch.tensor(
10787 [[1., 8., 5., 7., 9.],
10788 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10789 else:
10790 groundtruth = torch.tensor(
10791 [[1., 8., 5., 7., 9.],
10792 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10793 elif padding_mode == 'reflection':
10794 if align_corners:
10795 groundtruth = torch.tensor(
10796 [[1., 8., 5., 7., 9.],
10797 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10798 else:
10799 groundtruth = torch.tensor(
10800 [[1., 8., 5., 7., 9.],
10801 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10802 else:
Justin Chu73e14552023-07-19 07:40:18 -070010803 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010804 elif mode == 'bicubic':
10805 if padding_mode == 'zeros':
10806 if align_corners:
10807 groundtruth = torch.tensor(
10808 [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
10809 [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
10810 else:
10811 groundtruth = torch.tensor(
10812 [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
10813 [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
10814 elif padding_mode == 'border':
10815 if align_corners:
10816 groundtruth = torch.tensor(
10817 [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
10818 [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
10819 else:
10820 groundtruth = torch.tensor(
10821 [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
10822 [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
10823 elif padding_mode == 'reflection':
10824 if align_corners:
10825 groundtruth = torch.tensor(
10826 [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
10827 [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
10828 else:
10829 groundtruth = torch.tensor(
10830 [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
10831 [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
10832 else:
Justin Chu73e14552023-07-19 07:40:18 -070010833 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010834
10835 else:
Justin Chu73e14552023-07-19 07:40:18 -070010836 raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010837 output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
10838 align_corners=align_corners)
10839 self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
Aaron Gokaslan660e8062023-08-22 23:16:35 +000010840 msg=f"groundtruth comparison failed for mode={mode}, "
10841 f"padding_mode={padding_mode}")
Denis Vieriu5b8e4852023-02-09 02:25:46 +000010842
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000010843class TestAdvancedIndexing(TestCaseMPS):
Kulin Sethce7177f2022-08-18 06:03:16 +000010844 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
Denis Vieriuce4f1872022-09-28 00:47:52 +000010845 supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
Kulin Sethce7177f2022-08-18 06:03:16 +000010846
Li-Huai (Allan) Lin964f9752024-07-30 01:39:58 -070010847 @unittest.skipIf(product_version < 14.0, "Skipped on macOS < 14")
Denis Vieriu38de9812023-01-04 00:02:24 +000010848 def test_nonzero_no_warning(self):
10849 device = "mps"
10850 t = torch.randn((2, 2), device=device)
10851 with warnings.catch_warnings(record=True) as w:
10852 warnings.simplefilter("always")
10853 torch.nonzero(t)
10854 t.nonzero()
10855 self.assertEqual(len(w), 0)
10856
10857 def test_nonzero(self):
10858 def helper(dtype):
10859 device = "mps"
10860 shapes = [
10861 torch.Size((12,)),
10862 torch.Size((12, 1)),
10863 torch.Size((1, 12)),
10864 torch.Size((6, 2)),
10865 torch.Size((3, 2, 2)),
10866 torch.Size((5, 5, 5)),
10867 ]
10868
10869 def gen_nontrivial_input(shape, dtype, device):
10870 if dtype != torch.bfloat16:
10871 return torch.randint(2, shape, device=device, dtype=dtype)
10872 else:
10873 # windows does not work for bfloat16 randing
10874 return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
10875
10876 for shape in shapes:
10877 tensor = gen_nontrivial_input(shape, dtype, device)
10878 dst1 = torch.nonzero(tensor, as_tuple=False)
10879 dst2 = tensor.nonzero(as_tuple=False)
10880 dst3 = torch.empty([], dtype=torch.long, device=device)
10881 dst3 = dst3.resize_(0)
10882 torch.nonzero(tensor, out=dst3)
10883 np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
10884 np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
10885 self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
10886 self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
10887 self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
10888 tup1 = torch.nonzero(tensor, as_tuple=True)
10889 tup2 = tensor.nonzero(as_tuple=True)
10890 tup1 = torch.stack(tup1).t().cpu()
10891 tup2 = torch.stack(tup2).t().cpu()
10892 self.assertEqual(tup1, np_result, atol=0, rtol=0)
10893 self.assertEqual(tup2, np_result, atol=0, rtol=0)
10894 [helper(dtype) for dtype in self.supported_dtypes]
10895
10896 def test_nonzero_astuple_out(self):
10897 device = "mps"
10898 t = torch.randn((3, 3, 3), device=device)
10899 out = torch.empty([], dtype=torch.long, device=device)
10900 out = out.resize_(0)
10901
10902 with self.assertRaises(RuntimeError):
10903 torch.nonzero(t, as_tuple=True, out=out)
10904
10905 self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
10906
10907 # Verifies that JIT script cannot handle the as_tuple kwarg
10908 # See Issue https://github.com/pytorch/pytorch/issues/45499.
10909 def _foo(t):
10910 tuple_result = torch.nonzero(t, as_tuple=True)
10911 nontuple_result = torch.nonzero(t, as_tuple=False)
10912 out = torch.empty_like(nontuple_result)
10913 torch.nonzero(t, as_tuple=False, out=out)
10914 return tuple_result, nontuple_result, out
10915
10916 with self.assertRaises(RuntimeError):
10917 scripted_foo = torch.jit.script(_foo)
10918
10919 # Verifies that JIT tracing works fine
10920 traced_foo = torch.jit.trace(_foo, t)
10921 traced_tuple, traced_nontuple, traced_out = traced_foo(t)
10922 expected_tuple = torch.nonzero(t, as_tuple=True)
10923 expected_nontuple = torch.nonzero(t)
10924
10925 self.assertEqual(traced_tuple, expected_tuple)
10926 self.assertEqual(traced_nontuple, expected_nontuple)
10927 self.assertEqual(traced_out, expected_nontuple)
10928
10929 def test_nonzero_discontiguous(self):
10930 device = "mps"
10931 shape = (4, 4)
10932 tensor = torch.randint(2, shape, device=device)
10933 tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
10934 dst1 = tensor.nonzero(as_tuple=False)
10935 dst2 = tensor_nc.nonzero(as_tuple=False)
10936 self.assertEqual(dst1, dst2, atol=0, rtol=0)
10937 dst3 = torch.empty_like(dst1)
10938 data_ptr = dst3.data_ptr()
10939 # expect dst3 storage to be reused
10940 torch.nonzero(tensor, out=dst3)
10941 self.assertEqual(data_ptr, dst3.data_ptr())
10942 self.assertEqual(dst1, dst3, atol=0, rtol=0)
10943 # discontiguous out
10944 dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
10945 data_ptr = dst4.data_ptr()
10946 strides = dst4.stride()
10947 torch.nonzero(tensor, out=dst4)
10948 self.assertEqual(data_ptr, dst4.data_ptr())
10949 self.assertEqual(dst1, dst4, atol=0, rtol=0)
10950 self.assertEqual(strides, dst4.stride())
10951
10952 def test_nonzero_non_diff(self):
10953 device = "mps"
Denis Vieriu861bdf92024-08-16 21:07:48 +000010954 x = torch.randn(10, requires_grad=True, device=device)
Denis Vieriu38de9812023-01-04 00:02:24 +000010955 nz = x.nonzero()
10956 self.assertFalse(nz.requires_grad)
10957
Nikita Shulga916183a2023-09-13 19:28:47 +000010958 def test_nonzero_multi_threading(self):
Denis Vieriu861bdf92024-08-16 21:07:48 +000010959 # Test that MPS doesn't crash if nonzero called concurrently
Nikita Shulga916183a2023-09-13 19:28:47 +000010960 # See https://github.com/pytorch/pytorch/issues/100285
10961 x = torch.rand(3, 3, device="mps")
10962 t1 = threading.Thread(target=torch.nonzero, args=(x,))
10963 t2 = threading.Thread(target=torch.nonzero, args=(x,))
10964 t1.start()
10965 t2.start()
10966
pytorchbotf31b8bb2024-10-22 16:25:25 -070010967 def test_sliced_view_cast(self):
10968 # This used to crash on MacOS Sequoia
10969 # See https://github.com/pytorch/pytorch/issues/137800
10970 x = torch.rand(16, 16, device='mps', dtype=torch.float16)
10971 y = x[:, 0:2].view(torch.float32) + 1
10972
Denis Vieriu6a14fcb2022-09-29 23:23:00 +000010973 def test_masked_select(self):
10974 x = torch.randn(3, 4)
10975 x_mps = x.to("mps")
10976 mask = x.ge(0.5)
10977 mask_mps = x_mps.ge(0.5)
10978
10979 res = torch.masked_select(x, mask)
10980 res_mps = torch.masked_select(x_mps, mask_mps)
10981
10982 self.assertEqual(res, res_mps)
10983
Kulin Sethce7177f2022-08-18 06:03:16 +000010984 # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
Denis Vieriuce4f1872022-09-28 00:47:52 +000010985 def test_indexing_get(self):
Kulin Sethce7177f2022-08-18 06:03:16 +000010986 def helper(dtype):
10987 x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
10988 x_mps = x_cpu.detach().clone().to("mps")
10989
10990 y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]]
10991 y_mps = x_mps[[0, 1, 2], [0, 1, 0]]
10992 self.assertEqual(y_cpu, y_mps, str(dtype))
10993 [helper(dtype) for dtype in self.supported_dtypes]
10994
10995 def test_indexing_select_corners(self):
10996 def helper(dtype):
10997 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
10998 x_mps = x_cpu.detach().clone().to("mps")
10999
11000 rows_cpu = torch.tensor([[0, 0], [3, 3]])
11001 rows_mps = rows_cpu.detach().clone().to("mps")
11002
11003 cols_cpu = torch.tensor([[0, 2], [0, 2]])
11004 cols_mps = cols_cpu.detach().clone().to("mps")
11005
11006 res_cpu = x_cpu[rows_cpu, cols_cpu]
11007 res_mps = x_mps[rows_mps, cols_mps]
11008
11009 self.assertEqual(res_cpu, res_mps, str(dtype))
11010 [helper(dtype) for dtype in self.supported_dtypes]
11011
11012 # FIXME: uint8 fails for this testcase, needs further debugging
11013 def test_slicing_using_advanced_index_for_column(self):
11014 def helper(dtype):
11015 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
11016 x_mps = x_cpu.detach().clone().to("mps")
11017
11018 z_cpu = x_cpu[1:4, 1:3]
11019 z_mps = x_mps[1:4, 1:3]
11020 self.assertEqual(z_cpu, z_mps, str(dtype))
11021
11022 # using advanced index for column
11023 y_cpu = x_cpu[1:4, [1, 2]]
11024 y_mps = x_mps[1:4, [1, 2]]
11025 self.assertEqual(y_cpu, y_mps, str(dtype))
11026 # FIXME: use supported_dtypes once uint8 is fixed
11027 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
11028
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011029 def test_boolean_array_indexing(self):
11030 def helper(dtype):
11031 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
11032 x_mps = x_cpu.detach().clone().to("mps")
Kulin Sethce7177f2022-08-18 06:03:16 +000011033
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011034 res_cpu = x_cpu[x_cpu > 5]
11035 res_mps = x_mps[x_mps > 5]
Kulin Sethce7177f2022-08-18 06:03:16 +000011036
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011037 self.assertEqual(res_cpu, res_mps, str(dtype))
11038 for dtype in self.supported_dtypes:
11039 # MPS support binary op with uint8 natively starting from macOS 13.0
11040 if product_version < 13.0 and dtype == torch.uint8:
11041 continue
11042 helper(dtype)
Denis Vieriuce4f1872022-09-28 00:47:52 +000011043
11044 def test_advanced_indexing_3D_get(self):
11045 def helper(x_cpu):
11046 x_mps = x_cpu.detach().clone().to("mps")
11047 self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
11048 self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
11049 self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
11050
11051 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
11052 [0.5, 0.6, 0.7, 0.8],
11053 [0.9, 1.0, 1.1, 1.2],
11054 [1.3, 1.4, 1.5, 1.6]],
11055
11056 [[2.0, 2.1, 2.2, 2.3],
11057 [2.4, 2.5, 2.6, 2.7],
11058 [2.8, 2.9, 3.0, 3.1],
11059 [3.2, 3.3, 3.4, 3.5]],
11060
11061 [[4.0, 4.1, 4.2, 4.3],
11062 [4.4, 4.5, 4.6, 4.7],
11063 [4.8, 4.9, 5.0, 5.1],
11064 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
11065 helper(x_cpu)
11066 for idx in range(len(self.supported_np_dtypes)):
11067 # torch.randn / torch.rand don't work with all dtypes
11068 # Generate input data for all dtypes on Numpy them move to torch
11069 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
11070 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
11071
11072 helper(inputCPU)
11073
11074 def test_advanced_indexing_3D_put(self):
11075 def helper(x_cpu):
11076 dtype = x_cpu.dtype
11077 x_mps = x_cpu.detach().clone().to("mps")
11078
11079 out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
11080 out_tensor_cpu_view = out_tensor_cpu[1:]
11081
11082 out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
11083 out_tensor_mps_view = out_tensor_mps[1:]
11084
11085 x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
11086 x_mps[[1, 2], 3, :] = out_tensor_mps_view
11087 self.assertEqual(x_cpu, x_mps)
11088
11089 x_cpu[[0, 2], :, :] = out_tensor_cpu_view
11090 x_mps[[0, 2], :, :] = out_tensor_mps_view
11091 self.assertEqual(x_cpu, x_mps)
11092
11093 x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
11094 x_mps[:, [1, 0], [1]] = out_tensor_mps_view
11095 self.assertEqual(x_cpu, x_mps)
11096
11097 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
11098 [0.5, 0.6, 0.7, 0.8],
11099 [0.9, 1.0, 1.1, 1.2],
11100 [1.3, 1.4, 1.5, 1.6]],
11101
11102 [[2.0, 2.1, 2.2, 2.3],
11103 [2.4, 2.5, 2.6, 2.7],
11104 [2.8, 2.9, 3.0, 3.1],
11105 [3.2, 3.3, 3.4, 3.5]],
11106
11107 [[4.0, 4.1, 4.2, 4.3],
11108 [4.4, 4.5, 4.6, 4.7],
11109 [4.8, 4.9, 5.0, 5.1],
11110 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
11111 helper(x_cpu)
11112 for idx in range(len(self.supported_np_dtypes)):
11113 # torch.randn / torch.rand don't work with all dtypes
11114 # Generate input data for all dtypes on Numpy them move to torch
11115 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
11116 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
11117
11118 helper(inputCPU)
11119
11120 def test_index_put_with_view_indices(self):
11121 def helper(dtype):
11122 target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
11123 target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
11124
11125 indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
11126 indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
11127
11128 value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
11129 value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
11130
11131 target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
11132 target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
11133
11134 self.assertEqual(target_cpu, target_mps)
11135
11136 [helper(dtype) for dtype in [torch.int32, torch.float]]
11137
11138 # tests from 'test_indexing.py'
11139 def test_advancedindex_big(self, device="mps"):
11140 reference = torch.arange(0, 123344, dtype=torch.int, device=device)
11141
11142 self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
11143 torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
11144
11145 def test_set_item_to_scalar_tensor(self, device="mps"):
11146 m = random.randint(1, 10)
11147 n = random.randint(1, 10)
11148 z = torch.randn([m, n], device=device)
11149 a = 1.0
11150 w = torch.tensor(a, requires_grad=True, device=device)
11151 z[:, 0] = w
11152 z.sum().backward()
11153 self.assertEqual(w.grad, m * a)
11154
11155 def test_single_int(self, device="mps"):
11156 v = torch.randn(5, 7, 3, device=device)
11157 self.assertEqual(v[4].shape, (7, 3))
11158
11159 def test_multiple_int(self, device="mps"):
11160 v = torch.randn(5, 7, 3, device=device)
11161 self.assertEqual(v[4].shape, (7, 3))
11162 self.assertEqual(v[4, :, 1].shape, (7,))
11163
11164 def test_none(self, device="mps"):
11165 v = torch.randn(5, 7, 3, device=device)
11166 self.assertEqual(v[None].shape, (1, 5, 7, 3))
11167 self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
11168 self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
11169 self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
11170
11171 def test_step(self, device="mps"):
11172 v = torch.arange(10, device=device)
11173 self.assertEqual(v[::1], v)
11174 self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
11175 self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
11176 self.assertEqual(v[::11].tolist(), [0])
11177 self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
11178
11179 def test_step_assignment(self, device="mps"):
11180 v = torch.zeros(4, 4, device=device)
11181 v[0, 1::2] = torch.tensor([3., 4.], device=device)
11182 self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
11183 self.assertEqual(v[1:].sum(), 0)
11184
Kulin Sethce7177f2022-08-18 06:03:16 +000011185 def test_bool_indices(self, device="mps"):
11186 v = torch.randn(5, 7, 3, device=device)
11187 boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
11188 self.assertEqual(v[boolIndices].shape, (3, 7, 3))
11189 self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
11190
11191 v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
11192 boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
11193 uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
11194 with warnings.catch_warnings(record=True) as w:
11195 self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
11196 self.assertEqual(v[boolIndices], v[uint8Indices])
11197 self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
11198 self.assertEqual(len(w), 2)
11199
Denis Vieriu71ec2612023-02-15 06:09:56 +000011200 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Denis Vieriuce4f1872022-09-28 00:47:52 +000011201 def test_bool_indices_accumulate(self, device="mps"):
11202 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
11203 mask = mask > 0
11204 y = torch.ones(size=(10, 10), device=device)
11205 y.index_put_((mask, ), y[mask], accumulate=True)
11206 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
11207
Kulin Sethce7177f2022-08-18 06:03:16 +000011208 def test_multiple_bool_indices(self, device="mps"):
11209 v = torch.randn(5, 7, 3, device=device)
11210 # note: these broadcast together and are transposed to the first dim
11211 mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
11212 mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
11213 self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
11214
Kulin Sethce7177f2022-08-18 06:03:16 +000011215 def test_byte_mask(self, device="mps"):
11216 v = torch.randn(5, 7, 3, device=device)
11217 mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
11218 with warnings.catch_warnings(record=True) as w:
11219 self.assertEqual(v[mask].shape, (3, 7, 3))
11220 self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
11221 self.assertEqual(len(w), 2)
11222
11223 v = torch.tensor([1.], device=device)
11224 self.assertEqual(v[v == 0], torch.tensor([], device=device))
11225
Denis Vieriuce4f1872022-09-28 00:47:52 +000011226 def test_byte_mask_accumulate(self, device="mps"):
11227 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
11228 y = torch.ones(size=(10, 10), device=device)
11229 with warnings.catch_warnings(record=True) as w:
11230 warnings.simplefilter("always")
11231 y.index_put_((mask, ), y[mask], accumulate=True)
11232 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
11233 self.assertEqual(len(w), 2)
11234
11235 def test_index_put_accumulate_expanded_values(self, device="mps"):
11236 t = torch.zeros((5, 2))
11237 t_dev = t.to(device)
11238 indices = [
11239 torch.tensor([0, 1, 2, 3]),
11240 torch.tensor([1, ]),
11241 ]
11242 indices_dev = [i.to(device) for i in indices]
11243 values0d = torch.tensor(1.0)
11244 values1d = torch.tensor([1.0, ])
11245
11246 out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
11247 out_cpu = t.index_put_(indices, values0d, accumulate=True)
11248 self.assertEqual(out_mps.cpu(), out_cpu)
11249
11250 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
11251 out_cpu = t.index_put_(indices, values1d, accumulate=True)
11252 self.assertEqual(out_mps.cpu(), out_cpu)
11253
11254 t = torch.zeros(4, 3, 2)
11255 t_dev = t.to(device)
11256
11257 indices = [
11258 torch.tensor([0, ]),
11259 torch.arange(3)[:, None],
11260 torch.arange(2)[None, :],
11261 ]
11262 indices_dev = [i.to(device) for i in indices]
11263 values1d = torch.tensor([-1.0, -2.0])
11264 values2d = torch.tensor([[-1.0, -2.0], ])
11265
11266 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
11267 out_cpu = t.index_put_(indices, values1d, accumulate=True)
11268 self.assertEqual(out_mps.cpu(), out_cpu)
11269
11270 out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
11271 out_cpu = t.index_put_(indices, values2d, accumulate=True)
11272 self.assertEqual(out_mps.cpu(), out_cpu)
11273
11274 def test_index_put_accumulate_non_contiguous(self, device="mps"):
11275 t = torch.zeros((5, 2, 2))
11276 t_dev = t.to(device)
11277 t1 = t_dev[:, 0, :]
11278 t2 = t[:, 0, :]
Nikita Shulgabc689072024-06-26 16:29:59 +000011279 self.assertFalse(t1.is_contiguous())
11280 self.assertFalse(t2.is_contiguous())
Denis Vieriuce4f1872022-09-28 00:47:52 +000011281
11282 indices = [torch.tensor([0, 1]), ]
11283 indices_dev = [i.to(device) for i in indices]
11284 value = torch.randn(2, 2)
11285 out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
11286 out_cpu = t2.index_put_(indices, value, accumulate=True)
Nikita Shulgabc689072024-06-26 16:29:59 +000011287 self.assertFalse(t1.is_contiguous())
11288 self.assertFalse(t2.is_contiguous())
Denis Vieriuce4f1872022-09-28 00:47:52 +000011289
11290 self.assertEqual(out_mps.cpu(), out_cpu)
11291
11292 def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
11293 # TODO: replace with a better solution.
11294 # Currently, here using torchscript to put None into indices.
11295 # on C++ it gives indices as a list of 2 optional tensors: first is null and
11296 # the second is a valid tensor.
11297 @torch.jit.script
11298 def func(x, i, v):
11299 idx = [None, i]
11300 x.index_put_(idx, v, accumulate=True)
11301 return x
11302
11303 n = 4
11304 t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
11305 t_dev = t.to(device)
11306 indices = torch.tensor([1, 0])
11307 indices_dev = indices.to(device)
11308 value0d = torch.tensor(10.0)
11309 value1d = torch.tensor([1.0, 2.0])
11310
11311 out_mps = func(t_dev, indices_dev, value0d.to("mps"))
11312 out_cpu = func(t, indices, value0d)
11313 self.assertEqual(out_mps.cpu(), out_cpu)
11314
11315 out_mps = func(t_dev, indices_dev, value1d.to("mps"))
11316 out_cpu = func(t, indices, value1d)
11317 self.assertEqual(out_mps.cpu(), out_cpu)
11318
11319 def test_index_put_accumulate_duplicate_indices(self, device="mps"):
11320 for i in range(1, 128):
11321 # generate indices by random walk, this will create indices with
11322 # lots of duplicates interleaved with each other
11323 delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
11324
Nikita Shulga657f2e12022-11-04 01:22:41 +000011325 indices = delta.cumsum(0).long().to("mps")
Denis Vieriuce4f1872022-09-28 00:47:52 +000011326
11327 # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
Denis Vieriu6a14fcb2022-09-29 23:23:00 +000011328 input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
Denis Vieriuce4f1872022-09-28 00:47:52 +000011329 values = torch.randn(indices.size(0), device=device)
11330 output = input.index_put((indices,), values, accumulate=True)
11331
11332 input_list = input.tolist()
11333 indices_list = indices.tolist()
11334 values_list = values.tolist()
11335 for i, v in zip(indices_list, values_list):
11336 input_list[i] += v
11337
11338 self.assertEqual(output, input_list)
11339
Li-Huai (Allan) Lin3b6a7f42023-05-08 00:57:29 +000011340 def test_index_put_deterministic(self, device="mps"):
11341 def helper(dtype, accumulate, deterministic, num_tests=128):
11342 acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype)
11343 non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype)
11344 t_idx = torch.tensor(
11345 [0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2,
11346 0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2]
11347 )
11348 for _ in range(num_tests):
11349 try:
11350 torch.use_deterministic_algorithms(deterministic)
11351 t = torch.zeros(3, dtype=dtype, device=device)
11352 t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate)
11353 if accumulate:
11354 self.assertEqual(t, acc_expected)
11355 else:
11356 self.assertEqual(t, non_acc_expected)
11357 finally:
11358 torch.use_deterministic_algorithms(False)
11359
11360 for accumulate, deterministic in product((False, True), (False, True)):
11361 dtype = torch.float if accumulate else torch.long
11362 if not accumulate and not deterministic:
11363 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"):
11364 helper(dtype, accumulate, deterministic)
11365 else:
11366 helper(dtype, accumulate, deterministic)
11367
Denis Vieriuce4f1872022-09-28 00:47:52 +000011368 def test_multiple_byte_mask(self, device="mps"):
11369 v = torch.randn(5, 7, 3, device=device)
11370 # note: these broadcast together and are transposed to the first dim
11371 mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
11372 mask2 = torch.ByteTensor([1, 1, 1]).to(device)
11373 with warnings.catch_warnings(record=True) as w:
11374 warnings.simplefilter("always")
11375 self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
11376 self.assertEqual(len(w), 2)
11377
11378 def test_byte_mask2d(self, device="mps"):
11379 v = torch.randn(5, 7, 3, device=device)
11380 c = torch.randn(5, 7, device=device)
11381 num_ones = (c > 0).sum()
11382 r = v[c > 0]
11383 self.assertEqual(r.shape, (num_ones, 3))
11384
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011385 def test_jit_indexing(self, device="mps"):
11386 def fn1(x):
11387 x[x < 50] = 1.0
11388 return x
Denis Vieriuce4f1872022-09-28 00:47:52 +000011389
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011390 def fn2(x):
11391 x[0:50] = 1.0
11392 return x
Denis Vieriuce4f1872022-09-28 00:47:52 +000011393
Li-Huai (Allan) Lindb8abde2023-04-01 16:15:08 +000011394 scripted_fn1 = torch.jit.script(fn1)
11395 scripted_fn2 = torch.jit.script(fn2)
11396 data = torch.arange(100, device=device, dtype=torch.float)
11397 out = scripted_fn1(data.detach().clone())
11398 ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
11399 self.assertEqual(out, ref)
11400 out = scripted_fn2(data.detach().clone())
11401 self.assertEqual(out, ref)
Denis Vieriuce4f1872022-09-28 00:47:52 +000011402
11403 def test_int_indices(self, device="mps"):
11404 v = torch.randn(5, 7, 3, device=device)
11405 self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
11406 self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
11407 self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
11408
11409 def test_index_put_src_datatype(self):
11410 def helper(device, dtype):
11411 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11412 vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
11413 indices = (torch.tensor([0, 2, 1]),)
11414 res = src.index_put_(indices, vals, accumulate=True)
11415 self.assertEqual(res.shape, src.shape)
11416 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
11417
Denis Vieriu71ec2612023-02-15 06:09:56 +000011418 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
Denis Vieriuce4f1872022-09-28 00:47:52 +000011419 def test_index_src_datatype(self):
11420 def helper(device, dtype):
11421 orig_dtype = dtype
11422 if dtype is torch.bool:
11423 dtype = torch.uint8
11424
11425 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11426 if orig_dtype is torch.bool:
11427 src = src == 1
11428 # test index
11429 res = src[[0, 2, 1], :, :]
11430 self.assertEqual(res.shape, src.shape)
11431 # test index_put, no accum
11432 src[[0, 2, 1], :, :] = res
11433 self.assertEqual(res.shape, src.shape)
11434 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
11435
Kulin Sethce7177f2022-08-18 06:03:16 +000011436 def test_int_indices2d(self, device="mps"):
11437 # From the NumPy indexing example
11438 x = torch.arange(0, 12, device=device).view(4, 3)
11439 rows = torch.tensor([[0, 0], [3, 3]], device=device)
11440 columns = torch.tensor([[0, 2], [0, 2]], device=device)
11441 self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
11442
11443 def test_int_indices_broadcast(self, device="mps"):
11444 # From the NumPy indexing example
11445 x = torch.arange(0, 12, device=device).view(4, 3)
11446 rows = torch.tensor([0, 3], device=device)
11447 columns = torch.tensor([0, 2], device=device)
11448 result = x[rows[:, None], columns]
11449 self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
11450
Denis Vieriuce4f1872022-09-28 00:47:52 +000011451 def test_empty_index(self, device="mps"):
11452 x = torch.arange(0, 12, device=device).view(4, 3)
11453 idx = torch.tensor([], dtype=torch.long, device=device)
11454 self.assertEqual(x[idx].numel(), 0)
11455
11456 # empty assignment should have no effect but not throw an exception
11457 y = x.clone()
11458 y[idx] = -1
11459 self.assertEqual(x, y)
11460
11461 mask = torch.zeros(4, 3, device=device).bool()
11462 y[mask] = -1
11463 self.assertEqual(x, y)
11464
Kulin Sethce7177f2022-08-18 06:03:16 +000011465 def test_empty_ndim_index(self, device="mps"):
11466 x = torch.randn(5, device=device)
11467 self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
11468
11469 x = torch.randn(2, 3, 4, 5, device=device)
11470 self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
11471 x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
11472
11473 x = torch.empty(10, 0, device=device)
11474 self.assertEqual(x[[1, 2]].shape, (2, 0))
11475 self.assertEqual(x[[], []].shape, (0,))
11476 with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
11477 x[:, [0, 1]]
11478
11479 def test_empty_ndim_index_bool(self, device="mps"):
11480 x = torch.randn(5, device=device)
11481 self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
11482
Denis Vieriuce4f1872022-09-28 00:47:52 +000011483 def test_empty_slice(self, device="mps"):
11484 x = torch.randn(2, 3, 4, 5, device=device)
11485 y = x[:, :, :, 1]
11486 z = y[:, 1:1, :]
11487 self.assertEqual((2, 0, 4), z.shape)
11488 # this isn't technically necessary, but matches NumPy stride calculations.
11489 self.assertEqual((60, 20, 5), z.stride())
11490 self.assertTrue(z.is_contiguous())
11491
Kulin Sethce7177f2022-08-18 06:03:16 +000011492 def test_index_getitem_copy_bools_slices(self, device="mps"):
11493 true = torch.tensor(1, dtype=torch.uint8, device=device)
11494 false = torch.tensor(0, dtype=torch.uint8, device=device)
11495
11496 tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
11497
11498 for a in tensors:
11499 self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
11500 self.assertEqual(torch.empty(0, *a.shape), a[False])
11501 self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
11502 self.assertEqual(torch.empty(0, *a.shape), a[false])
11503 self.assertEqual(a.data_ptr(), a[None].data_ptr())
11504 self.assertEqual(a.data_ptr(), a[...].data_ptr())
11505
Denis Vieriuce4f1872022-09-28 00:47:52 +000011506 def test_index_setitem_bools_slices(self, device="mps"):
11507 true = torch.tensor(1, dtype=torch.uint8, device=device)
11508 false = torch.tensor(0, dtype=torch.uint8, device=device)
11509
11510 tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
11511
11512 for a in tensors:
11513 # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
11514 # (some of these ops already prefix a 1 to the size)
11515 neg_ones = torch.ones_like(a) * -1
11516 neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
11517 a[True] = neg_ones_expanded
11518 self.assertEqual(a, neg_ones)
11519 a[False] = 5
11520 self.assertEqual(a, neg_ones)
11521 a[true] = neg_ones_expanded * 2
11522 self.assertEqual(a, neg_ones * 2)
11523 a[false] = 5
11524 self.assertEqual(a, neg_ones * 2)
11525 a[None] = neg_ones_expanded * 3
11526 self.assertEqual(a, neg_ones * 3)
11527 a[...] = neg_ones_expanded * 4
11528 self.assertEqual(a, neg_ones * 4)
11529 if a.dim() == 0:
11530 with self.assertRaises(IndexError):
11531 a[:] = neg_ones_expanded * 5
11532
Kulin Sethce7177f2022-08-18 06:03:16 +000011533 def test_index_scalar_with_bool_mask(self, device="mps"):
11534 a = torch.tensor(1, device=device)
11535 uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
11536 boolMask = torch.tensor(True, dtype=torch.bool, device=device)
11537 self.assertEqual(a[uintMask], a[boolMask])
11538 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11539
11540 a = torch.tensor(True, dtype=torch.bool, device=device)
11541 self.assertEqual(a[uintMask], a[boolMask])
11542 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11543
Denis Vieriuce4f1872022-09-28 00:47:52 +000011544 def test_setitem_expansion_error(self, device="mps"):
11545 true = torch.tensor(True, device=device)
11546 a = torch.randn(2, 3, device=device)
11547 # check prefix with non-1s doesn't work
11548 a_expanded = a.expand(torch.Size([5, 1]) + a.size())
11549 # NumPy: ValueError
11550 with self.assertRaises(RuntimeError):
11551 a[True] = a_expanded
11552 with self.assertRaises(RuntimeError):
11553 a[true] = a_expanded
11554
Kulin Sethce7177f2022-08-18 06:03:16 +000011555 def test_getitem_scalars(self, device="mps"):
11556 zero = torch.tensor(0, dtype=torch.int64, device=device)
11557 one = torch.tensor(1, dtype=torch.int64, device=device)
11558
11559 # non-scalar indexed with scalars
11560 a = torch.randn(2, 3, device=device)
11561 self.assertEqual(a[0], a[zero])
11562 self.assertEqual(a[0][1], a[zero][one])
11563 self.assertEqual(a[0, 1], a[zero, one])
11564 self.assertEqual(a[0, one], a[zero, 1])
11565
11566 # indexing by a scalar should slice (not copy)
11567 self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
11568 self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
11569 self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
11570
11571 # scalar indexed with scalar
11572 r = torch.randn((), device=device)
11573 with self.assertRaises(IndexError):
11574 r[:]
11575 with self.assertRaises(IndexError):
11576 r[zero]
11577 self.assertEqual(r, r[...])
11578
Denis Vieriuce4f1872022-09-28 00:47:52 +000011579 def test_setitem_scalars(self, device="mps"):
11580 zero = torch.tensor(0, dtype=torch.int64)
11581
11582 # non-scalar indexed with scalars
11583 a = torch.randn(2, 3, device=device)
11584 a_set_with_number = a.clone()
11585 a_set_with_scalar = a.clone()
11586 b = torch.randn(3, device=device)
11587
11588 a_set_with_number[0] = b
11589 a_set_with_scalar[zero] = b
11590 self.assertEqual(a_set_with_number, a_set_with_scalar)
11591 a[1, zero] = 7.7
11592 self.assertEqual(7.7, a[1, 0])
11593
11594 # scalar indexed with scalars
11595 r = torch.randn((), device=device)
11596 with self.assertRaises(IndexError):
11597 r[:] = 8.8
11598 with self.assertRaises(IndexError):
11599 r[zero] = 8.8
11600 r[...] = 9.9
11601 self.assertEqual(9.9, r)
11602
11603 def test_basic_advanced_combined(self, device="mps"):
11604 # From the NumPy indexing example
11605 x = torch.arange(0, 12, device=device).view(4, 3)
11606 self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
11607 self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
11608
11609 # Check that it is a copy
11610 unmodified = x.clone()
11611 x[1:2, [1, 2]].zero_()
11612 self.assertEqual(x, unmodified)
11613
11614 # But assignment should modify the original
11615 unmodified = x.clone()
11616 x[1:2, [1, 2]] = 0
11617 self.assertNotEqual(x, unmodified)
11618
11619 def test_int_assignment(self, device="mps"):
11620 x = torch.arange(0, 4, device=device).view(2, 2)
11621 x[1] = 5
11622 self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
11623
11624 x = torch.arange(0, 4, device=device).view(2, 2)
11625 x[1] = torch.arange(5, 7, device=device)
11626 self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
11627
11628 def test_byte_tensor_assignment(self, device="mps"):
11629 x = torch.arange(0., 16, device=device).view(4, 4)
11630 b = torch.ByteTensor([True, False, True, False]).to(device)
11631 value = torch.tensor([3., 4., 5., 6.], device=device)
11632
11633 with warnings.catch_warnings(record=True) as w:
11634 x[b] = value
11635 self.assertEqual(len(w), 1)
11636
11637 self.assertEqual(x[0], value)
11638 self.assertEqual(x[1], torch.arange(4., 8, device=device))
11639 self.assertEqual(x[2], value)
11640 self.assertEqual(x[3], torch.arange(12., 16, device=device))
11641
Kulin Sethce7177f2022-08-18 06:03:16 +000011642 def test_variable_slicing(self, device="mps"):
11643 x = torch.arange(0, 16, device=device).view(4, 4)
11644 indices = torch.IntTensor([0, 1]).to(device)
11645 i, j = indices
11646 self.assertEqual(x[i:j], x[0:1])
11647
11648 def test_ellipsis_tensor(self, device="mps"):
11649 x = torch.arange(0, 9, device=device).view(3, 3)
11650 idx = torch.tensor([0, 2], device=device)
11651 self.assertEqual(x[..., idx].tolist(), [[0, 2],
11652 [3, 5],
11653 [6, 8]])
11654 self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
11655 [6, 7, 8]])
11656
11657 def test_invalid_index(self, device="mps"):
11658 x = torch.arange(0, 16, device=device).view(4, 4)
11659 self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
11660
Denis Vieriuce4f1872022-09-28 00:47:52 +000011661 def test_out_of_bound_index(self, device="mps"):
11662 x = torch.arange(0, 100, device=device).view(2, 5, 10)
11663 self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
11664 self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
11665 self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
11666 lambda: x[0, 1, 15])
11667 self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
11668 lambda: x[:, :, 12])
11669
11670 def test_zero_dim_index(self, device="mps"):
11671 x = torch.tensor(10, device=device)
11672 self.assertEqual(x, x.item())
11673
11674 def runner():
11675 print(x[0])
11676 return x[0]
11677
11678 self.assertRaisesRegex(IndexError, 'invalid index', runner)
11679
11680 def test_cpu_indices(self, device="mps"):
11681 idx = torch.tensor([0, 1])
11682 b = torch.zeros(2, device=device)
11683 x = torch.ones(10, device=device)
11684 x[idx] = b # index_put_
11685 ref = torch.ones(10, device=device)
11686 ref[:2] = 0
11687 self.assertEqual(x, ref, atol=0, rtol=0)
11688 out = x[idx] # index
11689 self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
11690
Nikita Shulga5944a532024-04-27 02:58:05 +000011691 def test_nextafter(self, device="mps"):
11692 for dtype in [torch.float16, torch.float32]:
11693 x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype)
11694 y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype)
11695 na = torch.nextafter(x, y)
11696 na_cpu = torch.nextafter(x.cpu(), y.cpu())
11697 na_ge_x_mps = na.cpu() > x.cpu()
11698 # greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051
11699 na_ge_x_cpu = na_cpu > x.cpu()
11700 self.assertEqual(na_ge_x_mps, na_ge_x_cpu)
11701
11702
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000011703class TestRNNMPS(TestCaseMPS):
alexdremov78da3152023-03-05 00:19:51 +000011704 def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
11705 seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
11706 rnn = nn.LSTM(
11707 input_size=input_size,
11708 hidden_size=hidden_size,
11709 num_layers=num_layers,
11710 bias=bias,
11711 bidirectional=bidirectional,
11712 batch_first=batch_first,
11713 device="cpu"
11714 )
11715 bidirectional_mul = 2 if bidirectional else 1
Kulin Sethe011a8e2022-05-13 18:28:53 +000011716
alexdremov78da3152023-03-05 00:19:51 +000011717 if batch_first:
11718 input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11719 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11720 requires_grad=backward)
11721 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11722 requires_grad=backward)
11723 else:
11724 input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11725 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11726 requires_grad=backward)
11727 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11728 requires_grad=backward)
Kulin Sethe011a8e2022-05-13 18:28:53 +000011729
alexdremov78da3152023-03-05 00:19:51 +000011730 cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
11731
11732 rnn = rnn.to(device)
11733 input = input.to(device)
11734 hx = hx.to(device)
11735 cx = cx.to(device)
11736 output, (hn, cn) = rnn(input, (hx, cx))
11737
11738 self.assertEqual(cpu_output, output)
11739 self.assertEqual(cpu_hn, hn)
11740 self.assertEqual(cpu_cn, cn)
11741
alexdremov62eb7a22023-03-16 15:53:52 +000011742 def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
alexdremovb9e95152023-02-23 17:32:42 +000011743 rnn = rnn.to(device)
alexdremov78da3152023-03-05 00:19:51 +000011744 inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
Alban Desmaison02551a02022-05-28 12:39:10 -040011745
alexdremov62eb7a22023-03-16 15:53:52 +000011746 output, (hx_out, cx_out) = rnn(inp, (hx, cx))
11747 assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
11748
11749 f = 0
11750 if output_grad_presented:
11751 f = f + 3 * output.sum()
11752 if states_grad_presented:
11753 f = f + (hx_out * cx_out).sum()
qqaatwb0b24b42022-07-07 07:18:00 +000011754
alexdremov78da3152023-03-05 00:19:51 +000011755 param_names, params = zip(*rnn.named_parameters())
11756 param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
qqaatwb0b24b42022-07-07 07:18:00 +000011757
alexdremov78da3152023-03-05 00:19:51 +000011758 input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
11759 return output, param_grads, input_grad, hx_grad, cx_grad
qqaatwb0b24b42022-07-07 07:18:00 +000011760
alexdremov78da3152023-03-05 00:19:51 +000011761 if backward:
alexdremov62eb7a22023-03-16 15:53:52 +000011762 grad_cases = [
11763 dict(output_grad_presented=True, states_grad_presented=True),
11764 dict(output_grad_presented=False, states_grad_presented=True),
11765 dict(output_grad_presented=True, states_grad_presented=False),
11766 ]
alexdremov78da3152023-03-05 00:19:51 +000011767
alexdremov62eb7a22023-03-16 15:53:52 +000011768 for grad_case in grad_cases:
11769 cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
11770 get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
11771 mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
11772 get_backward_results(rnn, device, input, hx, cx, **grad_case)
11773
11774 self.assertEqual(cpu_hx_grad, mps_hx_grad)
11775 self.assertEqual(cpu_cx_grad, mps_cx_grad)
11776 self.assertEqual(cpu_output, mps_output)
11777 self.assertEqual(cpu_input_grad, mps_input_grad)
11778 for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
11779 self.assertEqual(cpu_weight_grad, mps_weight_grad,
11780 f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
alexdremov78da3152023-03-05 00:19:51 +000011781
11782 LSTM_TEST_CASES = [
Xuehai Pan973037b2024-07-11 20:40:53 +080011783 {}, # default
alexdremov78da3152023-03-05 00:19:51 +000011784 dict(batch_first=True),
11785 dict(bias=False),
11786 dict(bidirectional=True),
11787 dict(batch_first=True, bias=False),
11788 dict(bidirectional=True, bias=False),
11789 dict(bidirectional=True, batch_first=True),
11790 dict(bidirectional=True, batch_first=True, bias=False)
11791 ]
11792
11793 def test_lstm_forward(self, device="mps", dtype=torch.float32):
Li-Huai (Allan) Lina87f3f62023-03-10 03:10:49 +000011794 for num_layers in [1, 2, 5]:
alexdremov78da3152023-03-05 00:19:51 +000011795 for test_options in self.LSTM_TEST_CASES:
11796 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
qqaatwb0b24b42022-07-07 07:18:00 +000011797
alexdremovb9e95152023-02-23 17:32:42 +000011798 def test_lstm_backward(self, device="mps", dtype=torch.float32):
Li-Huai (Allan) Lina87f3f62023-03-10 03:10:49 +000011799 for num_layers in [1, 2, 5]:
alexdremov78da3152023-03-05 00:19:51 +000011800 for test_options in self.LSTM_TEST_CASES:
11801 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
alexdremovb9e95152023-02-23 17:32:42 +000011802
Kulin Seth54ebf252023-02-15 16:10:40 +000011803 def test_RNN_cell_no_broadcasting(self):
11804 def test(cell_module, input, hx, input_size, hidden_size):
11805 cell = cell_module(input_size, hidden_size, device='mps')
11806 self.assertRaises(RuntimeError, lambda: cell(input, hx))
11807
11808 def test_all(hidden_size, bad_hx, good_hx, input_size, input):
11809 test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
11810 test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
11811 test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
11812 test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
11813
11814 hidden_size = 20
11815 input_size = 10
11816 input = torch.randn(3, input_size, device='mps')
11817 bad_hx = torch.randn(1, hidden_size, device='mps')
11818 good_hx = torch.randn(3, hidden_size, device='mps')
11819
11820 # Test hidden/input batch size broadcasting
11821 test_all(hidden_size, bad_hx, good_hx, input_size, input)
11822
11823 # Test hx's hidden_size vs module's hidden_size broadcasting
11824 bad_hx = torch.randn(3, 1)
11825 test_all(hidden_size, bad_hx, good_hx, input_size, input)
11826
11827 # Test input's input_size vs module's input_size broadcasting
11828 bad_input = torch.randn(3, 1)
11829 test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
11830
11831 def test_LSTM_cell(self):
11832 # this is just a smoke test; these modules are implemented through
11833 # autograd so no Jacobian test is needed
11834 for bias in (True, False):
11835 input = torch.randn(3, 10, device='mps')
11836 hx = torch.randn(3, 20, device='mps')
11837 cx = torch.randn(3, 20, device='mps')
11838 lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
11839 for _ in range(6):
11840 hx, cx = lstm(input, (hx, cx))
11841
11842 (hx + cx).sum().backward()
11843
11844 def test_LSTM_cell_forward_input_size(self):
11845 input = torch.randn(3, 11, device='mps')
11846 hx = torch.randn(3, 20, device='mps')
11847 cx = torch.randn(3, 20, device='mps')
11848 lstm = nn.LSTMCell(10, 20, device='mps')
11849 self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11850
11851 def test_LSTM_cell_forward_hidden_size(self):
11852 input = torch.randn(3, 10, device='mps')
11853 hx = torch.randn(3, 21, device='mps')
11854 cx = torch.randn(3, 20, device='mps')
11855 lstm = nn.LSTMCell(10, 20, device='mps')
11856 self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11857 self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
11858
11859
Kulin Seth3d833212022-05-20 03:18:09 +000011860class TestFallbackWarning(TestCase):
Nikita Shulga97594a22022-06-09 13:07:03 +000011861 # TODO: Remove once test_testing.py is running on MPS devices
Kulin Seth3d833212022-05-20 03:18:09 +000011862 def test_no_warning_on_import(self):
Nikita Shulga97594a22022-06-09 13:07:03 +000011863 out = subprocess.check_output(
Xuehai Pand2bd9ac2024-07-19 22:09:53 +080011864 [sys.executable, "-W", "always", "-c", "import torch"],
Nikita Shulga97594a22022-06-09 13:07:03 +000011865 stderr=subprocess.STDOUT,
11866 # On Windows, opening the subprocess with the default CWD makes `import torch`
11867 # fail, so just set CWD to this script's directory
11868 cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
Nikita Shulga078c25df2022-11-08 21:10:07 +000011869 self.assertEqual(out, "")
Kulin Seth3d833212022-05-20 03:18:09 +000011870
11871 def _get_not_implemented_op(self):
igm5031b9b3a22023-09-12 16:43:37 +000011872 # This can be changed once we actually implement 'lcm'
Kulin Seth3d833212022-05-20 03:18:09 +000011873 # Should return fn, args, kwargs, string_version
igm5031b9b3a22023-09-12 16:43:37 +000011874 return (torch.lcm,
11875 [torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {},
11876 "torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))")
Kulin Seth3d833212022-05-20 03:18:09 +000011877
11878 def test_error_on_not_implemented(self):
11879 fn, args, kwargs, _ = self._get_not_implemented_op()
11880
Nikita Shulga9b16bf02022-09-12 22:25:26 +000011881 with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
Kulin Seth3d833212022-05-20 03:18:09 +000011882 fn(*args, **kwargs)
11883
11884 def test_warn_on_not_implemented_with_fallback(self):
11885 _, _, _, op = self._get_not_implemented_op()
11886 script = f"""
11887import os
11888# MUST happen before pytorch's import
11889os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11890import warnings
11891
11892with warnings.catch_warnings(record=True) as w:
11893 import torch
11894
11895if len(w) > 0:
Nikita Shulga97594a22022-06-09 13:07:03 +000011896 print(w)
Kulin Seth3d833212022-05-20 03:18:09 +000011897 exit(1)
11898
11899# This should run just fine and raise warning about perf
11900with warnings.catch_warnings(record=True) as w:
11901 {op}
11902
11903if len(w) != 1:
Nikita Shulga97594a22022-06-09 13:07:03 +000011904 print(w)
Kulin Seth3d833212022-05-20 03:18:09 +000011905 exit(2)
Kulin Seth3d833212022-05-20 03:18:09 +000011906"""
11907 try:
11908 subprocess.check_output(
Xuehai Pand2bd9ac2024-07-19 22:09:53 +080011909 [sys.executable, '-W', 'always', '-c', script],
Kulin Seth3d833212022-05-20 03:18:09 +000011910 stderr=subprocess.STDOUT,
11911 # On Windows, opening the subprocess with the default CWD makes `import torch`
11912 # fail, so just set CWD to this script's directory
11913 cwd=os.path.dirname(os.path.realpath(__file__)),)
11914 except subprocess.CalledProcessError as e:
11915 if e.returncode == 1:
Nikita Shulga97594a22022-06-09 13:07:03 +000011916 self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
11917 e.output.decode("utf-8"))
Kulin Seth3d833212022-05-20 03:18:09 +000011918 elif e.returncode == 2:
11919 self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
Nikita Shulga97594a22022-06-09 13:07:03 +000011920 f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
Kulin Seth3d833212022-05-20 03:18:09 +000011921 else:
Nikita Shulga97594a22022-06-09 13:07:03 +000011922 self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
11923 e.output.decode("utf-8"))
Kulin Sethe011a8e2022-05-13 18:28:53 +000011924
Alban Desmaison04ac80c2022-05-20 20:25:12 +000011925class TestNoRegression(TestCase):
11926 def test_assert_close(self):
11927 a = torch.ones(1, device="mps")
11928 b = torch.zeros(1, device="mps")
11929 inf = a / b
11930 nan = b / b
11931
11932 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11933 torch.testing.assert_close(a, inf)
11934
Kulin Seth76cff182022-07-04 06:41:39 +000011935 # TODO: The NaN test is failing when all the tests in test_mps are run
11936 # together but passes when run separately. There seems to be memory
11937 # corruption which needs to be fixed for this test to be enabled.
11938 # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11939 # torch.testing.assert_close(a, nan)
Alban Desmaison04ac80c2022-05-20 20:25:12 +000011940
11941 def test_double_error(self):
11942 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11943 a = torch.ones(2, dtype=torch.float64, device="mps")
11944
11945 a = torch.ones(2, device="mps")
11946 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11947 a = a.double()
11948
11949 def test_legacy_constructor(self):
11950 a = torch.ones(2, device="mps")
11951
11952 b = a.new(1)
11953
Alban Desmaison0a651a22022-06-14 17:54:30 +000011954 def test_serialization_map_location(self):
11955
11956 # Ensures that cpu Tensor can be loaded on mps
11957 with tempfile.NamedTemporaryFile() as f:
11958 x = torch.rand(2)
11959 torch.save(x, f)
11960
11961 f.seek(0)
11962 x2 = torch.load(f, map_location="mps")
11963
11964 self.assertEqual(x, x2)
11965 self.assertEqual(x2.device.type, "mps")
11966
11967 # Ensures that mps Tensors can be loaded on mps
11968 with tempfile.NamedTemporaryFile() as f:
11969 x = torch.rand(2, device="mps")
11970 torch.save(x, f)
11971
11972 f.seek(0)
11973 x2 = torch.load(f)
11974
11975 self.assertEqual(x, x2)
11976 self.assertEqual(x2.device.type, "mps")
11977
11978 # Ensures that mps Tensors can be loaded on cpu
11979 with tempfile.NamedTemporaryFile() as f:
11980 x = torch.rand(2, device="mps")
11981 torch.save(x, f)
11982
11983 f.seek(0)
11984 x2 = torch.load(f, map_location="cpu")
11985
11986 self.assertEqual(x, x2)
11987 self.assertEqual(x2.device.type, "cpu")
11988
magic-akarie56cdfd2023-06-15 15:51:03 +000011989 # Ensures that `mps:0` Tensors can be loaded on mps
11990 with tempfile.NamedTemporaryFile() as f:
11991 x = torch.rand(2, device="mps:0")
11992 torch.save(x, f)
11993
11994 f.seek(0)
11995 x2 = torch.load(f, map_location="mps:0")
11996
11997 self.assertEqual(x, x2)
11998 self.assertEqual(x2.device.type, "mps")
11999
Alban Desmaison0a651a22022-06-14 17:54:30 +000012000
Kulin Seth76cff182022-07-04 06:41:39 +000012001MPS_DTYPES = get_all_dtypes()
Denis Vieriued1957d2023-03-01 01:36:36 +000012002for t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]:
Kulin Seth76cff182022-07-04 06:41:39 +000012003 del MPS_DTYPES[MPS_DTYPES.index(t)]
Alban Desmaison04ac80c2022-05-20 20:25:12 +000012004
Kulin Seth2bb022e2023-03-08 08:41:21 +000012005MPS_GRAD_DTYPES = [torch.float32, torch.float16]
12006
soulitzerbfdfeec2022-08-31 17:53:32 -040012007
Ramin Azarmehrb57e6fd2023-02-13 17:56:24 +000012008class TestConsistency(TestCaseMPS):
Kulin Seth76cff182022-07-04 06:41:39 +000012009 # TODO: This is only used while some ops are being added.
12010 # This list should contain all ops and dtypes eventually
12011 # This can be generated automatically in the `new_mps_allowlist.txt` file
12012 # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
12013 # You most likely do NOT want to modify this manually
Kulin Seth76cff182022-07-04 06:41:39 +000012014
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000012015 FP16_LOW_PRECISION_LIST = {
CaoE2a40b7e2023-10-19 17:47:45 +000012016 'add', 'sub', 'div', 'addcdiv',
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000012017 '__rdiv__', '__rmul__',
12018 'nn.functional.huber_loss',
12019 'true_divide', 'kron',
Nikita Shulgaeb9a3382024-02-13 21:51:27 +000012020 'gradient', 'var', 'std', 'std_mean', 'ldexp',
Jane Xu803d42e2023-07-21 10:03:12 -070012021 'linalg.vector_norm', 'lerp',
Kulin Seth2bb022e2023-03-08 08:41:21 +000012022 'addr', 'var_mean',
12023 'var_mean_unbiased',
Sun, Jiayid56e1b22023-05-11 15:30:59 +080012024 'acosh', 'asinh', 'asin',
12025 'masked.std',
12026 'nn.functional.normalize',
12027 'nn.functional.triplet_margin_loss',
12028 'nn.functional.triplet_margin_with_distance_loss',
CaoE54c28c52023-09-18 19:10:53 -070012029 'nn.functional.batch_norm',
12030 'nn.functional.instance_norm',
Pearu Peterson45401ef2023-06-14 14:00:05 +030012031 'round', 'xlogy', 'addcmul',
Nikita Shulga56771282024-04-18 15:21:01 +000012032 'nn.functional.cross_entropy',
12033 'nn.functional.binary_cross_entropy',
12034 'nn.functional.nll_loss',
CaoE42f94d72023-08-31 18:48:38 -070012035 'nn.functional.max_pool2d',
12036 'nn.functional.gelu',
12037 'nn.functional.glu',
CaoE54c28c52023-09-18 19:10:53 -070012038 '_native_batch_norm_legit',
andrewor14773ae812024-03-18 07:27:27 -070012039 '_batch_norm_with_update',
CaoE54c28c52023-09-18 19:10:53 -070012040 'native_batch_norm',
Cao E1c89ea72023-10-26 08:38:54 +000012041 'softmax',
12042 '_softmax_backward_data',
12043 'log_softmax',
12044 'masked.softmax',
12045 'masked.log_softmax',
12046 'masked.softmin',
12047 'nn.functional.kl_div',
12048 'nn.functional.softmin',
CaoEa310cc82023-10-31 09:12:47 +000012049 'cross', 'linalg.cross',
CaoE26b5e272023-11-05 12:31:38 +000012050 'prod', 'masked.prod',
CaoE455241b2023-11-06 06:01:29 +000012051 'nextafter',
Sun, Jiayic173a9d2023-12-19 15:39:04 +080012052 'native_layer_norm',
12053 'nn.functional.layer_norm',
Sun, Jiayi2dd4a252024-01-18 09:07:16 +000012054 'nn.functional.interpolate',
12055 'nn.functional.upsample_bilinear',
12056 'nn.functional.upsample_nearest',
Kulin Seth2bb022e2023-03-08 08:41:21 +000012057
12058 # for macOS 12
12059 'masked.normalize', 'masked.sum', 'masked.var',
12060 'outer',
12061 'sum_to_size', 'sum',
12062 'mul',
12063 'nansum', 'nanmean',
12064 'norm',
12065 }
12066
12067 FP32_LOW_PRECISION_LIST = {
12068 # conv2d and conv_transpose2d results have a very small
12069 # difference compared to CPU/CUDA, so we use lower precision on FP32
12070 'nn.functional.conv2d',
12071 'nn.functional.conv_transpose2d',
12072 'matmul', '__rmatmul__',
12073 'linalg.multi_dot',
12074 'addbmm',
Ramin Azarmehr7c4acda2023-02-10 19:20:29 +000012075 }
12076
Nikita Shulga07330ff2024-03-13 04:08:06 +000012077 def _compute_tolerances(self, op, dtype):
Nikita Shulga045309a2024-05-28 17:56:13 +000012078 if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]:
Nikita Shulga07330ff2024-03-13 04:08:06 +000012079 return (1e-4, 3e-5)
12080
12081 if op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16:
12082 return (1e-2, 1e-2)
12083
12084 if op.name in ['nn.functional.conv_transpose1d',
12085 'nn.functional.conv_transpose2d',
12086 'nn.functional.conv_transpose3d',
12087 '__rmatmul__', 'addbmm', 'addmv',
12088 'baddbmm', 'cov', 'matmul', 'mv'] and dtype == torch.float16:
12089 return (5e-2, 5e-2)
12090 if op.name == "masked.mean":
12091 return (7e-4, 2e-3)
12092 if op.name == "native_layer_norm":
12093 return (1e-4, 1.3e-5)
12094 if op.name in ["pow", "__rpow__"] and product_version < 13.3:
12095 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
12096 # fixed in macOS 13.3+
12097 return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6)
12098 if op.name == "nn.functional.interpolate":
12099 return (1e-3, 1e-4)
12100 if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']:
12101 # TODO: Investigate why this is needed
12102 # See https://github.com/pytorch/pytorch/issues/120237
12103 return (3e-5, 3e-5)
12104 return (None, None)
12105
Kulin Seth76cff182022-07-04 06:41:39 +000012106 # Used for accept mode only
12107 NEW_ALLOW_LIST = defaultdict(list)
soulitzerbfdfeec2022-08-31 17:53:32 -040012108 NEW_ALLOW_LIST_GRAD = defaultdict(list)
Kulin Seth76cff182022-07-04 06:41:39 +000012109
Nikita Shulga53a4ca42023-08-31 20:41:39 -070012110 @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES + [torch.complex64])
Kulin Seth76cff182022-07-04 06:41:39 +000012111 def test_output_match(self, device, dtype, op):
12112 self.assertEqual(device, "cpu")
Nikita Shulga3859aac2022-12-14 19:51:00 +000012113
Kulin Seth2bb022e2023-03-08 08:41:21 +000012114 def get_samples():
Joel Schlosserc8ab2e82024-07-05 18:48:32 -040012115 return op.sample_inputs(
12116 device,
12117 dtype,
12118 requires_grad=(dtype.is_floating_point or dtype.is_complex),
12119 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12120 set_seed=False,
12121 )
Kulin Seth2bb022e2023-03-08 08:41:21 +000012122 cpu_samples = get_samples()
Kulin Seth76cff182022-07-04 06:41:39 +000012123
Kulin Seth2bb022e2023-03-08 08:41:21 +000012124 for cpu_sample in cpu_samples:
12125 #
12126 # Forward check
12127 #
12128 mps_sample = cpu_sample.transform(
12129 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
12130
12131 cpu_args = [cpu_sample.input] + list(cpu_sample.args)
12132 cpu_kwargs = cpu_sample.kwargs
12133 mps_args = [mps_sample.input] + list(mps_sample.args)
12134 mps_kwargs = mps_sample.kwargs
12135
12136 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
vfdevb7624fc2023-08-29 10:46:02 +000012137 if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
Kulin Seth2bb022e2023-03-08 08:41:21 +000012138 mps_args[1] = cpu_args[1]
12139
12140 cpu_out = op(*cpu_args, **cpu_kwargs)
12141 mps_out = op(*mps_args, **mps_kwargs)
12142
Nikita Shulga07330ff2024-03-13 04:08:06 +000012143 atol, rtol = self._compute_tolerances(op, dtype)
12144 if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8:
vfdev-5d2a2a672023-10-06 10:01:15 +000012145 atol = 1.0
12146 rtol = 0.0
Kulin Seth2bb022e2023-03-08 08:41:21 +000012147
12148 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
12149
12150
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012151 @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
Kulin Seth2bb022e2023-03-08 08:41:21 +000012152 def test_output_grad_match(self, device, dtype, op):
12153 self.assertEqual(device, "cpu")
Kulin Seth76cff182022-07-04 06:41:39 +000012154
soulitzerbfdfeec2022-08-31 17:53:32 -040012155 def get_samples():
Joel Schlosserc8ab2e82024-07-05 18:48:32 -040012156 return op.sample_inputs(
12157 device,
12158 dtype,
12159 requires_grad=(dtype.is_floating_point or dtype.is_complex),
12160 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12161 set_seed=False,
12162 )
soulitzerbfdfeec2022-08-31 17:53:32 -040012163 cpu_samples = get_samples()
12164
soulitzerbfdfeec2022-08-31 17:53:32 -040012165 for cpu_sample in cpu_samples:
12166 #
12167 # Forward check
12168 #
12169 forward_failed = False
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000012170 mps_sample = cpu_sample.transform(
12171 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
Kulin Seth76cff182022-07-04 06:41:39 +000012172
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000012173 cpu_args = [cpu_sample.input] + list(cpu_sample.args)
12174 cpu_kwargs = cpu_sample.kwargs
12175 mps_args = [mps_sample.input] + list(mps_sample.args)
12176 mps_kwargs = mps_sample.kwargs
Kulin Seth76cff182022-07-04 06:41:39 +000012177
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000012178 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
vfdevb7624fc2023-08-29 10:46:02 +000012179 if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000012180 mps_args[1] = cpu_args[1]
Ramin Azarmehrb654d142023-02-07 15:56:46 +000012181
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000012182 cpu_out = op(*cpu_args, **cpu_kwargs)
12183 mps_out = op(*mps_args, **mps_kwargs)
Kulin Seth76cff182022-07-04 06:41:39 +000012184
Nikita Shulga07330ff2024-03-13 04:08:06 +000012185 if op.name == "unique" and cpu_kwargs["sorted"] is False:
12186 continue
12187
12188 atol, rtol = self._compute_tolerances(op, dtype)
12189 if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16:
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000012190 atol = 7e-4
12191 rtol = 1.5e-3
Kulin Seth76cff182022-07-04 06:41:39 +000012192
Aaron Gokaslan3e2ea322023-05-19 17:30:47 +000012193 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
soulitzerbfdfeec2022-08-31 17:53:32 -040012194
soulitzerbfdfeec2022-08-31 17:53:32 -040012195 #
12196 # Backward check
12197 #
Kulin Seth2bb022e2023-03-08 08:41:21 +000012198 if forward_failed:
12199 # We would've failed immediately anyway, but this error is clearer
12200 # We error instead of continuing so that all_backward_pass would not be True
12201 raise RuntimeError("Forward pass already failed")
soulitzerbfdfeec2022-08-31 17:53:32 -040012202
Kulin Seth2bb022e2023-03-08 08:41:21 +000012203 cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
12204 mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
12205
12206 def req_grad(t):
12207 return isinstance(t, torch.Tensor) and t.requires_grad
12208
12209 diff_cpu_out = tuple(t for t in cpu_out if req_grad(t))
12210 diff_mps_out = tuple(t for t in mps_out if req_grad(t))
Peter Bellbbd5b932023-10-30 00:05:29 +000012211 diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t))
12212 diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t))
Kulin Seth2bb022e2023-03-08 08:41:21 +000012213 self.assertEqual(len(diff_cpu_out), len(diff_mps_out))
12214 self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg))
12215
12216 if len(diff_cpu_out) == 0:
soulitzerbfdfeec2022-08-31 17:53:32 -040012217 continue
Kulin Seth2bb022e2023-03-08 08:41:21 +000012218 # rand_like does not work with certain dtypes, so cast to double and cast back
Nikita Shulga6e85a682023-08-25 03:16:18 +000012219 cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out)
Kulin Seth2bb022e2023-03-08 08:41:21 +000012220 mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
soulitzerbfdfeec2022-08-31 17:53:32 -040012221
Kulin Seth2bb022e2023-03-08 08:41:21 +000012222 # Compare computed gradients with cpu given random grad_output vector
12223 # Sometimes when the derivative is 0, we just don't bother creating the graph
12224 # allow_unused is needed in those cases.
12225 cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
12226 mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
soulitzerbfdfeec2022-08-31 17:53:32 -040012227
Kulin Seth2bb022e2023-03-08 08:41:21 +000012228 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
Alex620dbc42022-10-21 19:03:00 +000012229
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012230
Li-Huai (Allan) Lina50fb502023-05-01 14:54:57 +080012231class TestErrorInputs(TestCase):
12232 _ignore_not_implemented_error = True
12233
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012234 @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
12235 def test_error_inputs(self, device, op):
Ramin Azarmehrcecfcf12023-05-09 03:55:16 +000012236 self.assertEqual(device, "mps:0")
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012237
Joel Schlosserc8ab2e82024-07-05 18:48:32 -040012238 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12239 mps_samples = op.error_inputs(device, set_seed=False)
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012240
12241 for mps_sample in mps_samples:
12242 mps_sample_input = mps_sample.sample_input
12243 error_type = mps_sample.error_type
12244 error_regex = mps_sample.error_regex
12245
12246 mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
12247 mps_kwargs = mps_sample_input.kwargs
12248
12249 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
12250 if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
12251 mps_args[1] = mps_args[1].cpu()
12252
12253 with self.assertRaisesRegex(error_type, error_regex):
12254 op(*mps_args, **mps_kwargs)
12255
Nikita Shulga1d610112024-02-08 18:10:59 +000012256class TestComplex(TestCase):
12257 def test_tensor_scalar_binops(self):
12258 # Regression test for https://github.com/pytorch/pytorch/issues/119088
12259 def to_cpu(x):
12260 return x.cpu() if isinstance(x, torch.Tensor) else x
12261
12262 # Allocate tensors on mps
12263 with torch.device("mps"):
12264 inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]]
12265 self.assertTrue(all(x.device.type == "mps" for x in inputs))
12266 # Add scalars
12267 inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)])
12268
12269 # Iterate over all permutations of types(int, float, complex, half) and ops (excluding div)
12270 for x, y in itertools.product(inputs, inputs):
12271 for op_name in ["__add__", "__sub__", "__mul__"]:
12272 x_cpu, y_cpu = map(to_cpu, (x, y))
12273 res = getattr(x, op_name)(y)
12274 res_cpu = getattr(x_cpu, op_name)(y_cpu)
12275 self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}")
12276
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012277
Alex620dbc42022-10-21 19:03:00 +000012278# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
12279@skipIfSlowGradcheckEnv
12280class TestCommon(TestCase):
12281 exact_dtype = True
12282
12283 # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
12284 @classmethod
12285 def tearDownClass(cls):
12286 super().tearDownClass()
12287
12288 if IS_CI:
12289 err_msg = (
12290 "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
12291 "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
12292 )
12293 # Assure no opinfo entry has dynamic_dtypes
12294 filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
12295 for op in filtered_ops:
12296 fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
12297 err_msg += "\n" + fmt_str
12298
12299 assert len(filtered_ops) == 0, err_msg
12300
12301 # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
12302 # MPS still requires some fairly heavy special casing in the test framework.
12303 # When MPS becomes more consistent, this can probably be merged with that test using
12304 # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
Alex620dbc42022-10-21 19:03:00 +000012305 @suppress_warnings
12306 # MPS only supports float32
12307 @ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
12308 def test_numpy_ref_mps(self, device, dtype, op):
12309 # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
12310 # does not support float64 Tensors.
12311 # A few ops are currently broken on their reference inputs, but not their sample inputs. These should
12312 # get patched up and this workaround removed.
Li-Huai (Allan) Lin99967e12024-07-09 19:05:45 -070012313 broken_on_ref_inputs = op.name in ('where',)
Joel Schlosserc8ab2e82024-07-05 18:48:32 -040012314
12315 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12316 inputs = (
12317 op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs
12318 else op.sample_inputs(device, dtype, set_seed=False)
12319 )
Alex620dbc42022-10-21 19:03:00 +000012320 for sample_input in inputs:
12321 self.compare_with_reference(op, op.ref, sample_input)
12322
Nikita Shulga436993d2023-03-04 01:29:07 +000012323 @dtypes(*get_all_dtypes())
12324 def test_tensor_creation(self, device, dtype):
12325 def ones(device):
12326 return torch.ones((2, 2), dtype=dtype, device=device)
Nikita Shulga4ee8aac2024-02-11 16:25:29 +000012327 if dtype not in MPS_DTYPES + ([torch.bfloat16, torch.complex64] if product_version > 14.0 else [torch.complex64]):
Nikita Shulga436993d2023-03-04 01:29:07 +000012328 with self.assertRaises(TypeError):
12329 ones(device)
12330 else:
12331 mps_tensor = ones(device)
12332 cpu_tensor = ones("cpu")
12333 self.assertEqual(mps_tensor.cpu(), cpu_tensor)
12334
Nikita Shulga30610252024-05-03 15:20:39 +000012335
Kulin Seth76cff182022-07-04 06:41:39 +000012336# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
12337# This requires mps to be properly registered in the device generic test framework which is not the
Alex620dbc42022-10-21 19:03:00 +000012338# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
12339# to achieve this.
Kulin Seth76cff182022-07-04 06:41:39 +000012340instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
Li-Huai (Allan) Lin71aea7f2023-04-12 17:19:08 +000012341instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
Nikita Shulga436993d2023-03-04 01:29:07 +000012342instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
Nikita Shulga30610252024-05-03 15:20:39 +000012343instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
Denis Vieriu861bdf92024-08-16 21:07:48 +000012344instantiate_parametrized_tests(TestMPS)
Alban Desmaison04ac80c2022-05-20 20:25:12 +000012345
Kulin Sethe011a8e2022-05-13 18:28:53 +000012346if __name__ == "__main__":
12347 run_tests()