[MPS] Fused Adam & AdamW (#127242)
Summary:
This PR adds fused Adam and AdamW implementations.
Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300
Times are in milliseconds (ms).
```
**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200
Times are in milliseconds (ms).
```
```python
def profile_fused_adam():
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools
def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
fn(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=False,
capturable=False,
fused=fused,
amsgrad=amsgrad,
beta1=0.9,
beta2=0.99,
lr=1e-3,
weight_decay=.0,
eps=1e-5,
maximize=False,
grad_scale=None,
found_inf=None,
)
torch.mps.synchronize()
device = "mps"
results = []
for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
if adamWflag:
fn = adamw.adamw
else:
fn = adam.adam
for fused in [True, False]:
t = benchmark.Timer(
stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
label='Fused Adam',
sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
globals=locals(),
description= f"Fused: {fused}",
).blocked_autorange(min_run_time=5)
results.append(t)
compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
diff --git a/test/test_mps.py b/test/test_mps.py
index 311cf82..a97b8fb 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -76,7 +76,6 @@
XFAILLIST_GRAD = {
# precision issues
- 'digamma': [torch.float32],
'special.polygammaspecial_polygamma_n_0': [torch.float16],
'polygammapolygamma_n_0': [torch.float16],
'nn.functional.binary_cross_entropy': [torch.float16],
@@ -95,7 +94,6 @@
'masked.scatter': [torch.float16, torch.float32],
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
'aminmax': [torch.float32, torch.float16],
- 'polar': [torch.float32],
# Correctness issues
'atanh': [torch.float32],
@@ -569,7 +567,6 @@
'special.ndtr': [torch.uint8],
'sqrt': [torch.uint8],
'sub': [torch.uint8],
- 'tanh': [torch.uint8],
'trapezoid': [torch.uint8],
'trapz': [torch.uint8],
'true_divide': [torch.uint8],
@@ -586,28 +583,13 @@
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# cpu not giving nan for x/0.0
- 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+ 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# inconsistency errors between cpu and mps, max seen atol is 2
'nn.functional.interpolatebilinear': [torch.uint8],
}
MACOS_BEFORE_13_3_XFAILLIST = {
- # Failure due to precision issues (still present on 13.3+) as well as non-standard behavior of
- # cpu ops for the negative integers.
- # Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)):
- # - CPU output: tensor([102.668, 1.129e+15])
- # - MPS output: tensor([102.6681, inf])
- # In the latter case, inf is probably correct (this is what scipy does).
- 'polygamma': [torch.float32, torch.uint8],
- 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
- 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
- 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
-
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
'tan': [torch.float32],
'cdist': [torch.float32],
@@ -656,20 +638,6 @@
# Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
# The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
-
- # Failure due to precision issues as well as non-standard behavior of cpu ops for the
- # negative integers. Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)):
- # - CPU output: tensor([102.668, 1.129e+15])
- # - MPS output: tensor([102.6681, inf])
- # In the latter case, inf is probably correct (this is what scipy does).
- 'polygamma': [torch.float32, torch.uint8],
- 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
- 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
- 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
- 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
}
MACOS_BEFORE_14_4_XFAILLIST = {