blob: 82aae84b2064ef1023623296d53f09d3bc17075a [file] [log] [blame]
Kulin Sethe011a8e2022-05-13 18:28:53 +00001# -*- coding: utf-8 -*-
2# Owner(s): ["module: mps"]
3
4import sys
5import math
6import random
7import unittest
8import warnings
Kulin Seth3d833212022-05-20 03:18:09 +00009import subprocess
10import os
Kulin Sethe011a8e2022-05-13 18:28:53 +000011import torch
12import torch.nn as nn
13import torch.nn.functional as F
Kulin Seth978304f2022-05-14 13:33:16 +000014import itertools
Kulin Seth3d833212022-05-20 03:18:09 +000015from torch._six import inf
Kulin Sethe011a8e2022-05-13 18:28:53 +000016from torch.nn import Parameter
17from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN
18import torch.backends.mps
Alban Desmaison04ac80c2022-05-20 20:25:12 +000019from torch.distributions import Uniform
Kulin Sethe011a8e2022-05-13 18:28:53 +000020
21from torch.testing._internal.common_nn import NNTestCase
22import numpy as np
23import torch
24
25# Same logic as test_cuda.py
26if not torch.backends.mps.is_available():
27 print('MPS not available, skipping tests', file=sys.stderr)
28 TestCase = object # noqa: F811
29 NNTestCase = object # noqa: F811
30
Kulin Sethe011a8e2022-05-13 18:28:53 +000031class MPSReluTest(TestCase):
32 def _npRelu(self, np_features):
33 return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
34
35 def testNpRelu(self):
36 torch.testing.assert_allclose(
37 np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
38 self._npRelu(
39 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
40 0.9]])))
41
42 def _testRelu(self, np_features, device):
43 np_relu = self._npRelu(np_features)
44 # Convert the numpy array to a PyTorch Tensor,
45 # and move the Tensor to the CPU/GPU based on the "device" parameter
46 py_tensor = torch.from_numpy(np_features).to(device)
47 py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
48 py_relu_cpu = py_relu.to("cpu")
49
50 torch.testing.assert_allclose(np_relu, py_relu_cpu)
51
52 def _testReluInPlace(self, np_features, device):
53 np_relu = self._npRelu(np_features)
54 # Convert the numpy array to a PyTorch Tensor,
55 # and move the Tensor to the CPU/GPU based on the "device" parameter
56 py_tensor = torch.from_numpy(np_features).to(device)
57 py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
58 py_relu_cpu = py_relu.to("cpu")
59
60 torch.testing.assert_allclose(np_relu, py_relu_cpu)
61 # Inplace Relu modifies the initial input and it should match the output of Relu
62 torch.testing.assert_allclose(np_relu, py_tensor.to("cpu"))
63
64 def testNumbersCPU(self):
65 for t in [np.int32]:
66 # Force execution on CPU even if a GPU kernel is available for the type.
67 self._testRelu(
68 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
69 device="cpu")
70 self._testReluInPlace(
71 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
72 device="cpu")
73
74 def testNumbersGPU(self):
75 for t in [np.float16, np.float32]:
76 self._testRelu(
77 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
78 device="mps")
79 self._testReluInPlace(
80 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
81 device="mps")
82
Kulin Sethe011a8e2022-05-13 18:28:53 +000083class MatmulTest(TestCase):
Kulin Seth978304f2022-05-14 13:33:16 +000084 def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
85 if expand_tensor_1_shape:
86 tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
87 else:
88 tensor1_mps = torch.randn(shape_tensor_1, device="mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +000089
Kulin Seth978304f2022-05-14 13:33:16 +000090 if expand_tensor_2_shape:
91 tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
92 else:
93 tensor2_mps = torch.randn(shape_tensor_2, device="mps")
94
95 tensor1_cpu = tensor1_mps.to("cpu")
96 tensor2_cpu = tensor2_mps.to("cpu")
Kulin Sethe011a8e2022-05-13 18:28:53 +000097
98 matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
99 matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
100
101 self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
102
103 def test_vector_x_vector(self):
104 # uses `dot`
105 self._helper(3, 3)
106
107 def test_matrix_x_vector(self):
108 # uses `addmv`
109 self._helper((3, 4), 4)
110
111 def test_batched_matrix_x_broadcasted_vector(self):
112 self._helper((10, 3, 4), 4)
113
114 def test_batched_matrix_x_batched_matrix(self):
115 # uses `bmm.out`
116 self._helper((10, 3, 4), (10, 4, 5))
117
118 def test_batched_matrix_x_broadcasted_matrix(self):
119 self._helper((10, 3, 4), (4, 5))
120
121
122class MPSLeakyReluTest(TestCase):
123 def _npLeakyRelu(self, np_features, negative_slope=0.1):
124 return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
125
126 def testNpLeakyRelu(self):
127 torch.testing.assert_allclose(
128 np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
129 [0.1, -0.03, 0.5, -0.07, 0.9]]),
130 self._npLeakyRelu(
131 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
132 0.9]]),
133 negative_slope=0.1))
134
135 def _testLeakyRelu(self, np_features, negative_slope, device):
136 cpu_x = torch.from_numpy(np_features).requires_grad_()
137 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
138 relu_op = torch.nn.LeakyReLU(negative_slope)
139
140 cpu_leaky_relu = relu_op(cpu_x)
141 mps_leaky_relu = relu_op(mps_x)
142 torch.testing.assert_allclose(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
143
144 # test backward pass
145 cpu_grad = torch.ones_like(cpu_leaky_relu)
146 mps_grad = cpu_grad.to('mps')
147 cpu_leaky_relu.backward(gradient=cpu_grad)
148 mps_leaky_relu.backward(gradient=mps_grad)
149 torch.testing.assert_allclose(cpu_x.grad, mps_x.grad.to('cpu'))
150
151 def testNumbersCPU(self):
152 for t in [np.float32]:
153 self._testLeakyRelu(
154 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
155 negative_slope=0.2,
156 device="cpu")
157
158
159class TestAvgPool(TestCase):
160 def _sum_pool2d(self, x, kernel_size):
161 windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
162 return torch.sum(windows, dim=1)
163
164 def _sum_pool3d(self, x, kernel_size):
165 # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
166 h = kernel_size[0]
167 splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
168 # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
169 splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
170 joined_x = torch.cat(splited_x)
171 return joined_x.view(1, joined_x.numel())
172
173 def _avg_pool2d(self, x, kernel_size):
174 size = reduce((lambda x, y: x * y), kernel_size)
175 return self._sum_pool2d(x, kernel_size) / size
176
177 def _avg_pool3d(self, x, kernel_size):
178 size = reduce((lambda x, y: x * y), kernel_size)
179 return self._sum_pool3d(x, kernel_size) / size
180
181 def test_avg_pool2d_with_zero_divisor(self):
182 self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
183 lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
184
185 def test_doubletensor_avg_pool2d_with_divisor(self):
186 n, m = 3, 3
187 input = torch.rand(1, 1, n, m)
188 for i in range(1, n + 1):
189 for j in range(1, m + 1):
190 for divisor in [1, 7, i * j]:
191 actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
192 actual = actual.view(1, actual.numel())
193 expected = self._sum_pool2d(input, (i, j)) / divisor
194 self.assertEqual(actual, expected, rtol=0, atol=1e-5)
195
196 def test_avg_pool2d_ceil_mode(self):
197 # Regression test for gh-36977
198 x = 10 * torch.randn((1, 16, 4, 4))
199 y = torch.nn.functional.avg_pool2d(
200 x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
201 padding=(0, 1), stride=2)
202 self.assertTrue(not torch.isnan(y).any())
203 y = torch.nn.functional.avg_pool2d(
204 x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
205 padding=(0, 1), stride=2)
206 self.assertTrue(not torch.isnan(y).any())
207
208
209class TestMPS(TestCase):
Kulin Sethe011a8e2022-05-13 18:28:53 +0000210 def test_exp(self, device="mps", dtype=torch.float):
211 for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
212 b = torch.arange(18, device="cpu") / 3 * math.pi
213 a = torch.tensor(v, dtype=dtype, device="cpu") * b
214 a = a.to(dtype).to("mps")
215 self.compare_with_numpy(torch.exp, np.exp, a)
216
217 def test_exp1(self, device="mps", dtype=torch.float):
218 input = torch.tensor([-0.1, 3.0, -0.9]).to('mps')
219 output = torch.exp(input).to('cpu')
Kulin Sethe011a8e2022-05-13 18:28:53 +0000220
221 def _testLeakyRelu(self, np_features, negative_slope, device):
222 cpu_x = torch.from_numpy(np_features).requires_grad_()
223 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
224 relu_op = torch.nn.LeakyReLU(negative_slope)
225
226 cpu_leaky_relu = relu_op(cpu_x)
227 mps_leaky_relu = relu_op(mps_x)
228 torch.testing.assert_allclose(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
229
230 # test backward pass
231 cpu_grad = torch.ones_like(cpu_leaky_relu)
232 mps_grad = cpu_grad.to('mps')
233 cpu_leaky_relu.backward(gradient=cpu_grad)
234 mps_leaky_relu.backward(gradient=mps_grad)
235 torch.testing.assert_allclose(cpu_x.grad, mps_x.grad.to('cpu'))
236
237 def testNumbersGPU(self):
238 for t in [np.float32]:
239 self._testLeakyRelu(
240 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
241 negative_slope=0.1,
242 device="mps")
243
244 def test_fill(self):
245
246 def helper(val, shape):
247 tensor = torch.zeros(shape, device='mps')
248 tensor_mps = tensor.fill_(val)
249 tensor_mps = torch.tanh(tensor_mps)
250
251 tensor_0 = torch.zeros(shape, device='cpu')
252 tensor_cpu = tensor_0.fill_(val)
253 tensor_cpu = torch.tanh(tensor_cpu)
254
255 self.assertEqual(tensor_mps, tensor_cpu)
256
257 helper(0, [1024])
258 helper(0.2, [2, 3])
259
260 def test_mm(self):
261 B = torch.ones(5, 6).to("mps")
262 C = torch.ones(6, 5).to("mps")
263 D = torch.mm(B, C).cpu()
264 torch.testing.assert_allclose(D, torch.full((5, 5), 6.0))
265
266 def test_addmm(self):
267 A = torch.ones(5, 5).to("mps")
268 B = torch.ones(5, 6).to("mps")
269 C = torch.ones(6, 5).to("mps")
270 D = torch.addmm(A, B, C).to("cpu")
271 torch.testing.assert_allclose(D, torch.full((5, 5), 7.0))
272
273 def test_bmm(self):
274 batch1_cpu = torch.randn(10, 3, 4)
275 batch2_cpu = torch.randn(10, 4, 5)
276
277 batch1_mps = batch1_cpu.detach().clone().to("mps")
278 batch2_mps = batch2_cpu.detach().clone().to("mps")
279
280 output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
281 output_mps = torch.bmm(batch1_mps, batch2_mps)
282
283 self.assertEqual(output_cpu, output_mps)
284 self.assertEqual(output_cpu.size(), output_mps.size())
285
286 def test_addbmm(self):
287 M_cpu = torch.randn(3, 5)
288 batch1_cpu = torch.randn(10, 3, 4)
289 batch2_cpu = torch.randn(10, 4, 5)
290
291 M_mps = M_cpu.detach().clone().to("mps")
292 batch1_mps = batch1_cpu.detach().clone().to("mps")
293 batch2_mps = batch2_cpu.detach().clone().to("mps")
294
295 output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
296 output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
297
298 self.assertEqual(output_cpu, output_mps)
299 self.assertEqual(output_cpu.size(), output_mps.size())
300
301 def test_baddbmm(self):
Kulin Seth3d833212022-05-20 03:18:09 +0000302 def helper(input_shape, batch1_shape, batch2_shape):
303 M_cpu = torch.randn(input_shape)
304 batch1_cpu = torch.randn(batch1_shape)
305 batch2_cpu = torch.randn(batch2_shape)
306 alpha = 1.2
307 beta = 0.8
Kulin Sethe011a8e2022-05-13 18:28:53 +0000308
Kulin Seth3d833212022-05-20 03:18:09 +0000309 M_mps = M_cpu.detach().clone().to("mps")
310 batch1_mps = batch1_cpu.detach().clone().to("mps")
311 batch2_mps = batch2_cpu.detach().clone().to("mps")
Kulin Sethe011a8e2022-05-13 18:28:53 +0000312
Kulin Seth3d833212022-05-20 03:18:09 +0000313 output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
314 output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000315
Kulin Seth3d833212022-05-20 03:18:09 +0000316 self.assertEqual(output_cpu, output_mps)
317 self.assertEqual(output_cpu.size(), output_mps.size())
Kulin Sethd63db522022-05-28 14:41:56 +0000318
Kulin Seth3d833212022-05-20 03:18:09 +0000319 helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
320 helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
321 helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
Kulin Sethe011a8e2022-05-13 18:28:53 +0000322
323 def test_local_scalar_dense_mps(self):
324 x_cpu = torch.randn(1)
325 y_mps = x_cpu.to("mps")
326 torch.testing.assert_allclose(x_cpu.item(), y_mps.item())
327
328 def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
329 cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
330 mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
331
332 # Use the same weights and bias as the ones from the cpu
333 mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
334
335 if bias:
336 mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
337
338 linear_mps_input = torch.randn(shape).to('mps')
339 linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
340
341 if backward_pass:
342 linear_mps_input = linear_mps_input.requires_grad_()
343 linear_cpu_input = linear_cpu_input.requires_grad_()
344
345 linear_cpu_output = cpu_linear(linear_cpu_input)
346 linear_mps_output = mps_linear(linear_mps_input)
347
348 self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
349 self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
350
351 if backward_pass:
352 cpu_grad = torch.ones_like(linear_cpu_output)
353 grad = cpu_grad.to('mps')
354
355 linear_cpu_output.backward(gradient=cpu_grad)
356 linear_mps_output.backward(gradient=grad)
357
358 self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
359 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
360
361 self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
362 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
363 if bias:
364 self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
365 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
366
367 def test_linear2D(self):
368 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
369
370 def test_linear2D_backward(self):
371 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
372
373 def test_linear2D_no_bias(self):
374 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
375
376 def test_linear2D_no_bias_backward(self):
377 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
378
379 def test_linear3D(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400380 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000381
Nikita Shulga70508262022-05-25 16:23:10 +0000382 def test_linear3D_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400383 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000384
385 def test_linear3D_no_bias(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400386 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000387
388 def test_linear3D_no_bias_backward(self):
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400389 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000390
391 def test_uniform(self):
392 low = torch.zeros(5, 5, requires_grad=True)
393 high = (torch.ones(5, 5) * 3).requires_grad_()
394 low_1d = torch.zeros(1, requires_grad=True)
395 high_1d = (torch.ones(1) * 3).requires_grad_()
396 self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
397 self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
Kulin Seth3d833212022-05-20 03:18:09 +0000398 self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
399 self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
400 self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
Kulin Sethe011a8e2022-05-13 18:28:53 +0000401
Kulin Seth3d833212022-05-20 03:18:09 +0000402 # Check log_prob computation when value outside range
403 uniform = Uniform(low_1d, high_1d, validate_args=False)
404 above_high = torch.tensor([4.0])
405 below_low = torch.tensor([-1.0])
406 self.assertEqual(uniform.log_prob(above_high).item(), -inf)
407 self.assertEqual(uniform.log_prob(below_low).item(), -inf)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000408
Kulin Seth3d833212022-05-20 03:18:09 +0000409 # check cdf computation when value outside range
410 self.assertEqual(uniform.cdf(below_low).item(), 0)
411 self.assertEqual(uniform.cdf(above_high).item(), 1)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000412
Kulin Seth3d833212022-05-20 03:18:09 +0000413 state = torch.get_rng_state()
414 rand = low.new(low.size()).uniform_()
415 torch.set_rng_state(state)
416 u = Uniform(low, high).rsample()
417 u.backward(torch.ones_like(u))
418 self.assertEqual(low.grad, 1 - rand)
419 self.assertEqual(high.grad, rand)
420 low.grad.zero_()
421 high.grad.zero_()
Kulin Sethe011a8e2022-05-13 18:28:53 +0000422
423 # Test forward maxpool2d
424 def test_max_pool2d(self):
425 def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
426
427 cpu_x = None
428 if(test_ties):
429 cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
430 else:
431 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
432 x = cpu_x.detach().clone().to('mps').requires_grad_()
433
434 pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
435 ceil_mode=ceil_mode, return_indices=return_indices)
436
437 if(return_indices is False):
438 y = pool(x)
439 ref_y = pool(cpu_x)
440
441 cpu_grad = torch.ones_like(ref_y)
442 grad = cpu_grad.to('mps')
443
444 y.backward(gradient=grad)
445 ref_y.backward(gradient=cpu_grad)
446
447 self.assertEqual(y, ref_y)
448 self.assertEqual(x.grad, cpu_x.grad)
449 else:
450 y, idx = pool(x)
451 ref_y, ref_idx = pool(cpu_x)
452
453 cpu_grad = torch.ones_like(ref_y)
454 grad = cpu_grad.to('mps')
455
456 y.backward(gradient=grad)
457 ref_y.backward(gradient=cpu_grad)
458
459 self.assertEqual(y, ref_y)
460 self.assertEqual(idx, ref_idx)
461 self.assertEqual(x.grad, cpu_x.grad)
462
463 # Test with no batch dimension
464 helper((8, 4, 4), ks=2)
465 helper((2, 8, 4, 4), ks=2)
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400466 helper((1, 1000, 32, 32), ks=4)
467 helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +0000468 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400469 helper((1, 1000, 32, 32), ks=4, padding=1)
470 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +0000471 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400472 helper((1, 1000, 32, 32), ks=4, dilation=2)
473 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +0000474 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400475 helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
476 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +0000477
478 # Test return indices
479 for test_ties in [False, True]:
480 # Test with no batch dimension
481 helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
482 helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400483 helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
484 helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d
Kulin Sethe011a8e2022-05-13 18:28:53 +0000485 # Test padding
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400486 helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
487 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
Kulin Sethe011a8e2022-05-13 18:28:53 +0000488 return_indices=True, test_ties=test_ties) # test for max_pool1d
489 # Test dilation
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400490 helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
491 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
Kulin Sethe011a8e2022-05-13 18:28:53 +0000492 return_indices=True, test_ties=test_ties) # test for max_pool1d
493 # Test ceil mode
Alban Desmaisonbde246f2022-05-30 10:36:31 -0400494 helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
495 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
Kulin Sethe011a8e2022-05-13 18:28:53 +0000496 return_indices=True, test_ties=test_ties) # test for max_pool1d
497
498 def test_adaptive_avg_pool2d_output_size_one(self):
499 def helper(size, memory_format):
500 x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
Kulin Seth3d833212022-05-20 03:18:09 +0000501 if memory_format == 'non_contiguous':
502 x = x[::2, ::2, ::2, ::2]
503 else:
504 x = x.to(memory_format=memory_format)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000505
506 net = torch.nn.AdaptiveAvgPool2d((1, 1))
507 out = net(x)
508 ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
509
510 out.sum().backward() # make sure it doesn't crash
511
512 self.assertEqual(out, ref_out)
513 if memory_format == torch.channels_last:
514 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
515 c = out.size(1)
516 self.assertEqual(out.stride(), [c, 1, c, c])
517 else:
518 self.assertTrue(out.is_contiguous())
519 c = out.size(1)
520 self.assertEqual(out.stride(), [c, 1, 1, 1])
521
522 helper((2, 3, 6, 6), torch.contiguous_format)
523
Kulin Seth3d833212022-05-20 03:18:09 +0000524 def test_masked_fill(self):
525 device = "mps"
526 dtype = torch.float32
527 mask_dtype = torch.bool
528
529 with warnings.catch_warnings(record=True) as w:
530 warnings.simplefilter("always")
531 num_dest = 10
532 dst = torch.zeros(num_dest, dtype=dtype, device=device)
533 mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
534 val = random.random()
535 dst2 = torch.zeros(num_dest, dtype=dtype)
536 mask_cpu = mask.to("cpu")
537
538 dst.masked_fill_(mask, val)
539 for i in range(num_dest):
540 if mask_cpu[i]:
541 dst2[i] = val
542 self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
543
544 # test non-contiguous case
545 dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
546 dst2 = dst.contiguous()
547 if dtype.is_complex:
548 mask = dst.abs() > 0
549 else:
550 mask = dst > 0
551 self.assertTrue(not dst.is_contiguous())
552 self.assertTrue(dst2.is_contiguous())
553 dst.masked_fill_(mask.to(mask_dtype), val)
554 dst2.masked_fill_(mask.to(mask_dtype), val)
555 self.assertEqual(dst, dst2, atol=0, rtol=0)
556
557 if mask_dtype == torch.uint8:
558 self.assertEqual(len(w), 3)
559
560 warn = 'masked_fill_ received a mask with dtype torch.uint8,'
561 for wi in w:
562 self.assertEqual(str(wi.message)[0:52], str(warn))
563 else:
564 self.assertEqual(len(w), 0)
565
566 def test_nhwc_operation(self):
567 def helper(shape, channels_last=False):
568 import numpy as np
569 np.random.seed(332)
570 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
571 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
572 if(channels_last):
573 cpu_x = cpu_x.to(memory_format=torch.channels_last)
574 cpu_x.retain_grad()
575 x = cpu_x.detach().clone().to('mps').requires_grad_()
576
577 # This passes
578 self.assertEqual(x, cpu_x)
579
580 helper((2, 2, 2, 2), True)
581
Kulin Sethe011a8e2022-05-13 18:28:53 +0000582 # Test forward batch norm
583 def test_batch_norm(self):
584 def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
585 track_running_stats=True, test_module=False):
586
587 import numpy as np
588 np.random.seed(332)
589 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
590 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
591 if(channels_last):
592 cpu_x = cpu_x.to(memory_format=torch.channels_last)
593 cpu_x.retain_grad()
594 x = cpu_x.detach().clone().to('mps').requires_grad_()
595
596 mean_shape = [shape[1]]
597 cpu_running_mean = None
598 cpu_running_var = None
599 running_mean = None
600 running_var = None
601 if(track_running_stats):
602 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
603 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
604 var_arr = 32 * np.random.random_sample(size=mean_shape)
605 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
606 running_mean = cpu_running_mean.detach().clone().to('mps')
607 running_var = cpu_running_var.detach().clone().to('mps')
608
609 weight = None
610 cpu_weight = None
611 bias = None
612 cpu_bias = None
613 if(wts):
614 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
615 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
616 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
617 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
618
619 y = None
620 ref_y = None
621
622 if(not test_module):
623 y = torch.nn.functional.batch_norm(x, running_mean, running_var,
624 weight=weight,
625 bias=bias,
626 training=training,
627 momentum=momentum, eps=eps)
628 ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
629 weight=cpu_weight,
630 bias=cpu_bias,
631 training=training,
632 momentum=momentum, eps=eps)
633
634 else:
635
636 batchnorm_op = None
637 mps_batchnorm_op = None
638
639 if(len(shape) == 3):
640 batchnorm_op = torch.nn.BatchNorm1d(shape[1],
641 eps=eps,
642 momentum=momentum,
643 affine=wts,
644 track_running_stats=track_running_stats,
645 device='cpu')
646 mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
647 eps=eps,
648 momentum=momentum,
649 affine=wts,
650 track_running_stats=track_running_stats,
651 device='mps')
652 elif(len(shape) == 4):
653 batchnorm_op = torch.nn.BatchNorm2d(shape[1],
654 eps=eps,
655 momentum=momentum,
656 affine=wts,
657 track_running_stats=track_running_stats,
658 device='cpu')
659 mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
660 eps=eps,
661 momentum=momentum,
662 affine=wts,
663 track_running_stats=track_running_stats,
664 device='mps')
665 elif(len(shape) == 5):
666 batchnorm_op = torch.nn.BatchNorm3d(shape[1],
667 eps=eps,
668 momentum=momentum,
669 affine=wts,
670 track_running_stats=track_running_stats,
671 device='cpu')
672 mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
673 eps=eps,
674 momentum=momentum,
675 affine=wts,
676 track_running_stats=track_running_stats,
677 device='mps')
678
679 if(track_running_stats):
680 batchnorm_op.running_mean = cpu_running_mean
681 batchnorm_op.running_var = cpu_running_var
682 mps_batchnorm_op.running_mean = running_mean
683 mps_batchnorm_op.running_var = running_var
684 if(wts):
685 batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
686 batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
687 mps_batchnorm_op.weight = torch.nn.Parameter(weight)
688 mps_batchnorm_op.bias = torch.nn.Parameter(bias)
689
690 ref_y = batchnorm_op(cpu_x)
691 y = mps_batchnorm_op(x)
692
693 self.assertEqual(y, ref_y)
694 if(not test_module):
695 self.assertEqual(running_mean, cpu_running_mean)
696 self.assertEqual(running_var, cpu_running_var)
697 else:
698 self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
699 self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
700
701 cpu_grad = torch.randn(ref_y.shape)
702 grad = cpu_grad.to('mps')
703 ref_y.backward(gradient=cpu_grad)
704 y.backward(gradient=grad)
705
706 self.assertEqual(x.grad, cpu_x.grad)
707 if(wts):
708 if(not test_module):
709 self.assertEqual(weight.grad, cpu_weight.grad)
710 self.assertEqual(bias.grad, cpu_bias.grad)
711 else:
712 self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
713 self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
714
715 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
716 for test_module in [False, True]:
717 for track_running_stats in [True, False]:
Kulin Seth3d833212022-05-20 03:18:09 +0000718 for channels_last in [False]:
Kulin Sethe011a8e2022-05-13 18:28:53 +0000719 if(channels_last and len(shape) != 4):
720 continue
721 # Running stats must be tracked in eval mode
722 if(track_running_stats):
723 helper(shape, eps=0, momentum=1, channels_last=channels_last,
724 track_running_stats=track_running_stats, test_module=test_module)
725 helper(shape, channels_last=channels_last,
726 track_running_stats=track_running_stats, test_module=test_module)
727 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
728 track_running_stats=track_running_stats, test_module=test_module)
729 helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
730 track_running_stats=track_running_stats, test_module=test_module)
731 helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
732 track_running_stats=track_running_stats, test_module=test_module)
733 helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
734 track_running_stats=track_running_stats, test_module=test_module)
735 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
736 track_running_stats=track_running_stats, test_module=test_module)
737 helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
738 track_running_stats=track_running_stats, test_module=test_module)
739 helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
740 track_running_stats=track_running_stats, test_module=test_module)
741 helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
742 track_running_stats=track_running_stats, test_module=test_module)
743
744 # Test forward instance norm
745 def test_instance_norm(self):
746 def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
747
748 import numpy as np
749 np.random.seed(332)
750 arr = (256 - 128) * np.random.random_sample(size=shape) + 128
751 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
752 if(channels_last):
753 cpu_x = cpu_x.to(memory_format=torch.channels_last)
754 cpu_x.retain_grad()
755 x = cpu_x.detach().clone().to('mps').requires_grad_()
756
757 mean_shape = [shape[1]]
758 cpu_running_mean = None
759 cpu_running_var = None
760 running_mean = None
761 running_var = None
762 if(track_running_stats):
763 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
764 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
765 var_arr = 32 * np.random.random_sample(size=mean_shape)
766 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
767 running_mean = cpu_running_mean.detach().clone().to('mps')
768 running_var = cpu_running_var.detach().clone().to('mps')
769
770 weight = None
771 cpu_weight = None
772 bias = None
773 cpu_bias = None
774 if(wts):
775 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
776 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
777 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
778 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
779
780 y = None
781 ref_y = None
782
783 if(not test_module):
784 ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
785 weight=cpu_weight,
786 bias=cpu_bias,
787 momentum=momentum, eps=eps)
788 y = torch.nn.functional.instance_norm(x, running_mean, running_var,
789 weight=weight,
790 bias=bias,
791 momentum=momentum, eps=eps)
792
793 else:
794
795 instancenorm_op = None
796 mps_instancenorm_op = None
797
798 if(len(shape) == 3):
799 instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
800 eps=eps,
801 momentum=momentum,
802 affine=wts,
803 track_running_stats=track_running_stats,
804 device='cpu')
805 mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
806 eps=eps,
807 momentum=momentum,
808 affine=wts,
809 track_running_stats=track_running_stats,
810 device='mps')
811 elif(len(shape) == 4):
812 instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
813 eps=eps,
814 momentum=momentum,
815 affine=wts,
816 track_running_stats=track_running_stats,
817 device='cpu')
818 mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
819 eps=eps,
820 momentum=momentum,
821 affine=wts,
822 track_running_stats=track_running_stats,
823 device='mps')
824 elif(len(shape) == 5):
825 instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
826 eps=eps,
827 momentum=momentum,
828 affine=wts,
829 track_running_stats=track_running_stats,
830 device='cpu')
831 mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
832 eps=eps,
833 momentum=momentum,
834 affine=wts,
835 track_running_stats=track_running_stats,
836 device='mps')
837
838 if(track_running_stats):
839 instancenorm_op.running_mean = cpu_running_mean
840 instancenorm_op.running_var = cpu_running_var
841 mps_instancenorm_op.running_mean = running_mean
842 mps_instancenorm_op.running_var = running_var
843 if(wts):
844 instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
845 instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
846 mps_instancenorm_op.weight = torch.nn.Parameter(weight)
847 mps_instancenorm_op.bias = torch.nn.Parameter(bias)
848
849 ref_y = instancenorm_op(cpu_x)
850 y = mps_instancenorm_op(x)
851
852 self.assertEqual(y, ref_y)
853 if(not test_module):
854 self.assertEqual(running_mean, cpu_running_mean)
855 self.assertEqual(running_var, cpu_running_var)
856 else:
857 self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
858 self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
859
860 cpu_grad = torch.randn(ref_y.shape)
861 grad = cpu_grad.to('mps')
862 ref_y.backward(gradient=cpu_grad)
863 y.backward(gradient=grad)
864
865 self.assertEqual(x.grad, cpu_x.grad)
866 if(wts):
867 if(not test_module):
868 self.assertEqual(weight.grad, cpu_weight.grad)
869 self.assertEqual(bias.grad, cpu_bias.grad)
870 else:
871 self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
872 self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
873
874 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
875 for test_module in [False, True]:
876 for track_running_stats in [True, False]:
877 for channels_last in [False]:
878 if(channels_last and len(shape) != 4):
879 continue
880 # Running stats must be tracked in eval mode
881 if(track_running_stats):
882 helper(shape, eps=0, momentum=1, channels_last=channels_last,
883 track_running_stats=track_running_stats, test_module=test_module)
884 helper(shape, channels_last=channels_last,
885 track_running_stats=track_running_stats, test_module=test_module)
886 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
887 track_running_stats=track_running_stats, test_module=test_module)
888 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
889 track_running_stats=track_running_stats, test_module=test_module)
890 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
891 track_running_stats=track_running_stats, test_module=test_module)
892 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
893 track_running_stats=track_running_stats, test_module=test_module)
894 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
895 track_running_stats=track_running_stats, test_module=test_module)
896 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
897 track_running_stats=track_running_stats, test_module=test_module)
898 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
899 track_running_stats=track_running_stats, test_module=test_module)
900 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
901 track_running_stats=track_running_stats, test_module=test_module)
902
903 # Test conv2d
904 def test_conv2d_unit(self):
905 def helper(input_shape, wt_shape,
906 stride=1, padding=0,
907 dilation=1, groups=1,
908 bias_shape=None):
909
910 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
911 x = cpu_x.detach().clone().to('mps').requires_grad_()
912
913 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
914 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
915
916 cpu_bias = None
917 bias = None
918
919 if(bias_shape is not None):
920 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
921 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
922
923 y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
924 padding=padding, dilation=dilation, groups=groups)
925 ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
926 padding=padding, dilation=dilation, groups=groups)
927
928 cpu_grad = torch.ones_like(ref_y)
929 grad = cpu_grad.to('mps')
930
931 y.backward(gradient=grad)
932 ref_y.backward(gradient=cpu_grad)
933
934 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
935 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
936 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
Kulin Seth3d833212022-05-20 03:18:09 +0000937 if(bias_shape is not None):
938 self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
Kulin Sethe011a8e2022-05-13 18:28:53 +0000939
940 N = 1
941 C_in = 3
942 C_out = 64
943 H = 64
944 W = 64
945 kH = 4
946 kW = 4
947 stride = 2
948 padding = 1
949
950 helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
951
952 N = 4
953 C_in = 16
954 H = 32
955 W = 32
956
957 C_out = 8
958 kH = 3
959 kW = 3
960
961 for groups in [1, 2, 4]:
962 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
963 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
964
965 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
966 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
967
968 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
969 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
970
971 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
972 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
973 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
974 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
975
976 # Test conv transpose 2d
977 def test_conv_transpose2d(self):
978 def helper(input_shape, wt_shape,
979 stride=1, padding=0,
980 output_padding=0,
981 dilation=1, groups=1,
982 bias_shape=None):
983
984 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
985 x = cpu_x.detach().clone().to('mps').requires_grad_()
986
987 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
988 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
989
990 cpu_bias = None
991 bias = None
992
993 if(bias_shape is not None):
994 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
995 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
996
997 y = torch.nn.functional.conv_transpose2d(
998 x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
999 ref_y = torch.nn.functional.conv_transpose2d(
1000 cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
1001 output_padding=output_padding, groups=groups, dilation=dilation)
1002
1003 cpu_grad = torch.randn(ref_y.shape)
1004 grad = cpu_grad.to('mps')
1005
1006 y.backward(gradient=grad)
1007 ref_y.backward(gradient=cpu_grad)
1008
1009 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
1010 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
1011 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
1012
1013 # if(bias_shape is not None):
1014 # print(cpu_bias.grad)
1015 # print(bias.grad.to('cpu'))
1016 # self.assertEqual(bias.grad, cpu_bias.grad)
1017
1018 N = 4
Alban Desmaisonbde246f2022-05-30 10:36:31 -04001019 C_in = 2
Kulin Sethe011a8e2022-05-13 18:28:53 +00001020 H = 32
1021 W = 32
1022
1023 C_out = 8
1024 groups = 1
1025 kH = 3
1026 kW = 3
1027
1028 for stride in [1, 2, 3]:
1029 for padding in [0, 1, 2]:
1030 for output_padding in [0, 1, 2]:
1031 for dilation in [1, 2]:
1032 if(output_padding >= stride or output_padding >= dilation):
1033 continue
1034 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
1035 padding=padding, output_padding=output_padding, dilation=dilation)
1036 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
1037 padding=padding, output_padding=output_padding, dilation=dilation)
1038
1039 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
1040 padding=padding, output_padding=output_padding, dilation=dilation)
1041 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
1042 padding=padding, output_padding=output_padding, dilation=dilation)
1043
1044 # Test sigmoid
1045 def test_sigmoid(self):
1046 def helper(shape):
1047
1048 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
1049 x = cpu_x.detach().clone().to('mps').requires_grad_()
1050
1051 sigmoid_op = torch.nn.Sigmoid()
1052
1053 y = sigmoid_op(x)
1054 ref_y = sigmoid_op(cpu_x)
1055
1056 cpu_grad = torch.ones_like(ref_y)
1057 grad = cpu_grad.to('mps')
1058
1059 y.backward(gradient=grad)
1060 ref_y.backward(gradient=cpu_grad)
1061
1062 self.assertEqual(y, ref_y)
1063 self.assertEqual(x.grad, cpu_x.grad)
1064
1065 helper((2, 3, 4, 5))
1066 helper((2, 3, 4))
1067 helper((2, 8, 4, 5))
1068
1069 # Test tanh
1070 def test_tanh(self):
1071 def helper(shape):
1072
1073 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
1074 x = cpu_x.detach().clone().to('mps').requires_grad_()
1075
1076 tanh_op = torch.nn.Tanh()
1077
1078 y = tanh_op(x)
1079 ref_y = tanh_op(cpu_x)
1080
1081 cpu_grad = torch.ones_like(ref_y)
1082 grad = cpu_grad.to('mps')
1083
1084 y.backward(gradient=grad)
1085 ref_y.backward(gradient=cpu_grad)
1086
1087 self.assertEqual(y, ref_y)
1088 self.assertEqual(x.grad, cpu_x.grad)
1089
1090 helper((2, 3, 4, 5))
1091 helper((2, 3, 4))
1092 helper((2, 8, 4, 5))
1093
1094 def test_threshold(self):
1095 def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
1096 m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
1097
1098 input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
1099 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
1100
1101 output_cpu = m(input_cpu)
1102 output_mps = m(input_mps)
1103
1104 cpu_grad = torch.ones_like(output_cpu)
1105 mps_grad = cpu_grad.to('mps')
1106
1107 self.assertEqual(output_cpu, output_mps)
1108
1109 if requires_grad:
1110 output_cpu.backward(gradient=cpu_grad)
1111 output_mps.backward(gradient=mps_grad)
1112
1113 self.assertEqual(input_cpu.grad, input_mps.grad)
1114
1115 helper(threshold=0.1, value=20, num_elems=2)
1116 helper(threshold=-0.1, value=10, num_elems=10)
1117 helper(threshold=0.5, value=-15, num_elems=100)
1118 helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
1119
1120 # Test pow
1121 def test_pow(self):
1122 def helper(shape):
1123 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1124 x = cpu_x.detach().clone().to('mps')
1125 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1126 y = cpu_y.detach().clone().to('mps')
1127 z = torch.pow(x, y)
1128 ref_z = torch.pow(cpu_x, cpu_y)
1129
1130 self.assertEqual(z, ref_z)
1131
1132 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1133 x = cpu_x.detach().clone().to('mps')
1134 exp = random.random()
1135 z = torch.pow(x, exp)
1136 ref_z = torch.pow(cpu_x, exp)
1137
1138 self.assertEqual(z, ref_z)
1139
1140 helper((2, 8, 4, 5))
1141
1142 # Test addcmul
1143 def test_addcmul(self):
1144 def helper(shape, value):
1145
1146 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1147 x = cpu_x.detach().clone().to('mps')
1148
1149 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1150 y = cpu_y.detach().clone().to('mps')
1151
1152 cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1153 z = cpu_z.detach().clone().to('mps')
1154
1155 y = torch.addcmul(x, y, z, value=value)
1156 ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
1157
1158 self.assertEqual(y, ref_y)
1159
1160 helper((2, 3, 4, 5), 0.1)
1161 helper((2, 8, 4, 5), 0.1)
1162 helper((2, 3, 4, 5), 0.2)
1163 helper((2, 8, 4, 5), 0.2)
1164
1165 # Test addcdiv
1166 def test_addcdiv(self):
1167 def helper(shape, value):
1168 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1169 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1170 # clamp to avoid division by 0
1171 cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
1172 cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1173
1174 mps_x = cpu_x.detach().clone().to('mps')
1175 mps_y = cpu_y.detach().clone().to('mps')
1176 mps_z = cpu_z.detach().clone().to('mps')
1177 mps_out = cpu_out.detach().clone().to('mps')
1178
1179 result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
1180 result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
1181 self.assertEqual(result_div_mps, result_div_cpu)
1182 # test .out variant
1183 self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
1184
1185 helper((2, 3, 4, 5), 0.1)
1186 helper((2, 8, 4, 5), 0.2)
1187 helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally
1188
Ramin Azarmehraa62b3e2022-05-31 19:15:45 +00001189 def test_buffer_size_match(self):
1190 # this test shouldn't cause any crash
1191 size = 16
1192 cpu_A = torch.rand(size, device='cpu')
1193 cpu_F = torch.rand(size, size, size, device='cpu')
1194
1195 mps_A = cpu_A.to('mps')
1196 mps_F = cpu_F.to('mps')
1197 self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
1198
Kulin Sethe011a8e2022-05-13 18:28:53 +00001199 def test_transpose_inplace(self):
1200 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
1201 cpu_x = torch.tensor(values, device='cpu')
1202 mps_x = torch.tensor(values, device='mps')
1203
1204 cpu_x.transpose_(0, 1)
1205 mps_x.transpose_(0, 1)
1206 self.assertEqual(cpu_x, mps_x.to('cpu'))
1207
1208 def test_slice(self):
1209 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
1210 cpu_x = torch.tensor(values, device='cpu')
1211 mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
1212
1213 cpu_slice1 = cpu_x[:2, :]
1214 mps_slice1 = mps_x[:2, :]
Kulin Sethe011a8e2022-05-13 18:28:53 +00001215 self.assertEqual(cpu_slice1, mps_slice1)
1216
1217 cpu_slice2 = cpu_x[:, :1]
1218 mps_slice2 = mps_x[:, :1]
Kulin Sethe011a8e2022-05-13 18:28:53 +00001219 self.assertEqual(cpu_slice2, mps_slice2)
1220
1221 cpu_slice3 = cpu_x[1:2, :]
1222 mps_slice3 = mps_x[1:2, :]
1223 self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
1224
1225 cpu_slice4 = cpu_x[1, :]
1226 mps_slice4 = mps_x[1, :].to('cpu')
1227 self.assertEqual(cpu_slice4, mps_slice4)
1228
Kulin Sethd63db522022-05-28 14:41:56 +00001229 def test_slice_contiguous_view(self):
1230 # https://github.com/pytorch/pytorch/issues/77750
1231
1232 def helper(operator):
1233 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
1234 t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
1235
1236 # contiguous view
1237 x_mps = t_mps[2:] # 3, 4
1238 y_mps = t_mps[:2] # 1, 2
1239
1240 x_cpu = t_cpu[2:]
1241 y_cpu = t_cpu[:2]
1242
1243 res_mps = res_cpu = None
1244 if operator == "<=":
1245 res_mps = x_mps <= y_mps
1246 res_cpu = x_cpu <= y_cpu
1247 if operator == "<":
1248 res_mps = x_mps < y_mps
1249 res_cpu = x_cpu < y_cpu
1250 if operator == ">=":
1251 res_mps = x_mps >= y_mps
1252 res_cpu = x_cpu >= y_cpu
1253 if operator == ">":
1254 res_mps = x_mps >= y_mps
1255 res_cpu = x_cpu >= y_cpu
1256 if operator == "==":
1257 res_mps = x_mps == y_mps
1258 res_cpu = x_cpu == y_cpu
1259 if operator == "!=":
1260 res_mps = x_mps != y_mps
1261 res_cpu = x_cpu != y_cpu
1262
1263 self.assertEqual(res_mps, res_cpu)
1264
1265 for op in ["<=", "<", ">=", ">", "==", "!="]:
1266 helper(op)
1267
1268 def test_index_storage_offset(self):
1269 # https://github.com/pytorch/pytorch/issues/78107
1270
1271 a = torch.tensor([8.2670e-01, -1.0293e+00])
1272 b_cpu = a[0]
1273 c_cpu = a[1]
1274
1275 # both 'b' and 'c' are views of 'a'
1276 # 'b' has a storage offset of 0, while 'c' has a storage offset of 1
1277 # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
1278 # otherwise it ends with same value as 'b'
1279 b = b_cpu.to('mps')
1280 c = c_cpu.to('mps')
1281
1282 res_mps = b > c
1283 res_cpu = b_cpu > c_cpu
1284 self.assertEqual(res_mps, res_cpu)
1285
1286 res_mps = c > b
1287 res_cpu = c_cpu > b_cpu
1288 self.assertEqual(res_mps, res_cpu)
1289
Kulin Sethe011a8e2022-05-13 18:28:53 +00001290 def test_flatten(self):
1291 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
1292 cpu_x = torch.tensor(values, device='cpu')
1293 mps_x = torch.tensor(values, device='mps')
1294
1295 cpu_flatten1 = cpu_x.flatten()
1296 mps_flatten1 = mps_x.flatten().to('cpu')
1297 self.assertEqual(cpu_flatten1, mps_flatten1)
1298
1299 cpu_flatten2 = cpu_x.flatten(start_dim=1)
1300 mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
1301 self.assertEqual(cpu_flatten2, mps_flatten2)
1302
1303 cpu_flatten3 = cpu_x.flatten(end_dim=1)
1304 mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
1305 self.assertEqual(cpu_flatten3, mps_flatten3)
1306
1307 # Test repeat
1308 def test_repeat(self):
1309 def helper(shape, repeats):
1310
1311 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
1312 x = cpu_x.detach().clone().to('mps').requires_grad_()
1313
1314 y = x.repeat(repeats)
1315 ref_y = cpu_x.repeat(repeats)
1316
1317 cpu_grad = torch.randn(ref_y.shape)
1318 grad = cpu_grad.to('mps')
1319
1320 y.backward(gradient=grad)
1321 ref_y.backward(gradient=cpu_grad)
1322
1323 self.assertEqual(y, ref_y)
1324 self.assertEqual(x.grad, cpu_x.grad)
1325
1326 helper((2, 3, 4, 5), (2, 3, 4, 5))
1327 helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
1328 helper((3, 4, 5), (2, 3, 4, 5))
1329 helper((3, 4, 5), (2, 2, 2))
1330
Rohan Mitchellf42b42d2022-05-31 18:23:25 +00001331 def test_count_nonzero(self):
1332 def helper(dtype):
1333 n = [
1334 [[1, 0, 2], [3, 0, 2], [7, 9, -4]],
1335 [[0, 2, 3], [3, 2, 1], [2, 0, 0]],
1336 ]
1337 cpu_x = torch.tensor(n, dtype=dtype)
1338 mps_x = torch.tensor(n, dtype=dtype).to('mps')
1339
1340 # All non-zeros
1341 self.assertEqual(
1342 torch.count_nonzero(cpu_x),
1343 torch.count_nonzero(mps_x)
1344 )
1345
1346 # dim=1
1347 self.assertEqual(
1348 torch.count_nonzero(cpu_x, dim=1),
1349 torch.count_nonzero(mps_x, dim=1)
1350 )
1351
1352 # dim=(0, 1)
1353 self.assertEqual(
1354 torch.count_nonzero(cpu_x, dim=(0, 1)),
1355 torch.count_nonzero(mps_x, dim=(0, 1))
1356 )
1357 helper(torch.int32)
1358 helper(torch.int64)
1359 helper(torch.float16)
1360 helper(torch.float32)
1361
Kulin Sethe011a8e2022-05-13 18:28:53 +00001362 def _test_module_empty_input(self, module, inp, check_size=True):
1363 inp.requires_grad_(True)
1364 out = module(inp)
1365 gO = torch.rand_like(out)
1366 out.backward(gO)
1367 if check_size:
1368 self.assertEqual(out.size(), inp.size())
1369 for p in module.parameters():
1370 if p.requires_grad:
1371 self.assertEqual(p.grad, torch.zeros_like(p.grad))
1372 self.assertEqual(inp.grad, torch.zeros_like(inp))
1373
Lukas Hoeniga52bfe22022-05-24 20:09:45 +00001374 # Test dtype casting, with and without simultaneous device change
1375 def test_to(self):
1376 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
1377 cpu_x = torch.tensor(values, device='cpu')
1378 mps_x = torch.tensor(values, device='mps')
1379
1380 self.assertEqual(cpu_x.int(), mps_x.int().cpu())
1381 self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
1382 self.assertEqual(cpu_x.float(), mps_x.float().cpu())
1383
1384 self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
1385 torch.tensor(1, dtype=torch.int32))
1386 self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
1387 self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
1388 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
1389 torch.tensor(1, dtype=torch.int32))
1390 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
1391 torch.tensor(1.0))
1392 self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
1393 torch.tensor(4, dtype=torch.int32))
1394 self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
1395 torch.tensor(4, dtype=torch.int32))
1396 self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
1397 torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
1398
Kulin Sethd63db522022-05-28 14:41:56 +00001399 def test_setitem_scalar(self) -> None:
1400 device = 'mps'
1401 for dtype in [torch.int32, torch.float32, torch.int64]:
1402 for i in range(3, 6):
1403 for j in range(3, 6):
1404 t = torch.zeros(i, j, dtype=dtype, device=device)
1405 self.assertEqual(t.sum(), 0)
1406 t[1, 1] = 1
1407 t[2, 1] = j
1408 t[1, 2] = i
1409 self.assertEqual(t[1, 1], 1)
1410 self.assertEqual(t[1, 2], i)
1411 self.assertEqual(t[2, 1], j)
1412 self.assertEqual(t.sum(), 1 + i + j)
Nikita Shulga437ecfc2022-05-27 20:46:53 +00001413
Kulin Sethe011a8e2022-05-13 18:28:53 +00001414
1415class TestSmoothL1Loss(TestCase):
1416
1417 def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
1418 # CPU
1419 input_cpu = torch.randn(4, 7, requires_grad=requires_grad)
1420 target_cpu = torch.randn(4, 7)
1421
1422 # MPS
1423 input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
1424 target_mps = target_cpu.detach().clone().to('mps')
1425
1426 smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
1427 smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
1428
1429 self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
1430
1431 if requires_grad:
1432 smooth_l1_loss_cpu.backward()
1433 smooth_l1_loss_mps.backward()
1434 self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
1435
1436 return smooth_l1_loss_cpu, smooth_l1_loss_mps
1437
1438 def test_smooth_l1_loss_reduction_none(self):
1439 self._smooth_l1_loss_helper(reduction="none")
1440
1441 def test_smooth_l1_loss_reduction_mean(self):
1442 self._smooth_l1_loss_helper(reduction="mean")
1443
1444 def test_smooth_l1_loss_reduction_sum(self):
1445 self._smooth_l1_loss_helper(reduction="sum")
1446
1447 def test_smooth_l1_loss_reduction_mean_backward(self):
1448 self._smooth_l1_loss_helper(reduction="mean", requires_grad=True)
1449
1450 def test_smooth_l1_loss_reduction_mean_sum_backward(self):
1451 self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
1452
1453
1454class TestNLLLoss(TestCase):
1455
1456 def test_nll_loss_mismatched_batch(self, device='mps'):
1457 x = torch.randn((10, 3), requires_grad=True, device=device)
1458 # t should have size (10,)
1459 t = torch.zeros((3,), dtype=torch.int64, device=device)
1460 with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
1461 F.nll_loss(x, t)
1462
1463 def test_nll_loss_out_of_bounds_ignore_index(self):
1464
1465 def _test_nll_loss_out_of_bounds_ignore_index(device):
1466 output = []
1467 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
1468 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
1469 t = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
1470 for reduction in ['mean', 'none']:
1471 output.append(F.nll_loss(x, t, ignore_index=255, reduction=reduction))
1472 return output
1473
1474 output_cpu = _test_nll_loss_out_of_bounds_ignore_index(device='cpu')
1475 output_mps = _test_nll_loss_out_of_bounds_ignore_index(device='mps')
1476
1477 for cpu, mps in zip(output_cpu, output_mps):
1478 self.assertEqual(cpu, mps.to('cpu'))
1479
1480 def test_nll_loss_invalid_target_dim(self):
1481
1482 def _test_nll_loss_invalid_target_dim(device):
1483 output = []
1484 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
1485 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
1486 t = torch.zeros((6, 2), dtype=torch.int64, device=device)
1487 with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
1488 F.nll_loss(x, t)
1489
1490 _test_nll_loss_invalid_target_dim(device='cpu')
1491 _test_nll_loss_invalid_target_dim(device='mps')
1492
1493 def test_nll_loss_invalid_weights(self):
1494
1495 def _test_nll_loss_invalid_weights(device):
1496 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
1497 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
1498 t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
1499 invalid_weights = [
1500 torch.zeros(4, device=device),
1501 torch.zeros((1, 3), device=device),
1502 ]
1503 msg = "weight tensor should be defined either for all 3 classes or no classes"
1504 for weight in invalid_weights:
1505 with self.assertRaisesRegex(RuntimeError, msg):
1506 F.nll_loss(x, t, weight=weight)
1507
1508 _test_nll_loss_invalid_weights(device='cpu')
1509 _test_nll_loss_invalid_weights(device='mps')
1510
1511 def _nll_loss_helper(self, input_size, reduction, expected):
1512
1513 # CPU
1514 input = torch.rand(input_size, requires_grad=True, device='cpu')
1515 num_channels = input_size[1]
1516 target_size = (input_size[0], ) + tuple(input_size[2:])
1517 target = torch.randint(num_channels, target_size, device='cpu')
1518
1519 # MPS
1520 input_mps = input.detach().clone().to('mps').requires_grad_()
1521 target_mps = target.detach().clone().to('mps')
1522
1523 output_cpu = F.nll_loss(input, target, reduction=reduction)
1524 output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
1525 # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
1526 self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu'))
1527
1528 output_cpu.sum().backward()
1529 output_mps.sum().backward()
1530 self.assertEqual(input.grad, input_mps.grad.to('cpu'))
1531
1532 def test_as_strided(self):
1533 def helper(n, c):
1534 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
1535 values_1 = [[1.0, 1.0], [1.0, 1.0]]
1536 cpu_x = torch.tensor(values, device='cpu')
1537 ones1 = torch.tensor(values_1, device='mps')
1538 x = cpu_x.detach().clone().to('mps').requires_grad_()
Kulin Seth3d833212022-05-20 03:18:09 +00001539 strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
1540 strided_mps = torch.as_strided(x, (2, 2), (1, 2))
Kulin Sethe011a8e2022-05-13 18:28:53 +00001541
1542 self.assertEqual(strided_mps, strided_cpu)
1543
1544 helper(3, 3)
1545
Kulin Seth3d833212022-05-20 03:18:09 +00001546 def test_sum_backward(self):
1547 def helper(n, c):
1548 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
1549 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
1550 x = cpu_x.detach().clone().to('mps').requires_grad_()
1551
1552 all_sum = torch.sum(x)
1553 all_sum_cpu = torch.sum(cpu_x)
1554
1555 all_sum.backward()
1556 all_sum_cpu.backward()
Kulin Seth3d833212022-05-20 03:18:09 +00001557 self.assertEqual(all_sum, all_sum_cpu)
1558 self.assertEqual(x.grad, cpu_x.grad)
1559
1560 helper(3, 3)
1561
Kulin Sethe011a8e2022-05-13 18:28:53 +00001562 def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
1563 self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
1564 self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
1565 self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
1566 self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
1567 self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
1568
1569 @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
1570 def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
1571 nan = torch.tensor(float('nan'), device=device)
1572 self._nll_loss_helper([1, 3], "mean", nan)
1573 self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
1574 self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
1575 self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
1576 self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
1577
1578 def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
1579 zero = torch.tensor(0, device=device)
1580 self._nll_loss_helper([1, 3], "sum", zero)
1581 self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
1582 self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
1583 self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
1584 self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
1585
1586 def test_nll_loss_byte_target_matches_long(self, device='cpu'):
1587 N, C = 10, 4
1588 input = torch.randn(N, C, device=device, requires_grad=True)
1589 target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
1590
1591 def compute_result_and_gradient(reduction, target_dtype):
1592 result, grad = {}, {}
1593 for dev in ['cpu', 'mps']:
1594 input_dev = input.to(dev)
1595 input_ = input_dev.detach()
1596 input_.requires_grad_()
1597
1598 target_dev = target.to(dev)
1599
1600 prob = F.log_softmax(input_, dim=-1)
1601 loss = nn.NLLLoss(reduction=reduction)
1602 result[dev] = loss(prob, target_dev.to(target_dtype))
1603 result[dev].sum().backward()
1604 grad[dev] = input_.grad
1605
1606 return result, grad
1607
1608 for reduction in ["none", "mean", "sum"]:
1609 result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
1610 result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
1611
1612 self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
1613 self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
1614
1615 # Mean Squared Error
1616 def test_mse_loss(self):
1617 def helper(shape, reduction):
1618 # create the criterion
1619 loss = torch.nn.MSELoss(reduction=reduction)
1620
1621 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
1622 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
1623 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
1624 targetMPS = targetCPU.detach().clone().to('mps')
1625
1626 # forward pass
1627 outputCPU = loss(inputCPU, targetCPU)
1628 outputMPS = loss(inputMPS, targetMPS)
1629 self.assertEqual(outputCPU, outputMPS)
1630
1631 # backward pass
1632 if reduction != 'none':
1633 # chose 2 just to make the grad_output > 1 in backward pass
1634 outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
1635 outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
1636 self.assertEqual(inputCPU.grad, inputMPS.grad)
1637
1638 helper([8, 5, 4], 'none')
1639 helper([7, 5, 2, 4], 'sum')
1640 # verify if changes in shape would cause cached graph lookup problems
1641 helper([7, 5, 2, 4, 6], 'sum')
1642 helper([8, 4, 5, 7, 6], 'mean')
1643
1644 # Binary Cross Enropy
1645 def test_bce_loss(self):
1646 def helper(shape, reduction):
1647 # create the criterion
1648 loss = torch.nn.BCELoss(reduction=reduction)
1649
1650 # input and target must be within [0..1]
1651 input_t = np.random.random_sample(size=shape).astype(np.float32)
1652 target_t = np.random.random_sample(size=shape).astype(np.float32)
1653 inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
1654 targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
1655 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
1656 targetMPS = targetCPU.detach().clone().to('mps')
1657
1658 # forward pass
1659 outputCPU = loss(inputCPU, targetCPU)
1660 outputMPS = loss(inputMPS, targetMPS)
1661 self.assertEqual(outputCPU, outputMPS)
1662
1663 # backward pass
1664 if reduction != 'none':
1665 # chose 0.6 just to have the grad_output != 1
1666 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
1667 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
1668 self.assertEqual(inputCPU.grad, inputMPS.grad)
1669
1670 helper([8, 5, 4], 'none')
1671 helper([7, 5, 2, 4], 'sum')
1672 # verify if changes in shape would cause cached graph lookup problems
1673 helper([7, 5, 2, 4, 6], 'sum')
1674 helper([8, 4, 5, 7, 6], 'mean')
1675
1676 def test_log_softmax(self):
1677 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
1678 cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
1679 mps_x = torch.tensor(values, device='mps', requires_grad=True)
1680
1681 cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
1682 mps_log_softmax = F.log_softmax(mps_x, dim=0)
1683 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
1684
1685 cpu_grad = torch.ones_like(cpu_log_softmax)
1686 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
1687
1688 cpu_log_softmax.backward(gradient=cpu_grad)
1689 mps_log_softmax.backward(gradient=mps_grad)
1690
1691 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
1692
1693 def test_eq(self):
1694 values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
1695 values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
1696 mps_x = torch.tensor(values1, device='mps')
1697 mps_y = torch.tensor(values2, device='mps')
1698 cpu_x = torch.tensor(values1, device='cpu')
1699 cpu_y = torch.tensor(values2, device='cpu')
1700 result_mps = torch.eq(mps_x, mps_y)
1701 result_cpu = torch.eq(cpu_x, cpu_y)
1702
1703 self.assertEqual(result_cpu, result_mps.to('cpu'))
1704
1705 def test_eq_int64(self):
1706 values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
1707 values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
1708 mps_x = torch.tensor(values1, device='mps')
1709 mps_y = torch.tensor(values2, device='mps')
1710 cpu_x = torch.tensor(values1, device='cpu')
1711 cpu_y = torch.tensor(values2, device='cpu')
1712 result_mps = torch.eq(mps_x, mps_y)
1713 result_cpu = torch.eq(cpu_x, cpu_y)
1714
1715 self.assertEqual(result_cpu, result_mps.to('cpu'))
1716
1717 def test_ne(self):
1718 def helper(shape):
1719 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1720 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
1721 mps_x = cpu_x.detach().clone().to('mps')
1722 mps_y = cpu_y.detach().clone().to('mps')
1723 result_mps = torch.ne(mps_x, mps_y)
1724 result_cpu = torch.ne(cpu_x, cpu_y)
1725
1726 self.assertEqual(result_cpu, result_mps.to('cpu'))
1727
1728 helper((2, 3, 4, 5))
1729
1730 def test_ne_scalar(self):
1731 def helper(shape):
1732 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1733 mps_x = cpu_x.detach().clone().to('mps')
1734 result_mps = torch.ne(mps_x, 0.0)
1735 result_cpu = torch.ne(cpu_x, 0.0)
1736
1737 self.assertEqual(result_cpu, result_mps.to('cpu'))
1738
1739 helper((2, 3, 4, 5))
1740
1741 def test_lt(self):
1742 def helper(shape):
1743 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1744 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
1745 mps_x = cpu_x.detach().clone().to('mps')
1746 mps_y = cpu_y.detach().clone().to('mps')
1747 result_mps = torch.lt(mps_x, mps_y)
1748 result_cpu = torch.lt(cpu_x, cpu_y)
1749
1750 self.assertEqual(result_cpu, result_mps.to('cpu'))
1751
1752 helper((2, 3, 4, 5))
1753
1754 def test_lt_scalar(self):
1755 def helper(shape):
1756 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1757 mps_x = cpu_x.detach().clone().to('mps')
1758 result_mps = torch.lt(mps_x, 0.0)
1759 result_cpu = torch.lt(cpu_x, 0.0)
1760
1761 self.assertEqual(result_cpu, result_mps.to('cpu'))
1762
1763 helper((2, 3, 4, 5))
1764
1765 def test_le(self):
1766 def helper(shape):
1767 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1768 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
1769 mps_x = cpu_x.detach().clone().to('mps')
1770 mps_y = cpu_y.detach().clone().to('mps')
1771 result_mps = torch.le(mps_x, mps_y)
1772 result_cpu = torch.le(cpu_x, cpu_y)
1773
1774 self.assertEqual(result_cpu, result_mps.to('cpu'))
1775
1776 helper((2, 3, 4, 5))
1777
1778 def test_le_scalar(self):
1779 def helper(shape):
1780 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1781 mps_x = cpu_x.detach().clone().to('mps')
1782 result_mps = torch.le(mps_x, 0.0)
1783 result_cpu = torch.le(cpu_x, 0.0)
1784
1785 self.assertEqual(result_cpu, result_mps.to('cpu'))
1786
1787 helper((2, 3, 4, 5))
1788
1789 def test_ge(self):
1790 def helper(shape):
1791 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1792 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
1793 mps_x = cpu_x.detach().clone().to('mps')
1794 mps_y = cpu_y.detach().clone().to('mps')
1795 result_mps = torch.ge(mps_x, mps_y)
1796 result_cpu = torch.ge(cpu_x, cpu_y)
1797
1798 self.assertEqual(result_cpu, result_mps.to('cpu'))
1799
1800 helper((2, 3, 4, 5))
1801
1802 def test_ge_scalar(self):
1803 def helper(shape):
1804 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1805 mps_x = cpu_x.detach().clone().to('mps')
1806 result_mps = torch.ge(mps_x, 0.0)
1807 result_cpu = torch.ge(cpu_x, 0.0)
1808
1809 self.assertEqual(result_cpu, result_mps.to('cpu'))
1810
1811 helper((2, 3, 4, 5))
1812
1813 def test_gt(self):
1814 def helper(shape):
1815 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1816 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
1817 mps_x = cpu_x.detach().clone().to('mps')
1818 mps_y = cpu_y.detach().clone().to('mps')
1819 result_mps = torch.gt(mps_x, mps_y)
1820 result_cpu = torch.gt(cpu_x, cpu_y)
1821
1822 self.assertEqual(result_cpu, result_mps.to('cpu'))
1823
1824 helper((2, 3, 4, 5))
1825
1826 def test_gt_scalar(self):
1827 def helper(shape):
1828 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
1829 mps_x = cpu_x.detach().clone().to('mps')
1830 result_mps = torch.gt(mps_x, 0.0)
1831 result_cpu = torch.gt(cpu_x, 0.0)
1832
1833 self.assertEqual(result_cpu, result_mps.to('cpu'))
1834
1835 helper((2, 3, 4, 5))
1836
1837 # Test forward argmax
1838 def test_argmax(self):
1839 def helper(n, c, h, w, dtype=torch.float32):
1840 cpu_x = None
1841 x = None
1842 if(dtype not in [torch.float32, torch.bool]):
1843 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
1844 x = cpu_x.detach().clone().to('mps')
1845 elif (dtype == torch.bool):
1846 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
1847 x = cpu_x.detach().clone().to('mps')
1848 else:
1849 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
1850 x = cpu_x.detach().clone().to('mps').requires_grad_()
1851
1852 y = torch.argmax(x)
1853 ref_y = torch.argmax(cpu_x)
1854 self.assertEqual(y, ref_y)
1855
1856 y_0 = torch.argmax(x, dim=0)
1857 refy_0 = torch.argmax(cpu_x, dim=0)
1858 self.assertEqual(y_0, refy_0)
1859
1860 y_0dim = torch.argmax(x, dim=0, keepdim=True)
1861 refy_0dim = torch.argmax(cpu_x, dim=0, keepdim=True)
1862 self.assertEqual(y_0dim, refy_0dim)
1863
1864 y_1 = torch.argmax(x, dim=1)
1865 refy_1 = torch.argmax(cpu_x, dim=1)
1866 self.assertEqual(y_1, refy_1)
1867
1868 y_1dim = torch.argmax(x, dim=1, keepdim=True)
1869 refy_1dim = torch.argmax(cpu_x, dim=1, keepdim=True)
1870 self.assertEqual(y_1dim, refy_1dim)
1871
1872 y_2 = torch.argmax(x, dim=2)
1873 refy_2 = torch.argmax(cpu_x, dim=2)
1874 self.assertEqual(y_2, refy_2)
1875
1876 y_2dim = torch.argmax(x, dim=2, keepdim=True)
1877 refy_2dim = torch.argmax(cpu_x, dim=2, keepdim=True)
1878 self.assertEqual(y_2dim, refy_2dim)
1879
1880 y_3 = torch.argmax(x, dim=3)
1881 refy_3 = torch.argmax(cpu_x, dim=3)
1882 self.assertEqual(y_3, refy_3)
1883
1884 y_3dim = torch.argmax(x, dim=3, keepdim=True)
1885 refy_3dim = torch.argmax(cpu_x, dim=3, keepdim=True)
1886 self.assertEqual(y_3dim, refy_3dim)
1887
1888 helper(2, 8, 4, 4, torch.float32)
1889 helper(2, 8, 4, 4, torch.int32)
1890 helper(2, 8, 4, 4, torch.float16)
1891 helper(2, 8, 4, 4, torch.int64)
1892
1893 # Test forward max
1894 # Note - don't test grad now
1895 def test_max_el(self):
1896 def helper(n, c, h, w, dtype=torch.float32):
1897
1898 if(dtype not in [torch.float32, torch.bool]):
1899 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
1900 x = cpu_x.detach().clone().to('mps')
1901 elif (dtype == torch.bool):
1902 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
1903 x = cpu_x.detach().clone().to('mps')
1904 else:
1905 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
1906 x = cpu_x.detach().clone().to('mps')
1907
1908 ref_y = torch.max(cpu_x)
1909 y = torch.max(x)
1910 self.assertEqual(y, ref_y)
1911
1912 for dim in [0, 1, 2, 3]:
1913 for keepdim in [True, False]:
1914 y, idx = torch.max(x, dim=dim, keepdim=keepdim)
1915 refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
1916 self.assertEqual(y, refy)
1917 self.assertEqual(idx, refidx)
1918
1919 y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
1920 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
1921 torch.max(x, dim=0, out=(y_0, idx_0))
1922 refy_0, refidx_0 = torch.max(cpu_x, dim=0)
1923 self.assertEqual(y_0, refy_0)
1924 self.assertEqual(idx_0, refidx_0)
1925
1926 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
1927 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
1928 torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
1929 refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
1930 self.assertEqual(y_0dim, refy_0dim)
1931 self.assertEqual(idx_0dim, refidx_0dim)
1932
1933 y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
1934 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
1935 torch.max(x, dim=1, out=(y_1, idx_1))
1936 refy_1, refidx_1 = torch.max(cpu_x, dim=1)
1937 self.assertEqual(y_1, refy_1)
1938 self.assertEqual(idx_1, refidx_1)
1939
1940 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
1941 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
1942 torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
1943 refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
1944 self.assertEqual(y_1dim, refy_1dim)
1945 self.assertEqual(idx_1dim, refidx_1dim)
1946
1947 y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
1948 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
1949 torch.max(x, dim=2, out=(y_2, idx_2))
1950 refy_2, refidx_2 = torch.max(cpu_x, dim=2)
1951 self.assertEqual(y_2, refy_2)
1952 self.assertEqual(idx_2, refidx_2)
1953
1954 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
1955 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
1956 torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
1957 refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
1958 self.assertEqual(y_2dim, refy_2dim)
1959 self.assertEqual(idx_2dim, refidx_2dim)
1960
1961 y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
1962 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
1963 torch.max(x, dim=3, out=(y_3, idx_3))
1964 refy_3, refidx_3 = torch.max(cpu_x, dim=3)
1965 self.assertEqual(y_3, refy_3)
1966 self.assertEqual(idx_3, refidx_3)
1967
1968 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
1969 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
1970 torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
1971 refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
1972 self.assertEqual(y_3dim, refy_3dim)
1973 self.assertEqual(idx_3dim, refidx_3dim)
1974
1975 helper(2, 8, 4, 5, torch.float32)
1976 helper(2, 8, 4, 5, torch.int32)
1977 # helper(2, 8, 4, 5, torch.int64)
1978
1979 def test_any(self):
1980 def helper(shape):
1981 input_xs = []
1982 prod = 1
1983
1984 for i in range(len(shape)):
1985 prod *= shape[i]
1986 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
1987 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
1988 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
1989 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
1990 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
1991 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
1992 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
1993 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
1994 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
1995 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
1996
1997 for i, cpu_x in enumerate(input_xs):
1998 x = cpu_x.detach().clone().to('mps')
1999 y = torch.any(x)
2000 ref_y = torch.any(cpu_x)
2001 self.assertEqual(y, ref_y)
2002
2003 y_0 = torch.any(x, dim=0)
2004 refy_0 = torch.any(cpu_x, dim=0)
2005 self.assertEqual(y_0, refy_0)
2006
2007 y_0dim = torch.any(x, dim=0, keepdim=True)
2008 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
2009 self.assertEqual(y_0dim, refy_0dim)
2010
2011 y_0dim = torch.any(x, dim=0, keepdim=True)
2012 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
2013 self.assertEqual(y_0dim, refy_0dim)
2014
2015 y_1 = torch.any(x, dim=1)
2016 refy_1 = torch.any(cpu_x, dim=1)
2017 self.assertEqual(y_1, refy_1)
2018
2019 y_1dim = torch.any(x, dim=1, keepdim=True)
2020 refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
2021 self.assertEqual(y_1dim, refy_1dim)
2022
2023 if (len(shape) > 2):
2024 y_2 = torch.any(x, dim=2)
2025 refy_2 = torch.any(cpu_x, dim=2)
2026 self.assertEqual(y_2, refy_2)
2027
2028 y_2dim = torch.any(x, dim=2, keepdim=True)
2029 refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
2030 self.assertEqual(y_2dim, refy_2dim)
2031
2032 y_3 = torch.any(x, dim=3)
2033 refy_3 = torch.any(cpu_x, dim=3)
2034 self.assertEqual(y_3, refy_3)
2035
2036 y_3dim = torch.any(x, dim=3, keepdim=True)
2037 refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
2038 self.assertEqual(y_3dim, refy_3dim)
2039 helper((1, 1, 1, 1))
2040 helper((1, 1, 3, 3))
2041 helper((7, 13))
2042 helper((2, 8, 4, 5))
2043
2044 def test_all(self):
2045 def helper(shape):
2046 input_xs = []
2047 prod = 1
2048
2049 for i in range(len(shape)):
2050 prod *= shape[i]
2051 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
2052 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
2053 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
2054 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
2055 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
2056 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
2057 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
2058 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
2059 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
2060 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
2061
2062 for i, cpu_x in enumerate(input_xs):
2063 x = cpu_x.detach().clone().to('mps')
2064 y = torch.all(x)
2065 ref_y = torch.all(cpu_x)
2066 self.assertEqual(y, ref_y)
2067
2068 y_0 = torch.all(x, dim=0)
2069 refy_0 = torch.all(cpu_x, dim=0)
2070 self.assertEqual(y_0, refy_0)
2071
2072 y_0dim = torch.all(x, dim=0, keepdim=True)
2073 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
2074 self.assertEqual(y_0dim, refy_0dim)
2075
2076 y_0dim = torch.all(x, dim=0, keepdim=True)
2077 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
2078 self.assertEqual(y_0dim, refy_0dim)
2079
2080 y_1 = torch.all(x, dim=1)
2081 refy_1 = torch.all(cpu_x, dim=1)
2082 self.assertEqual(y_1, refy_1)
2083
2084 y_1dim = torch.all(x, dim=1, keepdim=True)
2085 refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
2086 self.assertEqual(y_1dim, refy_1dim)
2087 if (len(shape) > 2):
2088 y_2 = torch.all(x, dim=2)
2089 refy_2 = torch.all(cpu_x, dim=2)
2090 self.assertEqual(y_2, refy_2)
2091
2092 y_2dim = torch.all(x, dim=2, keepdim=True)
2093 refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
2094 self.assertEqual(y_2dim, refy_2dim)
2095
2096 y_3 = torch.all(x, dim=3)
2097 refy_3 = torch.all(cpu_x, dim=3)
2098 self.assertEqual(y_3, refy_3)
2099
2100 y_3dim = torch.all(x, dim=3, keepdim=True)
2101 refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
2102 self.assertEqual(y_3dim, refy_3dim)
2103
2104 helper((1, 1, 1, 1))
2105 helper((1, 1, 3, 3))
2106 helper((7, 13))
2107 helper((2, 8, 4, 5))
2108
2109 # Test forward min
2110 def test_min_el(self):
2111 def helper(n, c, h, w):
2112 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
2113 x = cpu_x.detach().clone().to('mps')
2114
2115 y = torch.min(x)
2116 ref_y = torch.min(cpu_x)
2117 self.assertEqual(y, ref_y)
2118
2119 y_0, idx_0 = torch.min(x, dim=0)
2120 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
2121 self.assertEqual(y_0, refy_0)
2122 self.assertEqual(idx_0, refidx_0)
2123
2124 y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
2125 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
2126 torch.min(x, dim=0, out=(y_0, idx_0))
2127 refy_0, refidx_0 = torch.min(cpu_x, dim=0)
2128 self.assertEqual(y_0, refy_0)
2129 self.assertEqual(idx_0, refidx_0)
2130
2131 y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
2132 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
2133 self.assertEqual(y_0dim, refy_0dim)
2134 self.assertEqual(idx_0dim, refidx_0dim)
2135
2136 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
2137 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
2138 torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
2139 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
2140 self.assertEqual(y_0dim, refy_0dim)
2141 self.assertEqual(idx_0dim, refidx_0dim)
2142
2143 y_1, idx_1 = torch.min(x, dim=1)
2144 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
2145 self.assertEqual(y_1, refy_1)
2146 self.assertEqual(idx_1, refidx_1)
2147
2148 y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
2149 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
2150 torch.min(x, dim=1, out=(y_1, idx_1))
2151 refy_1, refidx_1 = torch.min(cpu_x, dim=1)
2152 self.assertEqual(y_1, refy_1)
2153 self.assertEqual(idx_1, refidx_1)
2154
2155 y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
2156 refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
2157 self.assertEqual(y_1dim, refy_1dim)
2158 self.assertEqual(idx_1dim, refidx_1dim)
2159
2160 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
2161 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
2162 torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
2163 refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
2164 self.assertEqual(y_1dim, refy_1dim)
2165 self.assertEqual(idx_1dim, refidx_1dim)
2166
2167 y_2, idx_2 = torch.min(x, dim=2)
2168 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
2169 self.assertEqual(y_2, refy_2)
2170 self.assertEqual(idx_2, refidx_2)
2171
2172 y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
2173 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
2174 torch.min(x, dim=2, out=(y_2, idx_2))
2175 refy_2, refidx_2 = torch.min(cpu_x, dim=2)
2176 self.assertEqual(y_2, refy_2)
2177 self.assertEqual(idx_2, refidx_2)
2178
2179 y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
2180 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
2181 self.assertEqual(y_2dim, refy_2dim)
2182 self.assertEqual(idx_2dim, refidx_2dim)
2183
2184 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
2185 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
2186 torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
2187 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
2188 self.assertEqual(y_2dim, refy_2dim)
2189 self.assertEqual(idx_2dim, refidx_2dim)
2190
2191 y_3, idx_3 = torch.min(x, dim=3)
2192 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
2193 self.assertEqual(y_3, refy_3)
2194 self.assertEqual(idx_3, refidx_3)
2195
2196 y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
2197 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
2198 torch.min(x, dim=3, out=(y_3, idx_3))
2199 refy_3, refidx_3 = torch.min(cpu_x, dim=3)
2200 self.assertEqual(y_3, refy_3)
2201 self.assertEqual(idx_3, refidx_3)
2202
2203 y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
2204 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
2205 self.assertEqual(y_3dim, refy_3dim)
2206 self.assertEqual(idx_3dim, refidx_3dim)
2207
2208 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
2209 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
2210 torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
2211 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
2212 self.assertEqual(y_3dim, refy_3dim)
2213 self.assertEqual(idx_3dim, refidx_3dim)
2214
2215 helper(2, 8, 4, 5)
2216
2217 # Test forward sum
2218 def test_sum(self):
2219 def helper(n, c, h, w, dtype=torch.float32):
2220 cpu_x = None
2221 x = None
2222 if(dtype not in [torch.float32, torch.bool]):
2223 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
2224 x = cpu_x.detach().clone().to('mps')
2225 elif (dtype == torch.bool):
2226 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
2227 x = cpu_x.detach().clone().to('mps')
2228 else:
2229 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
2230 x = cpu_x.detach().clone().to('mps').requires_grad_()
2231
2232 all_sum = torch.sum(x)
2233 all_sum_cpu = torch.sum(cpu_x)
2234
2235 self.assertEqual(all_sum, all_sum_cpu)
2236
2237 nil_dim_sum = torch.sum(x, dim=[])
2238 nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
2239
2240 self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
2241
2242 nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
2243 nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
2244
2245 self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
2246
2247 zero_dim_sum = torch.sum(x, dim=[0])
2248 zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
2249
2250 self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
2251
2252 zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
2253 zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
2254
2255 self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
2256
2257 zero_one_dim_sum = torch.sum(x, dim=[0, 1])
2258 zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
2259
2260 self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
2261
2262 zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
2263 zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
2264
2265 self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
2266
2267 two_three_dim_sum = torch.sum(x, dim=[2, 3])
2268 two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
2269
2270 self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
2271
2272 two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
2273 two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
2274
2275 self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
2276
2277 helper(2, 8, 4, 5)
2278 helper(2, 8, 4, 5, dtype=torch.int32)
2279 helper(2, 8, 4, 5, dtype=torch.int64)
2280 helper(2, 8, 4, 5, dtype=torch.bool)
2281
2282 # Test forward prod
2283 def test_prod(self):
2284 def helper(shape, dtype=torch.float32):
2285 cpu_x = None
2286 x = None
2287 if(dtype not in [torch.float32, torch.bool]):
2288 cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
2289 x = cpu_x.detach().clone().to('mps')
2290 elif (dtype == torch.bool):
2291 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
2292 x = cpu_x.detach().clone().to('mps')
2293 else:
2294 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
2295 x = cpu_x.detach().clone().to('mps').requires_grad_()
2296
2297 all_prod = torch.prod(x)
2298 all_prod_cpu = torch.prod(cpu_x)
2299
2300 self.assertEqual(all_prod, all_prod_cpu)
2301
2302 for dim in range(len(shape)):
2303 dim_prod = torch.prod(x, dim=dim)
2304 dim_prod_cpu = torch.prod(cpu_x, dim=dim)
2305
2306 self.assertEqual(dim_prod, dim_prod_cpu)
2307
2308 dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
2309 dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
2310
2311 self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
2312
2313 for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
2314 helper((2, 3), dtype)
2315
2316 # Test forward mean
2317 def test_mean(self):
2318 def helper(n, c, h, w):
2319 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
2320 x = cpu_x.detach().clone().to('mps').requires_grad_()
2321
2322 all_mean = torch.mean(x)
2323 all_mean_cpu = torch.mean(cpu_x)
2324
2325 self.assertEqual(all_mean, all_mean_cpu)
2326
2327 nil_dim_mean = torch.mean(x, dim=[])
2328 nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
2329
2330 self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
2331
2332 nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
2333 nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
2334
2335 self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
2336
2337 zero_dim_mean = torch.mean(x, dim=[0])
2338 zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
2339
2340 self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
2341
2342 zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
2343 zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
2344
2345 self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
2346
2347 zero_one_dim_mean = torch.mean(x, dim=[0, 1])
2348 zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
2349
2350 self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
2351
2352 zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
2353 zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
2354
2355 self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
2356
2357 two_three_dim_mean = torch.mean(x, dim=[2, 3])
2358 two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
2359
2360 self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
2361
2362 two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
2363 two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
2364
2365 self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
2366
2367 helper(2, 8, 4, 5)
2368
2369 # Test std
2370 def test_std(self):
2371 def helper(shape):
2372 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2373 x = cpu_x.detach().clone().to('mps')
2374
2375 all_std = torch.std(x, unbiased=False)
2376 all_std_cpu = torch.std(cpu_x, unbiased=False)
2377
2378 self.assertEqual(all_std, all_std_cpu)
2379
2380 nil_dim_std = torch.std(x, dim=[], unbiased=False)
2381 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
2382
2383 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
2384
2385 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
2386 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
2387
2388 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
2389
2390 zero_dim_std = torch.std(x, dim=[0], unbiased=False)
2391 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
2392
2393 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
2394
2395 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
2396 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
2397
2398 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
2399
2400 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
2401 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
2402
2403 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
2404
2405 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
2406 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
2407
2408 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
2409
2410 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
2411 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
2412
2413 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
2414
2415 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
2416 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
2417
2418 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
2419
2420 all_std = torch.std(x, unbiased=True)
2421 all_std_cpu = torch.std(cpu_x, unbiased=True)
2422
2423 self.assertEqual(all_std, all_std_cpu)
2424
2425 nil_dim_std = torch.std(x, dim=[], unbiased=True)
2426 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
2427
2428 self.assertEqual(nil_dim_std, nil_dim_std_cpu)
2429
2430 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
2431 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
2432
2433 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
2434
2435 zero_dim_std = torch.std(x, dim=[0], unbiased=True)
2436 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
2437
2438 self.assertEqual(zero_dim_std, zero_dim_std_cpu)
2439
2440 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
2441 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
2442
2443 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
2444
2445 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
2446 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
2447
2448 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
2449
2450 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
2451 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
2452
2453 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
2454
2455 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
2456 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
2457
2458 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
2459
2460 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
2461 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
2462
2463 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
2464
2465 helper((4, 5, 6, 7))
2466
2467 # Test var
2468 def test_var(self):
2469 def helper(shape):
2470 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2471 x = cpu_x.detach().clone().to('mps')
2472
2473 all_var = torch.var(x, unbiased=False)
2474 all_var_cpu = torch.var(cpu_x, unbiased=False)
2475
2476 self.assertEqual(all_var, all_var_cpu)
2477
2478 nil_dim_var = torch.var(x, dim=[], unbiased=False)
2479 nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=False)
2480
2481 self.assertEqual(nil_dim_var, nil_dim_var_cpu)
2482
2483 nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=False)
2484 nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=False)
2485
2486 self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
2487
2488 zero_dim_var = torch.var(x, dim=[0], unbiased=False)
2489 zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=False)
2490
2491 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
2492
2493 zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=False)
2494 zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=False)
2495
2496 self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
2497
2498 zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=False)
2499 zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=False)
2500
2501 self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
2502
2503 zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=False)
2504 zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
2505
2506 self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
2507
2508 two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=False)
2509 two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=False)
2510
2511 self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
2512
2513 two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=False)
2514 two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
2515
2516 self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
2517
2518 all_var = torch.var(x, unbiased=True)
2519 all_var_cpu = torch.var(cpu_x, unbiased=True)
2520
2521 self.assertEqual(all_var, all_var_cpu)
2522
2523 nil_dim_var = torch.var(x, dim=[], unbiased=True)
2524 nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=True)
2525
2526 self.assertEqual(nil_dim_var, nil_dim_var_cpu)
2527
2528 nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=True)
2529 nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=True)
2530
2531 self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
2532
2533 zero_dim_var = torch.var(x, dim=[0], unbiased=True)
2534 zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=True)
2535
2536 self.assertEqual(zero_dim_var, zero_dim_var_cpu)
2537
2538 zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=True)
2539 zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=True)
2540
2541 self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
2542
2543 zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=True)
2544 zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=True)
2545
2546 self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
2547
2548 zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=True)
2549 zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
2550
2551 self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
2552
2553 two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=True)
2554 two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=True)
2555
2556 self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
2557
2558 two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=True)
2559 two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
2560
2561 self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
2562
2563 helper((4, 5, 6, 7))
2564
Kulin Sethe011a8e2022-05-13 18:28:53 +00002565 # Test minimum and maximum
2566 def test_minimum_maximum(self):
2567 def helper(n, c, h, w):
2568 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
2569 cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
2570 mps_x = cpu_x.detach().clone().to('mps')
2571 mps_y = cpu_y.detach().clone().to('mps')
2572
2573 minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
2574 minimum_result_mps = torch.minimum(mps_x, mps_y)
2575 self.assertEqual(minimum_result_cpu, minimum_result_mps)
2576
2577 maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
2578 maximum_result_mps = torch.maximum(mps_x, mps_y)
2579 self.assertEqual(maximum_result_cpu, maximum_result_mps)
2580
2581 helper(1, 1, 4, 5)
2582
2583 # Test clamp_min
2584 def test_clamp_min(self):
2585 def helper(n, c, h, w):
2586 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
2587 x = cpu_x.detach().clone().to('mps')
2588
2589 cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
2590 min_t = cpu_min_t.detach().clone().to('mps')
2591
2592 clamp_min_result = torch.clamp_min(x, min=5.0)
2593 clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
2594
2595 self.assertEqual(clamp_min_result, clamp_min_result_cpu)
2596
2597 clamp_min_t_result = torch.clamp_min(x, min=min_t)
2598 clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
2599
2600 self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
2601
2602 helper(2, 8, 4, 5)
2603
2604 # Test clamp_max
2605
2606 def test_clamp_max(self):
2607 def helper(n, c, h, w):
2608 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
2609 x = cpu_x.detach().clone().to('mps')
2610
2611 cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
2612 max_t = cpu_max_t.detach().clone().to('mps')
2613
2614 clamp_max_result = torch.clamp_max(x, max=100.0)
2615 clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
2616
2617 self.assertEqual(clamp_max_result, clamp_max_result_cpu)
2618
2619 clamp_max_t_result = torch.clamp_max(x, max=max_t)
2620 clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
2621
2622 self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
2623
2624 helper(2, 8, 4, 5)
2625
2626 # Test clamp
2627 def test_clamp(self):
2628 def helper(n, c, h, w):
2629 import numpy as np
2630 upper_bound = 1000
2631 half_upper_bound = upper_bound / 2
2632
2633 # x=[0..1000)
2634 x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
2635 cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
2636 x = cpu_x.detach().clone().to('mps')
2637
2638 # x=[0..500)
2639 min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
2640 cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
2641 min_t = cpu_min_t.detach().clone().to('mps')
2642
2643 # x=[500..1000), to ensure max's are greater than mins
2644 max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
2645 cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
2646 max_t = cpu_max_t.detach().clone().to('mps')
2647
2648 # [200..600]: just an arbitrary range between [0..1000]
2649 clamp_result = torch.clamp(x, min=200.0, max=600.0)
2650 clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
2651 self.assertEqual(clamp_result, clamp_result_cpu)
2652
2653 # test optional scalar refs and cached graph keys by passing only max
2654 clamp_opt_result = torch.clamp(x, max=600.0)
2655 clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
2656 self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
2657
2658 clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
2659 clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
2660 self.assertEqual(clamp_t_result, clamp_t_result_cpu)
2661
2662 # test optional tensor refs and cached graph keys by passing only max
2663 clamp_topt_result = torch.clamp(x, max=max_t)
2664 clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
2665 self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
2666
2667 # test inplace clamping
2668 x.clamp_(min=200.0, max=600.0)
2669 cpu_x.clamp_(min=200.0, max=600.0)
2670 self.assertEqual(cpu_x, x)
2671
2672 helper(2, 8, 4, 5)
2673
2674 def test_divmode(self):
2675 def helper(shape, rounding_mode):
2676 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2677 mps_x = cpu_x.detach().clone().to('mps')
2678 # clamp to avoid division by 0
2679 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
2680 mps_y = cpu_y.detach().clone().to('mps')
2681
2682 result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
2683 result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
2684 self.assertEqual(result_div_mps, result_div_cpu)
2685
2686 helper((2, 8, 4, 5), "floor")
2687 helper((2, 8, 4, 5), "trunc")
2688
2689 def test_rounding(self):
2690 def helper(shape):
2691 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2692 mps_x = cpu_x.detach().clone().to('mps')
2693
2694 result_floor_cpu = torch.floor(cpu_x)
2695 result_floor_mps = torch.floor(mps_x)
2696 self.assertEqual(result_floor_mps, result_floor_cpu)
2697
2698 result_ceil_cpu = torch.ceil(cpu_x)
2699 result_ceil_mps = torch.ceil(mps_x)
2700 self.assertEqual(result_ceil_mps, result_ceil_cpu)
2701
2702 result_trunc_cpu = torch.trunc(cpu_x)
2703 result_trunc_mps = torch.trunc(mps_x)
2704 self.assertEqual(result_trunc_mps, result_trunc_cpu)
2705
2706 result_round_cpu = torch.round(cpu_x)
2707 result_round_mps = torch.round(mps_x)
2708 self.assertEqual(result_round_mps, result_round_cpu)
2709
2710 helper((2, 6, 3, 5))
2711 helper((2, 8, 4, 5))
2712
2713 def test_expand(self):
2714 def helper(n, c):
2715 values = [[1.0], [4.0], [7.0]]
2716 cpu_x = torch.tensor(values, device='cpu')
2717 x = cpu_x.detach().clone().to('mps')
2718
2719 strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
2720 strided_mps = torch.as_strided(x, (3, 4), (1, 0))
2721
Kulin Sethe011a8e2022-05-13 18:28:53 +00002722 self.assertEqual(strided_mps, strided_cpu)
2723
2724 helper(3, 1)
2725
2726 def test_select(self):
2727 def helper(n, c):
2728 cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
2729 x = cpu_x.detach().clone().to('mps').requires_grad_()
2730
2731 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
2732 strided_mps = torch.as_strided(x, (3, 1), (3, 1))
2733 self.assertEqual(strided_mps, strided_cpu)
2734
2735 strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
2736 strided_mps = torch.as_strided(x, (1, 3), (3, 1))
2737 self.assertEqual(strided_mps, strided_cpu)
2738
2739 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
2740 strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
Kulin Sethe011a8e2022-05-13 18:28:53 +00002741
2742 self.assertEqual(strided_mps, strided_cpu)
2743
2744 helper(3, 3)
2745
2746 def test_topk(self):
2747 def helper(shape):
2748 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2749 x = cpu_x.detach().clone().to('mps')
2750 for largest_val in [True, False]:
2751 if (type(shape) == tuple):
2752 for curr_dim in range(0, len(shape)):
2753 dim_size = shape[curr_dim]
2754 for k in range(1, dim_size + 1):
2755 topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest_val)
2756 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest_val)
2757 self.assertEqual(topk_values, topk_values_cpu)
2758 self.assertEqual(topk_indices, topk_indices_cpu)
2759 else:
2760 for k in range(1, shape):
2761 topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest_val)
2762 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest_val)
2763 self.assertEqual(topk_values, topk_values_cpu)
2764 self.assertEqual(topk_indices, topk_indices_cpu)
2765
2766 helper(2)
2767 helper((5, 1))
2768 helper((1, 5))
2769 helper((5, 9, 7, 4))
2770
2771 def test_upsample_nearest_exact2d(self):
2772 def helper(N, C, H, W):
2773 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
2774 requires_grad=True).reshape(N, C, H, W)
2775 inputCPU.retain_grad()
2776 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
2777
2778 outputCPU = torch.nn.functional.interpolate(inputCPU, size=(5, 5), mode='nearest-exact')
2779 outputMPS = torch.nn.functional.interpolate(inputMPS, size=(5, 5), mode='nearest-exact')
2780
2781 self.assertEqual(outputCPU, outputMPS)
2782
2783 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
2784 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
2785
2786 self.assertEqual(inputCPU.grad, inputMPS.grad)
2787
2788 helper(1, 1, 4, 4)
2789 helper(7, 5, 3, 2)
2790
2791 def test_upsample_nearest2d(self):
2792 def helper(N, C, H, W):
2793 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
2794 requires_grad=True).reshape(N, C, H, W)
2795 inputCPU.retain_grad()
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002796 inputMPS = inputCPU.detach().to('mps').requires_grad_()
Kulin Sethe011a8e2022-05-13 18:28:53 +00002797
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002798 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00002799
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002800 for i in values:
2801 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00002802 upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
2803
2804 outputCPU = upsample_nearest2d(inputCPU)
2805 outputMPS = upsample_nearest2d(inputMPS)
2806
2807 self.assertEqual(outputCPU, outputMPS)
2808 upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
2809
2810 outputCPU = upsample_nearest2d(inputCPU)
2811 outputMPS = upsample_nearest2d(inputMPS)
2812
2813 self.assertEqual(outputCPU, outputMPS)
2814
2815 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
2816 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
2817
2818 self.assertEqual(inputCPU.grad, inputMPS.grad)
2819
2820 helper(1, 1, 4, 4)
2821 helper(7, 5, 3, 2)
2822
2823 def test_upsample_bilinear2d(self):
2824 def helper(N, C, H, W):
2825 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
2826 requires_grad=True).reshape(N, C, H, W)
2827 inputCPU.retain_grad()
2828 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
2829
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002830 values = [1, 2, 5, 10, 40]
Kulin Sethe011a8e2022-05-13 18:28:53 +00002831
Alban Desmaisonbde246f2022-05-30 10:36:31 -04002832 for i in values:
2833 for j in values:
Kulin Sethe011a8e2022-05-13 18:28:53 +00002834 upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
2835
2836 outputCPU = upsample_bilinear2d(inputCPU)
2837 outputMPS = upsample_bilinear2d(inputMPS)
2838
2839 self.assertEqual(outputCPU, outputMPS)
2840
2841 upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
2842
2843 outputCPU = upsample_bilinear2d(inputCPU)
2844 outputMPS = upsample_bilinear2d(inputMPS)
2845
2846 self.assertEqual(outputCPU, outputMPS)
2847
2848 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
2849 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
2850
2851 self.assertEqual(inputCPU.grad, inputMPS.grad)
2852
2853 helper(1, 1, 4, 4)
2854 helper(7, 5, 3, 2)
2855
2856 # Test concat forward
2857 def test_cat1(self):
2858 def helper(shape_x, shape_y, shape_z):
2859 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
2860 x = cpu_x.detach().clone().to('mps')
2861
2862 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
2863 y = cpu_y.detach().clone().to('mps')
2864
2865 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
2866 z = cpu_z.detach().clone().to('mps')
2867
2868 cat = torch.cat([x, y, z], dim=1)
2869 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
2870
2871 self.assertEqual(cat, cat_cpu)
2872
2873 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
2874 # Empty test - Currently failing! Empty tensor not handled!
2875 # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
2876
2877 def test_pad(self):
2878 def helper(shape, padding, op):
2879 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2880 inputCPU.retain_grad()
2881 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
2882
2883 padCriteria = op(padding)
2884 outputCPU = padCriteria(inputCPU)
2885 outputMPS = padCriteria(inputMPS)
2886 self.assertEqual(outputCPU, outputMPS)
2887
2888 # backward pass (chose 0.6 just to have the grad_output != 1)
2889 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
2890 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
2891 self.assertEqual(inputCPU.grad, inputMPS.grad)
2892
2893 # 1D Padding
2894 helper((2, 4, 3), 2, nn.ReflectionPad1d)
2895 # verify if a change in shape of input would cause problems with graph caching
2896 helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
2897 # Replication 1D
2898 helper((2, 1, 6), 3, nn.ReplicationPad1d)
2899
2900 # 2D Padding
2901 helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
2902 # verify if a change in shape of input would cause problems with graph caching
2903 helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
2904 # this should make the padding (2, 2, 2, 2)
2905 helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
2906 # verify if a change in shape of padding would cause problems with graph caching
2907 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
2908
2909 # 3D Padding
2910 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
2911 # verify if a change in shape of padding would cause problems with graph caching
2912 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
2913
2914 # Test stack forward
2915 def test_stack(self):
2916 # All shapes must be same
2917 def helper(shape):
2918 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2919 x = cpu_x.detach().clone().to('mps')
2920
2921 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2922 y = cpu_y.detach().clone().to('mps')
2923
2924 cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2925 z = cpu_z.detach().clone().to('mps')
2926
2927 stack = torch.stack([x, y, z], dim=1)
2928 stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
2929
2930 self.assertEqual(stack, stack_cpu)
2931
2932 helper([2, 8, 4, 5])
2933 # Empty test - Currently failing! Empty tensor not handled!
2934 # helper([0, 2, 4, 5])
2935
2936 # Test abs
2937 def test_abs(self):
2938 def helper(shape):
2939 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2940 x = cpu_x.detach().clone().to('mps')
2941
2942 abs_result = torch.abs(x)
2943 abs_result_cpu = torch.abs(cpu_x)
2944
2945 self.assertEqual(abs_result, abs_result_cpu)
2946
2947 helper((2, 8, 4, 5))
2948
2949 def test_log(self):
2950 def helper(shape):
2951 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2952 x = cpu_x.detach().clone().to('mps')
2953
2954 log_result = torch.log(x)
2955 log_result_cpu = torch.log(cpu_x)
2956
2957 self.assertEqual(log_result, log_result_cpu)
2958
2959 helper((2, 8, 4, 5))
2960
2961 def test_log_ten(self):
2962 def helper(shape):
2963 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2964 x = cpu_x.detach().clone().to('mps')
2965
2966 log_ten_result = torch.log10(x)
2967 log_ten_result_cpu = torch.log10(cpu_x)
2968
2969 self.assertEqual(log_ten_result, log_ten_result_cpu)
2970
2971 helper((2, 8, 4, 5))
2972
2973 def test_log_two(self):
2974 def helper(shape):
2975 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2976 x = cpu_x.detach().clone().to('mps')
2977
2978 log_two_result = torch.log2(x)
2979 log_two_result_cpu = torch.log2(cpu_x)
2980
2981 self.assertEqual(log_two_result, log_two_result_cpu)
2982
2983 helper((2, 8, 4, 5))
2984
2985 def test_log1p(self):
2986 def helper(shape):
2987 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
2988 x = cpu_x.detach().clone().to('mps')
2989
2990 log_result = torch.log1p(x)
2991 log_result_cpu = torch.log1p(cpu_x)
2992
2993 self.assertEqual(log_result, log_result_cpu)
2994
2995 helper((2, 8, 4, 5))
2996
2997 def test_logaddexp(self):
2998 def helper(shape):
2999 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3000 x = cpu_x.detach().clone().to('mps')
3001
3002 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3003 y = cpu_y.detach().clone().to('mps')
3004
3005 log_result = torch.logaddexp(x, y)
3006 log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
3007
3008 self.assertEqual(log_result, log_result_cpu)
3009
3010 helper((2, 8, 4, 5))
3011
3012 def test_logaddexp2(self):
3013 def helper(shape):
3014 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3015 x = cpu_x.detach().clone().to('mps')
3016
3017 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3018 y = cpu_y.detach().clone().to('mps')
3019
3020 log_result = torch.logaddexp2(x, y)
3021 log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
3022
3023 self.assertEqual(log_result, log_result_cpu)
3024
3025 helper((2, 8, 4, 5))
3026
3027 # Test concat forward
3028 def test_cat2(self):
3029
3030 def helper1(shape_x, shape_y, shape_z, shape_w):
3031 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
3032 x = cpu_x.detach().clone().to('mps')
3033
3034 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
3035 y = cpu_y.detach().clone().to('mps')
3036
3037 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
3038 z = cpu_z.detach().clone().to('mps')
3039
3040 cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
3041 w = cpu_w.detach().clone().to('mps')
3042
3043 cat = torch.cat([x, y, z, w], dim=1)
3044 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
3045
3046 self.assertEqual(cat, cat_cpu)
3047
3048 def helper(shape_x, shape_y, shape_z):
3049 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
3050 x = cpu_x.detach().clone().to('mps')
3051
3052 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
3053 y = cpu_y.detach().clone().to('mps')
3054
3055 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
3056 z = cpu_z.detach().clone().to('mps')
3057
3058 cat = torch.cat([x, y, z], dim=1)
3059 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
3060
3061 self.assertEqual(cat, cat_cpu)
3062
3063 helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
3064 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
3065 # Empty test - Currently failing! Empty tensor not handled!
3066 # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
3067
3068 # Test isnan
3069 def test_isnan(self):
3070 def helper(shape):
3071 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3072 nan_index = [random.randrange(0, shape[0])]
3073 # make a selected row inf
3074 cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
3075 x = cpu_x.detach().clone().to('mps')
3076
3077 isnan_result = torch.isnan(x)
3078 isnan_result_cpu = torch.isnan(cpu_x)
3079
3080 self.assertEqual(isnan_result, isnan_result_cpu)
3081
3082 helper((8, 2, 4, 5))
3083
3084 # Test reciprocal
3085 def test_reciprocal(self):
3086 def helper(shape):
3087 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3088 x = cpu_x.detach().clone().to('mps').requires_grad_()
3089
3090 reciprocal_result = torch.reciprocal(x)
3091 reciprocal_result_cpu = torch.reciprocal(cpu_x)
3092
3093 cpu_grad = torch.ones_like(reciprocal_result_cpu)
3094 grad = cpu_grad.to('mps')
3095
3096 reciprocal_result.backward(gradient=grad)
3097 reciprocal_result_cpu.backward(gradient=cpu_grad)
3098
3099 self.assertEqual(reciprocal_result, reciprocal_result_cpu)
3100 self.assertEqual(x.grad, cpu_x.grad)
3101
3102 helper((2, 8, 4, 5))
3103
3104 # Test sqrt
3105 def test_sqrt(self):
3106 def helper(shape):
3107 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3108 x = cpu_x.detach().clone().to('mps').requires_grad_()
3109
3110 sqrt_result = torch.sqrt(x)
3111 sqrt_result_cpu = torch.sqrt(cpu_x)
3112
3113 cpu_grad = torch.ones_like(sqrt_result_cpu)
3114 grad = cpu_grad.to('mps')
3115
3116 sqrt_result.backward(gradient=grad)
3117 sqrt_result_cpu.backward(gradient=cpu_grad)
3118
3119 self.assertEqual(sqrt_result, sqrt_result_cpu)
3120 self.assertEqual(x.grad, cpu_x.grad)
3121
3122 helper((2, 8, 4, 5))
3123
3124 # Test selu, elu, celu
3125 def test_elu(self):
3126 def helper(shape, alpha=1.0):
3127 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3128 x = cpu_x.detach().clone().to('mps').requires_grad_()
3129
3130 for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
3131 elu_result = activation_func(x)
3132 elu_result_cpu = activation_func(cpu_x)
3133
3134 cpu_grad = torch.randn(elu_result_cpu.shape)
3135 grad = cpu_grad.to('mps')
3136
3137 elu_result.backward(gradient=grad)
3138 elu_result_cpu.backward(gradient=cpu_grad)
3139
3140 self.assertEqual(elu_result, elu_result_cpu)
3141 self.assertEqual(x.grad, cpu_x.grad)
3142
3143 # Test empty shape too
3144 for shape in [[], (2, 3), (2, 8, 4, 5)]:
3145 for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
3146 helper(shape, alpha)
3147 # Test silu
3148
3149 def test_silu(self):
3150 def helper(shape):
3151 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3152 x = cpu_x.detach().clone().to('mps').requires_grad_()
3153
3154 silu_result = torch.nn.SiLU()(x)
3155 silu_result_cpu = torch.nn.SiLU()(cpu_x)
3156
3157 cpu_grad = torch.randn(silu_result_cpu.shape)
3158 grad = cpu_grad.to('mps')
3159
3160 silu_result.backward(gradient=grad)
3161 silu_result_cpu.backward(gradient=cpu_grad)
3162
3163 self.assertEqual(silu_result, silu_result_cpu)
3164 self.assertEqual(x.grad, cpu_x.grad)
3165
3166 # Test empty shape too
3167 for shape in [[], (2, 3), (2, 8, 4, 5)]:
3168 helper(shape)
3169
3170 # Test adaptive avg pool2d - when the input size is a multiple of output size
3171 # Not testing for channels last right now
3172 def test_adaptive_avg_pool2d_simple(self):
3173 def helper(input_shape, out_shape, channels_last):
3174 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
3175 if(channels_last):
3176 cpu_x = cpu_x.to(memory_format=torch.channels_last)
3177 cpu_x.retain_grad()
3178 x = cpu_x.detach().clone().to('mps').requires_grad_()
3179
3180 avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
3181 avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
3182
3183 cpu_grad = torch.randn(avg_result_cpu.shape)
3184 grad = cpu_grad.to('mps')
3185
3186 avg_result.backward(gradient=grad)
3187 avg_result_cpu.backward(gradient=cpu_grad)
3188
3189 self.assertEqual(avg_result, avg_result_cpu)
3190 self.assertEqual(x.grad, cpu_x.grad)
3191
3192 helper((2, 2, 4, 4), (2, 2), False)
3193 helper((2, 2, 9, 9), (3, 3), False)
3194 helper((2, 2, 9, 9), (9, 9), False)
3195 helper((2, 2, 16, 16), (2, 2), False)
3196 helper((2, 2, 16, 16), (2, 16), False)
3197
3198 helper((2, 16, 16), (4, 4), False)
3199
Kulin Seth2e32d5f2022-05-27 11:59:07 +00003200 # Test max avg pool2d - when the input size is a multiple of output size
3201 # Not testing for channels last right now
3202 def test_adaptive_max_pool2d_simple(self):
3203 def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
3204 cpu_x = None
3205 if(dtype in [torch.float16, torch.float32]):
3206 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
3207 else:
3208 cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
3209 if(channels_last):
3210 cpu_x = cpu_x.to(memory_format=torch.channels_last)
3211 cpu_x.retain_grad()
3212 x = cpu_x.detach().clone().to('mps').requires_grad_()
3213
3214 max_result, max_indices = None, None
3215 max_result_cpu, max_indices_cpu = None, None
3216
3217 if(return_indices):
3218 max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
3219 max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
3220 else:
3221 max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
3222 max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
3223
3224 cpu_grad = torch.randn(max_result_cpu.shape)
3225 grad = cpu_grad.to('mps')
3226
3227 max_result.backward(gradient=grad)
3228 max_result_cpu.backward(gradient=cpu_grad)
3229
3230 self.assertEqual(max_result, max_result_cpu)
3231 if(return_indices):
3232 self.assertEqual(max_indices, max_indices_cpu)
3233 self.assertEqual(x.grad, cpu_x.grad)
3234
3235 for dtype in [torch.float32]:
3236 for return_indices in [False, True]:
3237 helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
3238 helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
3239 helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
3240 helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
3241 helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
3242 helper((2, 16, 16), (4, 4), return_indices, dtype)
3243
Kulin Sethe011a8e2022-05-13 18:28:53 +00003244 def test_gelu_simple(self):
3245 def helper(shape):
3246 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3247 x = cpu_x.detach().clone().to('mps').requires_grad_()
3248
3249 gelu_result = torch.nn.GELU()(x)
3250 gelu_result_cpu = torch.nn.GELU()(cpu_x)
3251
3252 cpu_grad = torch.ones_like(gelu_result_cpu)
3253 grad = cpu_grad.to('mps')
3254
3255 gelu_result.backward(gradient=grad)
3256 gelu_result_cpu.backward(gradient=cpu_grad)
3257
3258 self.assertEqual(gelu_result, gelu_result_cpu)
3259 self.assertEqual(x.grad, cpu_x.grad)
3260
3261 # Test empty shape too
3262 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
3263 helper(shape)
3264
Kulin Seth3d833212022-05-20 03:18:09 +00003265 def test_gelu(self):
3266 def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
3267 numpy_dtype = {
3268 torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
3269 }[dtype]
3270 devices = ['cpu']
3271 devices += ['mps']
3272
3273 def _gelu_ref(X):
3274 return X * stats.norm.cdf(X)
3275
3276 for d in devices:
3277 X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
3278 res = X
3279 ref = (X.to(numpy_dtype).cpu().detach().numpy())
3280 self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
3281
Alban Desmaisonbde246f2022-05-30 10:36:31 -04003282 for n in [1, 5, 10]:
3283 for m in [1, 5, 10]:
Kulin Seth3d833212022-05-20 03:18:09 +00003284 _test_gelu(n, m, torch.float32, True)
3285 _test_gelu(n, m, torch.float32, False)
3286
3287 # Test multi threaded
3288 num_threads = torch.get_num_threads()
3289 torch.set_num_threads(4)
3290 try:
3291 _test_gelu(32, 32, torch.float32, False)
3292 finally:
3293 torch.set_num_threads(num_threads)
3294
Kulin Sethe011a8e2022-05-13 18:28:53 +00003295 # Test hardtanh
3296 def test_hardtanh(self):
3297 def helper(shape, min_val, max_val, inplace=False):
3298 cpu_x = None
3299 x = None
3300
3301 if(not inplace):
3302 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3303 x = cpu_x.detach().clone().to('mps').requires_grad_()
3304 else:
3305 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3306 x = cpu_x.detach().clone().to('mps')
3307
3308 hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
3309 hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
3310
3311 self.assertEqual(hardtanh_result, hardtanh_result_cpu)
3312
3313 if(not inplace):
3314 cpu_grad = torch.randn(hardtanh_result_cpu.shape)
3315 grad = cpu_grad.to('mps')
3316 hardtanh_result.backward(gradient=grad)
3317 hardtanh_result_cpu.backward(gradient=cpu_grad)
3318 self.assertEqual(x.grad, cpu_x.grad)
3319
3320 # Test empty shape too
3321 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
3322 for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
3323 helper(shape, min_val, max_val)
3324 helper(shape, min_val, max_val, inplace=True)
3325
Kulin Seth3d833212022-05-20 03:18:09 +00003326 def test_transpose_2D(self):
3327 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3328 values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
3329 cpu_x = torch.tensor(values, device='cpu')
3330 mps_x = torch.tensor(values, device='mps')
3331 mps_x1 = torch.tensor(values1, device='mps')
3332
3333 cpu_transpose = torch.transpose(cpu_x, 0, 1)
3334 mps_transpose = torch.transpose(mps_x, 0, 1)
3335 self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
3336
3337 def test_transpose_3D(self):
3338 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3339 cpu_x = torch.tensor(values, device='cpu')
3340 mps_x = torch.tensor(values, device='mps')
3341
3342 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
3343 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
3344 self.assertEqual(cpu_transpose1, mps_transpose1)
3345
3346 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
3347 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
3348 self.assertEqual(cpu_transpose2, mps_transpose2)
3349
3350 cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
3351 mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
3352 self.assertEqual(cpu_transpose3, mps_transpose3)
3353
3354
3355 def test_transpose_4D(self):
3356 values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
3357 [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
3358 cpu_x = torch.tensor(values, device='cpu')
3359 mps_x = torch.tensor(values, device='mps')
3360
3361 cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
3362 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
3363 self.assertEqual(cpu_transpose1, mps_transpose1)
3364
3365 cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
3366 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
3367 self.assertEqual(cpu_transpose2, mps_transpose2)
3368
3369 cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
3370 mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
3371 self.assertEqual(cpu_transpose3, mps_transpose3)
3372
3373 cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
3374 mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
3375 self.assertEqual(cpu_transpose4, mps_transpose4)
3376
3377 cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
3378 mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
3379 self.assertEqual(cpu_transpose5, mps_transpose5)
3380
3381 cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
3382 mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
3383 self.assertEqual(cpu_transpose6, mps_transpose6)
3384
Kulin Sethe011a8e2022-05-13 18:28:53 +00003385 # Test sign
3386 def test_sign(self):
3387 def helper(shape):
3388 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3389 x = cpu_x.detach().clone().to('mps').requires_grad_()
3390
3391 sign_result = torch.sign(x)
3392 sign_result_cpu = torch.sign(cpu_x)
3393
3394 cpu_grad = torch.ones_like(sign_result_cpu)
3395 grad = cpu_grad.to('mps')
3396
3397 sign_result.backward(gradient=grad)
3398 sign_result_cpu.backward(gradient=cpu_grad)
3399
3400 self.assertEqual(sign_result, sign_result_cpu)
3401
3402 helper((2, 8, 4, 5))
3403
3404 # Test neg
3405 def test_neg(self):
3406 def helper(shape):
3407 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3408 x = cpu_x.detach().clone().to('mps').requires_grad_()
3409
3410 neg_result = torch.neg(x)
3411 neg_result_cpu = torch.neg(cpu_x)
3412
3413 cpu_grad = torch.ones_like(neg_result_cpu)
3414 grad = cpu_grad.to('mps')
3415
3416 neg_result.backward(gradient=grad)
3417 neg_result_cpu.backward(gradient=cpu_grad)
3418
3419 self.assertEqual(neg_result, neg_result_cpu)
3420
3421 helper((2, 8, 4, 5))
3422
3423 # Test index select
3424 def test_index_select(self):
3425 def helper(shape, dim, index, idx_dtype=torch.int32):
3426 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3427 x = cpu_x.detach().clone().to('mps')
3428
3429 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
3430 idx = cpu_idx.detach().clone().to('mps')
3431
Kulin Sethe011a8e2022-05-13 18:28:53 +00003432 idx_result = torch.index_select(x, dim=dim, index=idx)
3433 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
3434
3435 self.assertEqual(idx_result, idx_result_cpu)
3436
3437 helper((2, 8, 4, 5), 0, [1])
3438 helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
3439 helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
3440 helper((2, 8, 4, 5), 2, [3, 0, 1])
3441 helper((2, 8, 4, 5), 3, [2, 3, 0])
3442 helper((2, 3, 3), -1, [1, 2])
3443
3444 def test_embedding_dense_backward(self):
3445 def helper(n, d, m):
3446 embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
3447 W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
3448 idx_MPS = torch.tensor([0, 1, 2]).to('mps')
3449 a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable
3450 a_MPS.retain_grad()
3451 b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place
3452 b_MPS.retain_grad()
3453 out_MPS = (a_MPS.unsqueeze(0) + b_MPS.unsqueeze(1))
3454 loss_MPS = out_MPS.sigmoid().prod()
3455 loss_MPS.backward()
3456
3457 embeddingCPU = nn.Embedding(n, d, max_norm=True, scale_grad_by_freq=True)
3458 W_CPU = W_MPS.to('cpu')
3459 idx_CPU = torch.tensor([0, 1, 2])
3460 a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable
3461 a_CPU.retain_grad()
3462 b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place
3463 b_CPU.retain_grad()
3464 out_CPU = (a_CPU.unsqueeze(0) + b_CPU.unsqueeze(1))
3465 loss_CPU = out_CPU.sigmoid().prod()
3466 loss_CPU.backward()
3467
3468 self.assertEqual(b_CPU.grad, b_MPS.grad)
3469 self.assertEqual(a_CPU.grad, a_MPS.grad)
3470
3471 helper(3, 5, 7)
3472
3473 # Test pytorch gather
3474 def test_gather(self):
3475 def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
3476 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3477 x = cpu_x.detach().clone().to('mps').requires_grad_()
3478
3479 # Indices should be taken from range of axis along which gathering is done
3480 idx_np = np.random.randint(0, shape[dim], idx_shape)
3481
3482 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
3483 idx = cpu_idx.detach().clone().to('mps')
3484
3485 gather_result = torch.gather(x, dim=dim, index=idx)
3486 gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
3487
3488 cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
3489 grad = cpu_grad.to('mps')
3490 gather_result.backward(gradient=grad)
3491 gather_result_cpu.backward(gradient=cpu_grad)
3492
3493 self.assertEqual(gather_result, gather_result_cpu)
3494 self.assertEqual(cpu_x.grad, x.grad)
3495
3496 helper((6, 3, 3), 0, (3, 3, 3))
3497 helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
3498 helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
3499 helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
3500 helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
3501 helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
3502 helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
3503 helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
3504 helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
3505
3506 # Test pytorch scatter_add and scatter
3507 def test_scatter_add(self):
3508 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
3509 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3510 x = cpu_x.detach().clone().to('mps').requires_grad_()
3511
3512 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
3513 src = cpu_src.detach().clone().to('mps').requires_grad_()
3514
3515 # Indices should be taken from range of axis along which gathering is done
3516 idx_np = None
3517 if(do_add):
3518 idx_np = np.random.randint(0, shape[dim], idx_shape)
3519 else:
3520 idx_np = np.array([[0, 1, 2],
3521 [1, 2, 3],
3522 [2, 3, 4],
3523 [3, 4, 5],
3524 [4, 5, 6]])
3525
3526 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
3527 idx = cpu_idx.detach().clone().to('mps')
3528
3529 scatter_result = None
3530 scatter_result_cpu = None
3531
3532 if(do_add):
3533 scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
3534 scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
3535 else:
3536 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
3537 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
3538
3539 cpu_grad = None
3540 grad = None
3541
3542 if(idx_shape == src_shape):
3543 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
3544 grad = cpu_grad.to('mps')
3545 scatter_result.backward(gradient=grad)
3546 scatter_result_cpu.backward(gradient=cpu_grad)
3547
3548 self.assertEqual(scatter_result, scatter_result_cpu)
3549 if(idx_shape == src_shape):
3550 self.assertEqual(cpu_x.grad, x.grad)
3551 self.assertEqual(cpu_src.grad, src.grad)
3552
3553 helper((2, 3), 0, (5, 3), (5, 3))
3554 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
3555 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
3556 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
3557 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
3558 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
3559
3560 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
3561 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
3562 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
3563 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
3564
3565 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
3566 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
3567 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
3568
3569 # Test scatter src
3570 helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
3571 helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
3572
3573 # Test pytorch scatter_reduce
3574 def test_scatter_reduce(self):
3575 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
3576 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3577 x = cpu_x.detach().clone().to('mps').requires_grad_()
3578
3579 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
3580 src = cpu_src.detach().clone().to('mps').requires_grad_()
3581
3582 # Indices should be taken from range of axis along which gathering is done
3583 idx_np = np.random.randint(0, shape[dim], idx_shape)
3584
3585 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
3586 idx = cpu_idx.detach().clone().to('mps')
3587
3588 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
3589 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
3590
3591 self.assertEqual(scatter_result, scatter_result_cpu)
3592
3593 # for reduce in ["sum", "prod", "amax", "amin"]:
3594 for reduce in ["add", "multiply"]:
3595 helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce)
3596 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce)
3597 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce)
3598 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce)
3599 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce)
3600 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce)
3601
3602 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce)
3603 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce)
3604 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce)
3605 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce)
3606
3607 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce)
3608 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce)
3609 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce)
3610
3611 def test_is_nonzero(self):
3612 self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
3613 self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
3614 self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
3615 self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
3616
3617 # Test triu
3618 def test_triu(self):
3619 def helper(shape, diag=0):
3620 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3621 x = cpu_x.detach().clone().to('mps').requires_grad_()
3622
3623 triu_result = torch.triu(x, diag)
3624 triu_result_cpu = torch.triu(cpu_x, diag)
3625
3626 cpu_grad = torch.randn(triu_result_cpu.shape)
3627 grad = cpu_grad.to('mps')
3628
3629 triu_result.backward(gradient=grad)
3630 triu_result_cpu.backward(gradient=cpu_grad)
3631
3632 self.assertEqual(triu_result, triu_result_cpu)
3633 self.assertEqual(x.grad, cpu_x.grad)
3634
3635 helper((2, 8, 4, 5))
3636 helper((2, 8, 4, 5), diag=1)
3637 helper((2, 8, 4, 5), diag=2)
3638 helper((2, 8, 4, 5), diag=3)
3639 helper((2, 8, 4, 5), diag=-1)
3640 helper((2, 8, 4, 5), diag=-2)
3641 helper((2, 8, 4, 5), diag=-3)
3642
3643 # Test tril
3644 def test_tril(self):
3645 def helper(shape, diag=0):
3646 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3647 x = cpu_x.detach().clone().to('mps').requires_grad_()
3648
3649 tril_result = torch.tril(x, diag)
3650 tril_result_cpu = torch.tril(cpu_x, diag)
3651
3652 cpu_grad = torch.randn(tril_result_cpu.shape)
3653 grad = cpu_grad.to('mps')
3654
3655 tril_result.backward(gradient=grad)
3656 tril_result_cpu.backward(gradient=cpu_grad)
3657
3658 self.assertEqual(tril_result, tril_result_cpu)
3659 self.assertEqual(x.grad, cpu_x.grad)
3660
3661 helper((2, 8, 4, 5))
3662 helper((2, 8, 4, 5), diag=1)
3663 helper((2, 8, 4, 5), diag=2)
3664 helper((2, 8, 4, 5), diag=3)
3665 helper((2, 8, 4, 5), diag=-1)
3666 helper((2, 8, 4, 5), diag=-2)
3667 helper((2, 8, 4, 5), diag=-3)
3668
Kulin Seth8552acb2022-05-27 17:07:02 +00003669 # test eye
3670 def test_eye(self):
3671 def helper(n, m, dtype):
3672 cpu_result = None
3673 result = None
3674
3675 if(n == m):
3676 cpu_result = torch.eye(n, dtype=dtype, device='cpu')
3677 result = torch.eye(n, dtype=dtype, device='mps')
3678 else:
3679 cpu_result = torch.eye(n, m, device='cpu')
3680 result = torch.eye(n, m, device='mps')
3681
3682 self.assertEqual(result, cpu_result)
3683
3684 for dtype in [torch.float32, torch.int32, torch.int64]:
3685 helper(2, 2, dtype)
3686 helper(2, 3, dtype)
3687 helper(0, 2, dtype)
3688 helper(0, 0, dtype)
3689 helper(3, 8, dtype)
3690 helper(8, 3, dtype)
3691
Kulin Sethe011a8e2022-05-13 18:28:53 +00003692 # Test diag
3693 def test_diag(self):
3694 def helper(shape, diag=0):
3695 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3696 x = cpu_x.detach().clone().to('mps').requires_grad_()
3697
3698 diag_result = torch.diag(x, diag)
3699 diag_result_cpu = torch.diag(cpu_x, diag)
3700
3701 # cpu_grad = torch.randn(diag_result_cpu.shape)
3702 # grad = cpu_grad.to('mps')
3703
3704 # diag_result.backward(gradient=grad)
3705 # diag_result_cpu.backward(gradient=cpu_grad)
3706
3707 self.assertEqual(diag_result, diag_result_cpu)
3708 # self.assertEqual(x.grad, cpu_x.grad)
3709
3710 for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
3711 for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
3712 helper(shape, diag=diag)
3713
3714 # Test softmax
3715 def test_softmax(self):
3716 def helper(shape, dim, channels_last=False):
3717 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3718 if(channels_last):
3719 cpu_x = cpu_x.to(memory_format=torch.channels_last)
3720 cpu_x.retain_grad()
3721 x = cpu_x.detach().clone().to('mps').requires_grad_()
3722
3723 softmax_result = torch.nn.functional.softmax(x, dim=dim)
3724 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
3725
3726 # Currently NOT testing backward for channels last backward
3727 cpu_grad = None
3728 grad = None
3729
3730 if(not channels_last):
3731 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
3732 grad = cpu_grad.to('mps')
3733
3734 softmax_result.backward(gradient=grad)
3735 softmax_result_cpu.backward(gradient=cpu_grad)
3736
3737 self.assertEqual(softmax_result, softmax_result_cpu)
3738 if(not channels_last):
3739 self.assertEqual(x.grad, cpu_x.grad)
3740
3741 def helper2(dim):
3742 cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
3743 x = cpu_x.detach().clone().to('mps').requires_grad_()
3744
3745 softmax_result = torch.nn.functional.softmax(x, dim=dim)
3746 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
3747
3748 cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
3749 grad = cpu_grad.to('mps')
3750
3751 softmax_result.backward(gradient=grad)
3752 softmax_result_cpu.backward(gradient=cpu_grad)
3753
3754 self.assertEqual(softmax_result, softmax_result_cpu)
3755 self.assertEqual(x.grad, cpu_x.grad)
3756
3757 helper2(0)
3758
Kulin Seth3d833212022-05-20 03:18:09 +00003759 for channels_last in [False]:
Kulin Sethe011a8e2022-05-13 18:28:53 +00003760 for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
3761 if(len(shape) != 4 and channels_last):
3762 continue
3763 for dim in [0, 1, 2, 3, -1, -2, -3]:
3764 helper(shape, dim, channels_last)
3765
3766 # Test sub
3767 def test_sub(self):
3768 def helper(shape, alpha):
3769 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3770 x = cpu_x.detach().clone().to('mps')
3771
3772 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3773 y = cpu_y.detach().clone().to('mps')
3774
3775 cpu_out = torch.sub(cpu_x, cpu_y, alpha=alpha)
3776 out = torch.sub(x, y, alpha=alpha)
3777
3778 self.assertEqual(out, cpu_out)
3779
3780 helper((2, 8, 4, 5), 0.1)
3781 helper((2, 8, 3, 5), 0.1)
3782 helper((2, 8, 3, 5), 0.2)
3783
3784 # Test where
3785 def test_where(self):
3786 def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
3787
3788 cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
3789 cond = cpu_cond.detach().clone().to('mps')
3790
3791 cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
3792 x = cpu_x.detach().clone().to('mps').requires_grad_()
3793
3794 cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
3795 y = cpu_y.detach().clone().to('mps').requires_grad_()
3796
3797 cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
3798 out = torch.where(cond, x, y)
3799
3800 cpu_grad = torch.randn(cpu_out.shape)
3801 grad = cpu_grad.to('mps')
3802
3803 cpu_out.backward(gradient=cpu_grad)
3804 out.backward(gradient=grad)
3805
3806 self.assertEqual(out, cpu_out)
3807 self.assertEqual(x.grad, cpu_x.grad)
3808 self.assertEqual(y.grad, cpu_y.grad)
3809
3810 for shape in ([(0, 3), [], (2, 3), (9,)]):
3811 helper(shape, shape, shape)
3812
3813 helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
3814 helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
3815 helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
3816 helper([], (1, 1, 4), (2, 3, 1))
3817 helper([], (2, 3, 4), [])
3818
3819 # Test normal
3820 def test_normal(self):
3821 def helper(shape, mean=0.0, std=1.0):
3822 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3823 x = cpu_x.detach().clone().to('mps')
3824
3825 mps_out = torch.normal(mean, std, shape, device='mps')
3826
Kulin Sethe011a8e2022-05-13 18:28:53 +00003827 mean_array = np.ones(shape)
3828 mean_array *= mean
3829 cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
3830 mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
3831
3832 std_array = np.ones(shape)
3833 std_array *= std
3834 cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
3835 std_tensor = cpu_std_tensor.detach().clone().to('mps')
3836
3837 mps_out = torch.zeros(shape, device='mps')
3838 torch.normal(mean_tensor, std, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003839
3840 mps_out = torch.zeros(shape, device='mps')
3841 torch.normal(mean, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003842
3843 mps_out = torch.zeros(shape, device='mps')
3844 torch.normal(mean_tensor, std_tensor, out=mps_out)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003845
3846 helper((2, 3, 4, 5, 6))
3847 helper((100, 100), 2.5, 1.2)
3848
3849 def test_bernoulli(self):
3850 def helper(shape, prob=0.5):
3851 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3852 x = cpu_x.detach().clone().to('mps')
3853
3854 prob_array = np.ones(shape)
3855 prob_array *= prob
3856 cpu_prob_tensor = torch.tensor(prob_array, device='cpu', dtype=torch.float, requires_grad=False)
3857 prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
3858
3859 mps_out = torch.bernoulli(prob_tensor)
Alban Desmaison02551a02022-05-28 12:39:10 -04003860 # We can't check reliably the mean and std.
3861 # Just make sure we don't return constant values
3862 self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
3863 self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003864
3865 mps_out = torch.zeros(shape, device='mps')
3866 mps_out = torch.bernoulli(mps_out, prob)
3867
Alban Desmaison02551a02022-05-28 12:39:10 -04003868 self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
3869 self.assertNotEqual(mps_out.to('cpu').std(), 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003870
3871 helper((100, 100), 0.50)
3872 helper((100, 100), 0.76)
3873 helper((100, 100), 0.23)
3874
3875 # Test random_.to and random_.from
3876 def test_random(self):
3877 def helper(shape, low, high, dtype=torch.int32):
3878
Kulin Sethe011a8e2022-05-13 18:28:53 +00003879 mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
3880
Alban Desmaison02551a02022-05-28 12:39:10 -04003881 # We can't check reliably the mean and std.
3882 # Just make sure we don't return constant values
3883 self.assertNotEqual(mps_out.to('cpu').float().mean(), 0.)
3884 self.assertNotEqual(mps_out.to('cpu').float().std(), 0.)
Kulin Sethe011a8e2022-05-13 18:28:53 +00003885
3886 helper([100, 100], 0, 10)
3887 helper([100, 100], 23, 89)
3888 helper([100, 100], 23, 89, dtype=torch.float32)
3889 helper([100, 100], 23, 89, dtype=torch.int64)
3890 helper([100, 100], 0, 2, dtype=torch.bool)
3891
3892 # Test add
3893 def test_add_binary_op(self):
3894 def helper(shape, alpha):
3895 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3896 x = cpu_x.detach().clone().to('mps')
3897
3898 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3899 y = cpu_y.detach().clone().to('mps')
3900
3901 cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
3902 out = torch.add(x, y, alpha=alpha)
3903
3904 self.assertEqual(out, cpu_out)
3905
3906 helper((2, 8, 4, 5), 0.1)
3907 helper((2, 8, 3, 5), 0.1)
3908 helper((2, 8, 3, 5), 0.2)
3909
3910 # Test add
3911 def test_add_scalars(self):
3912 def helper(alpha=1.0):
3913 cpu_x = torch.tensor(2.3, device='cpu', dtype=torch.float, requires_grad=False)
3914 x = cpu_x.detach().clone().to('mps')
3915
3916 cpu_y = torch.tensor(3.4, device='cpu', dtype=torch.float, requires_grad=False)
3917 y = cpu_y.detach().clone().to('mps')
3918
3919 cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
3920 out = torch.add(x, y, alpha=alpha)
3921
Kulin Sethe011a8e2022-05-13 18:28:53 +00003922 self.assertEqual(out, cpu_out)
3923
3924 helper()
3925 helper(0.1)
3926 helper(0.2)
3927
3928 def test_atan2(self):
3929 def helper(shape):
3930 input_cpu = torch.randn(shape)
3931 input_mps = input_cpu.detach().clone().to("mps")
3932
3933 other_cpu = torch.randn(shape)
3934 other_mps = other_cpu.detach().clone().to("mps")
3935
3936 atan2_cpu = torch.atan2(input_cpu, other_cpu)
3937 atan2_mps = torch.atan2(input_mps, other_mps)
3938
3939 self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
3940
3941 helper(4)
3942 helper(10000)
3943 helper((10000, 40))
3944
3945
3946class TestNNMPS(NNTestCase):
3947
3948 def _create_basic_net(self):
3949 class Layer(nn.Module):
3950 def __init__(self):
3951 super(Layer, self).__init__()
3952 self.layer_dummy_param = Parameter(torch.empty(3, 5))
3953 self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
3954
3955 class Net(nn.Module):
3956 def __init__(self):
3957 super(Net, self).__init__()
3958 self.l1 = Layer()
3959 self.dummy_param = Parameter(torch.empty(3, 5))
3960 self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
3961
3962 l = Layer()
3963 n = Net()
3964 s = nn.Sequential(n, n)
3965
3966 return l, n, s
3967
3968 def test_requires_grad_(self):
3969 m = self._create_basic_net()[-1]
3970 assert len(list(m.buffers())) > 0, 'invalid test'
3971 assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
3972 assert len(list(m.parameters())) > 0, 'invalid test'
3973 assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
3974 for requires_grad in (False, True):
3975 self.assertIs(m.requires_grad_(requires_grad), m)
3976 for p in m.parameters():
3977 self.assertEqual(p.requires_grad, requires_grad)
3978 for b in m.buffers():
3979 self.assertFalse(b.requires_grad)
3980
3981 def test_module_backcompat(self):
3982 from torch.serialization import SourceChangeWarning
3983 path = download_file('https://download.pytorch.org/test_data/linear.pt')
3984 with warnings.catch_warnings():
3985 warnings.simplefilter('ignore', SourceChangeWarning)
3986 m = torch.load(path)
3987 input = torch.randn(2, 3, dtype=torch.float)
3988 self.assertEqual(m(input).size(), (2, 5))
3989
3990 def test_conv_backcompat(self):
3991 from torch.serialization import SourceChangeWarning
3992 # This file was generated by running on PyTorch 1.0.1 on Python 2:
3993 #
3994 # import torch
3995 # from torch import nn
3996 # m = nn.Conv2d(1, 1, 1)
3997 # torch.save(m, 'legacy_conv2d.pt')
3998 #
3999 # NB: This Pickle also contains some Unicode data!
4000 path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
4001 with warnings.catch_warnings():
4002 warnings.simplefilter('ignore', SourceChangeWarning)
4003 m = torch.load(path, encoding='utf-8')
4004 input = torch.randn((1, 1, 1, 1), dtype=torch.float)
4005 self.assertEqual(m(input).size(), (1, 1, 1, 1))
4006
Kulin Seth017b0ae2022-05-31 02:09:03 +00004007 def test_conv_expand(self):
4008 device = 'mps'
4009 input_ = torch.rand(2, 3, 16, 16, device=device)
4010 kernel = torch.rand(1, 1, 3, 11, device=device)
4011 tmp_kernel = kernel.expand(-1, 3, -1, -1)
4012 output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
4013
4014 # The test should not crash
4015 def test_permute(self):
4016 X = torch.randn(5, 5).to('mps')
4017 torch.log(X)
4018 X = X.permute(1, 0)
4019 torch.log(X)
4020
4021 # Printing of non_contiguous should not crash
4022 def test_print_non_contiguous(self):
4023 print(torch.ones(100, 100, device='mps').nonzero())
4024 print(torch.ones(100, 100, device='mps').nonzero().contiguous())
4025
Kulin Sethe011a8e2022-05-13 18:28:53 +00004026 def test_zero_grad(self):
4027 i = torch.randn(2, 5, requires_grad=True)
4028 module = nn.Linear(5, 5)
4029 for p in module.parameters():
4030 p.requires_grad = False
4031 module.zero_grad()
4032
4033 module.weight.requires_grad = True
4034 module.zero_grad()
4035 self.assertIsNone(module.weight.grad) # uninitialized grad
4036
4037 module(i).sum().backward()
4038 self.assertIsNotNone(module.weight.grad)
4039 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
4040 module.zero_grad()
4041 self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
4042
4043 module.bias.requires_grad = True
4044 module.zero_grad()
4045 self.assertIsNotNone(module.weight.grad)
4046 self.assertIsNone(module.bias.grad)
4047 module(i).sum().backward()
4048 self.assertIsNotNone(module.weight.grad)
4049 self.assertIsNotNone(module.bias.grad)
4050 self.assertGreater(module.weight.grad.data.abs().sum(), 0)
4051 self.assertGreater(module.bias.grad.data.abs().sum(), 0)
4052 module.zero_grad()
4053 self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
4054 self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
4055
4056 # Force set to None.
4057 module.zero_grad(set_to_none=True)
4058 self.assertIsNone(module.weight.grad)
4059
4060 def test_no_grad(self):
4061 for dtype in [torch.bfloat16, torch.float, torch.double]:
4062 module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
4063 input = torch.randn(1, 2, 10, 10).to(dtype)
4064 x = input
4065 y = input.clone()
4066
4067 output = module(x)
4068 self.assertTrue(output.requires_grad)
4069 output.backward(torch.ones(1, 5, 10, 10))
4070
4071 with torch.no_grad():
4072 output2 = module(y)
4073 self.assertFalse(output2.requires_grad)
4074 self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
4075
4076 def test_invalid_conv1d(self):
4077 for dtype in [torch.bfloat16, torch.float, torch.double]:
4078 module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
4079 input = torch.randn(1, 3, 4).to(dtype)
4080 with self.assertRaisesRegex(RuntimeError,
4081 r'Calculated padded input size per channel: \(4\). ' +
4082 r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
4083 module(input)
4084
4085 # Negative stride check
4086 module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
4087 input = torch.randn(1, 3, 4).to(dtype)
4088 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
4089 module(input)
4090
4091 def test_conv2d_discontiguous_weight(self):
4092 # Test for https://github.com/pytorch/pytorch/issues/55781
4093 x = torch.ones(64, 16, 16, 16)
4094 weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
4095 self.assertFalse(weight.is_contiguous())
4096 y = torch.nn.functional.conv2d(x, weight, None)
4097 if torch.backends.mkldnn.is_available():
4098 # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
4099 with torch.backends.mkldnn.flags(enabled=False):
4100 y_ = torch.nn.functional.conv2d(x, weight, None)
4101 self.assertEqual(y, y_)
4102 self.assertEqual(y.sum(), 4186112.)
4103
4104 def test_invalid_conv2d(self):
4105 for dtype in [torch.bfloat16, torch.float, torch.double]:
4106 module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
4107 input = torch.empty(1, 1, 4, 4).to(dtype)
4108 self.assertRaises(RuntimeError, lambda: module(input))
4109
4110 module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
4111 input = torch.randn(1, 3, 1, 1)
4112 with self.assertRaisesRegex(RuntimeError,
4113 r'Calculated padded input size per channel: \(1 x 1\). ' +
4114 r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
4115 module(input)
4116
4117 # Negative stride check
4118 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
4119 input = torch.randn(1, 3, 4, 4).to(dtype)
4120 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
4121 module(input)
4122
4123 # Zero stride check
4124 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
4125 input = torch.randn(1, 3, 4, 4).to(dtype)
4126 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
4127 module(input)
4128
4129 def test_conv2d_valid_padding(self, device='mps'):
4130 # Test F.conv2d padding='valid' is the same as no padding
4131 x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
4132 y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
4133
4134 expect = F.conv2d(x, y)
4135 actual = F.conv2d(x, y, padding='valid')
4136 self.assertEqual(expect.to('cpu'), actual.to('cpu'))
4137
4138 # def test_conv2d_same_padding(self, device='mps'):
4139 # x = torch.rand(1, 1, 10, 11, device=device)
4140 # y = torch.rand(1, 1, 4, 5, device=device)
4141 # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
4142 # actual = F.conv2d(x, y, padding='same')
4143 # self.assertEqual(expect.to('cpu'), actual.to('cpu'))
4144
4145 # # With dilation
4146 # y = torch.rand(1, 1, 3, 4, device=device)
4147 # expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
4148 # actual = F.conv2d(x, y, padding='same', dilation=2)
4149 # self.assertEqual(expect, actual)
4150
4151 # # Dilation with asymmetric padding
4152 # y = torch.rand(1, 1, 4, 4, device=device)
4153 # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
4154 # actual = F.conv2d(x, y, padding='same', dilation=3)
4155 # self.assertEqual(expect, actual)
4156
4157
4158class TestConstantPadNd(TestCase):
4159 def test_preserves_memory_format(self):
4160 nchw_tensor = torch.rand((1, 2, 5, 3))
4161 nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
4162 self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
4163
4164 nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
4165 nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
4166 self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
4167
4168
4169class TestLinalgMPS(TestCase):
4170 def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
4171 dtype = t.dtype
4172 numpy_dtype = dtype
4173 alpha = 1.2 if alpha is None else alpha
4174 beta = 0.8 if beta is None else beta
4175 res1 = f(t, m, v, alpha=alpha, beta=beta)
4176 res2 = torch.full_like(res1, math.nan)
4177 if transpose_out:
4178 res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
4179 f(t, m, v, alpha=alpha, beta=beta, out=res2)
4180 res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
4181 if beta != 0:
4182 res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
4183 res3 = torch.from_numpy(res3).to(dtype)
Kulin Seth978304f2022-05-14 13:33:16 +00004184 self.assertEqual(res1, res2)
4185 self.assertEqual(res1, res3)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004186
4187 def test_addmm(self, device="mps", dtype=torch.float32):
4188 M = torch.randn(10, 25, device=device).to(dtype)
4189 m1 = torch.randn(10, 50, device=device).to(dtype)
4190 m2 = torch.randn(50, 25, device=device).to(dtype)
4191 self._test_addmm_addmv(torch.addmm, M, m1, m2)
4192
Kulin Sethe011a8e2022-05-13 18:28:53 +00004193 # Test beta=0, M=nan
4194 M = torch.full((10, 25), math.nan, device=device).to(dtype)
4195 m1 = torch.randn(10, 50, device=device).to(dtype)
4196 m2 = torch.randn(50, 25, device=device).to(dtype)
4197 self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
4198
Kulin Seth978304f2022-05-14 13:33:16 +00004199 # Test transpose
4200 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
4201 def maybe_transpose(cond, m):
4202 if not cond:
4203 return m
4204 return m.t().clone(memory_format=torch.contiguous_format).t()
Kulin Sethe011a8e2022-05-13 18:28:53 +00004205
Kulin Seth978304f2022-05-14 13:33:16 +00004206 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
4207 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
4208 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
4209 self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004210
4211
4212class TestRNNMPS(TestCase):
4213 def test_lstm_1(self, device="mps", dtype=torch.float32):
4214
4215 rnn = nn.LSTM(1, 4, 2, device="cpu")
4216 input = torch.randn(2, 3, 1, device="cpu")
4217 hx = torch.zeros(2, 3, 4, device="cpu")
4218 cx = torch.zeros(2, 3, 4, device="cpu")
Kulin Sethe011a8e2022-05-13 18:28:53 +00004219
Alban Desmaison02551a02022-05-28 12:39:10 -04004220 cpu_output, _ = rnn(input, (hx, cx))
4221
4222 device = torch.device("mps")
4223 rnn = rnn.to(device)
4224 input = input.to(device)
4225 hx = hx.to(device)
4226 cx = cx.to(device)
4227 output, _ = rnn(input, (hx, cx))
4228 self.assertEqual(cpu_output, output)
4229
4230 @unittest.skipIf(True, "Backward of lstm returns wrong result")
Kulin Sethe011a8e2022-05-13 18:28:53 +00004231 def test_lstm_2(self, device="mps", dtype=torch.float32):
Alban Desmaison02551a02022-05-28 12:39:10 -04004232 def get_results(device):
4233 rnn = nn.LSTM(1, 4, 1, device=device)
4234 inp = torch.randn(2, 3, 1, device=device, requires_grad=True)
4235 hx = torch.zeros(1, 3, 4, device=device)
4236 cx = torch.zeros(1, 3, 4, device=device)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004237
Alban Desmaison02551a02022-05-28 12:39:10 -04004238 output, _ = rnn(inp, (hx, cx))
4239 output.sum().backward()
Kulin Sethe011a8e2022-05-13 18:28:53 +00004240
Alban Desmaison02551a02022-05-28 12:39:10 -04004241 weight_grad = rnn.weight_ih_l0.grad.clone()
4242 input_grad = inp.grad.clone()
4243
4244 return output, weight_grad, input_grad
4245
4246
4247 cpu_output, cpu_weight_grad, cpu_input_grad = get_results("cpu")
4248 mps_output, mps_weight_grad, mps_input_grad = get_results("mps")
4249
4250 self.assertEqual(cpu_output, mps_output)
4251 self.assertEqual(cpu_input_grad, mps_input_grad)
4252 self.assertEqual(cpu_weight_grad, mps_weight_grad)
Kulin Sethe011a8e2022-05-13 18:28:53 +00004253
Kulin Seth3d833212022-05-20 03:18:09 +00004254class TestFallbackWarning(TestCase):
4255 def test_no_warning_on_import(self):
4256 script = """
4257import warnings
4258
4259with warnings.catch_warnings(record=True) as w:
4260 import torch
4261
4262exit(len(w))
4263"""
4264 try:
4265 subprocess.check_output(
4266 [sys.executable, '-W', 'all', '-c', script],
4267 stderr=subprocess.STDOUT,
4268 # On Windows, opening the subprocess with the default CWD makes `import torch`
4269 # fail, so just set CWD to this script's directory
4270 cwd=os.path.dirname(os.path.realpath(__file__)),)
4271 except subprocess.CalledProcessError as e:
4272 self.assertTrue(False, "There was a warning when importing torch.")
4273
4274 def _get_not_implemented_op(self):
Kulin Seth8552acb2022-05-27 17:07:02 +00004275 # This can be changed once we actually implement `torch.bincount`
Kulin Seth3d833212022-05-20 03:18:09 +00004276 # Should return fn, args, kwargs, string_version
Kulin Seth8552acb2022-05-27 17:07:02 +00004277 return (torch.bincount,
Kulin Sethd63db522022-05-28 14:41:56 +00004278 torch.tensor([4], device='mps'), {},
Kulin Seth8552acb2022-05-27 17:07:02 +00004279 "torch.bincount(torch.tensor([4, 3, 6, 3, 4], device='mps'))")
Kulin Seth3d833212022-05-20 03:18:09 +00004280
4281 def test_error_on_not_implemented(self):
4282 fn, args, kwargs, _ = self._get_not_implemented_op()
4283
4284 with self.assertRaisesRegex(NotImplementedError, "not current implemented for the MPS device"):
4285 fn(*args, **kwargs)
4286
4287 def test_warn_on_not_implemented_with_fallback(self):
4288 _, _, _, op = self._get_not_implemented_op()
4289 script = f"""
4290import os
4291# MUST happen before pytorch's import
4292os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4293import warnings
4294
4295with warnings.catch_warnings(record=True) as w:
4296 import torch
4297
4298if len(w) > 0:
4299 exit(1)
4300
4301# This should run just fine and raise warning about perf
4302with warnings.catch_warnings(record=True) as w:
4303 {op}
4304
4305if len(w) != 1:
4306 exit(2)
4307
4308"""
4309 try:
4310 subprocess.check_output(
4311 [sys.executable, '-W', 'all', '-c', script],
4312 stderr=subprocess.STDOUT,
4313 # On Windows, opening the subprocess with the default CWD makes `import torch`
4314 # fail, so just set CWD to this script's directory
4315 cwd=os.path.dirname(os.path.realpath(__file__)),)
4316 except subprocess.CalledProcessError as e:
4317 if e.returncode == 1:
4318 self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set.")
4319 elif e.returncode == 2:
4320 self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
4321 "PYTORCH_ENABLE_MPS_FALLBACK set.")
4322 else:
4323 self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set.")
Kulin Sethe011a8e2022-05-13 18:28:53 +00004324
Alban Desmaison04ac80c2022-05-20 20:25:12 +00004325class TestNoRegression(TestCase):
4326 def test_assert_close(self):
4327 a = torch.ones(1, device="mps")
4328 b = torch.zeros(1, device="mps")
4329 inf = a / b
4330 nan = b / b
4331
4332 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
4333 torch.testing.assert_close(a, inf)
4334
4335 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
4336 torch.testing.assert_close(a, nan)
4337
4338 def test_double_error(self):
4339 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
4340 a = torch.ones(2, dtype=torch.float64, device="mps")
4341
4342 a = torch.ones(2, device="mps")
4343 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
4344 a = a.double()
4345
4346 def test_legacy_constructor(self):
4347 a = torch.ones(2, device="mps")
4348
4349 b = a.new(1)
4350
4351
4352
Kulin Sethe011a8e2022-05-13 18:28:53 +00004353if __name__ == "__main__":
4354 run_tests()