[MPS] Fix channels last copies in ELU,ReLU and Hardswish (#94664)
Fixes test_modules.py tests:
```
test_memory_format_nn_Hardswish_mps_float32
test_non_contiguous_tensors_nn_Hardswish_mps_float32
test_memory_format_nn_ReLU_mps_float32
```
Fixes elu when ran with `ChannelsLast` memory format.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94664
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm
index a5dae09..9e643eb 100644
--- a/aten/src/ATen/native/mps/operations/Activation.mm
+++ b/aten/src/ATen/native/mps/operations/Activation.mm
@@ -18,14 +18,15 @@
Tensor relu_mps(const Tensor& self) {
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
- Tensor output = at::empty_like(self);
- resize_tensor(&output);
- TORCH_CHECK(output.is_mps());
-
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
MPSStream* stream = getCurrentMPSStream();
+ bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) ||
+ self.is_contiguous(MemoryFormat::ChannelsLast) ||
+ self.is_contiguous(MemoryFormat::ChannelsLast3d));
+ Tensor output = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
+
@autoreleasepool {
string key = "relu" + getTensorsStringKey({self});
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
@@ -51,8 +52,8 @@
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
- Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
+ Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nil, false);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -75,7 +76,13 @@
using CachedGraph = MPSUnaryCachedGraph;
// Inplace relu
Tensor &output = self;
- TORCH_CHECK(output.is_mps());
+ bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) ||
+ self.is_contiguous(MemoryFormat::ChannelsLast) ||
+ self.is_contiguous(MemoryFormat::ChannelsLast3d));
+ Tensor out;
+ if (executeGatherOp) {
+ out = at::empty_like(self, MemoryFormat::Contiguous);
+ }
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@@ -106,8 +113,8 @@
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
- Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
+ Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, executeGatherOp ? out : output, nil, false);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -119,7 +126,9 @@
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
-
+ if (executeGatherOp) {
+ output.copy_(out);
+ }
}
return output;
@@ -1052,11 +1061,17 @@
string func_name) {
using namespace mps;
- TORCH_CHECK(self.is_mps());
+ auto resultMemFormat = result.suggest_memory_format();
+ bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat));
+ Tensor out;
+ if (executeGatherOp && resultMemFormat == MemoryFormat::ChannelsLast) {
+ out = at::empty_like(result, MemoryFormat::Contiguous);
+ }
// Empty output
- if(result.numel() == 0)
+ if(result.numel() == 0) {
return;
+ }
struct CachedGraph : public MPSCachedGraph
{
@@ -1137,8 +1152,8 @@
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
- Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
+ Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -1150,8 +1165,10 @@
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ if (out.has_storage()) {
+ result.copy_(out);
+ }
}
-
}
// scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) ))
@@ -1174,13 +1191,18 @@
const Tensor& self_or_result,
const Tensor& grad_input
) {
-
using namespace mps;
- TORCH_CHECK(grad_output.is_mps());
+ auto gradMemFormat = grad_input.suggest_memory_format();
+ bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) && grad_input.is_contiguous(gradMemFormat));
+ Tensor out;
+ if (executeGatherOp && gradMemFormat == MemoryFormat::ChannelsLast) {
+ out = at::empty_like(grad_input, MemoryFormat::Contiguous);
+ }
// Empty output
- if(grad_input.numel() == 0)
+ if(grad_input.numel() == 0) {
return;
+ }
struct CachedGraph : public MPSCachedGraph
{
@@ -1281,14 +1303,14 @@
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
- Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
+ Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp);
Placeholder selfPlaceholder = Placeholder();
Placeholder resultPlaceholder = Placeholder();
if(is_result)
- resultPlaceholder = Placeholder(cachedGraph->resultTensor_, self_or_result);
+ resultPlaceholder = Placeholder(cachedGraph->resultTensor_, self_or_result, nil, executeGatherOp);
else
- selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result);
- Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
+ selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp);
+ Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
@@ -1309,8 +1331,10 @@
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ if (out.has_storage()) {
+ grad_input.copy_(out);
+ }
}
-
}
TORCH_IMPL_FUNC(glu_out_mps) (
@@ -1390,7 +1414,6 @@
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
-
}
Tensor& glu_backward_mps_out (
@@ -2210,12 +2233,17 @@
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
- TORCH_CHECK(self.is_mps());
-
if (output.numel() == 0) {
return output;
}
+ auto resultMemFormat = output.suggest_memory_format();
+ bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && output.is_contiguous(resultMemFormat));
+ Tensor out;
+ if (executeGatherOp && !output.is_contiguous(MemoryFormat::Contiguous)) {
+ out = at::empty_like(output, MemoryFormat::Contiguous);
+ }
+
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
MPSStream* stream = at::mps::getCurrentMPSStream();
@@ -2296,9 +2324,9 @@
});
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
- Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
+ Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
- Placeholder(cachedGraph->outputTensor_, output);
+ Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : output, nil, false);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -2312,6 +2340,9 @@
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ if (out.has_storage()) {
+ output.copy_(out);
+ }
}
return output;
}
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 1e47b57..eade505 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -103,18 +103,20 @@
static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
{
+ auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
+
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* stream = getCurrentMPSStream();
Tensor dst;
Tensor src;
- if (!dst_.is_contiguous()) {
+ if (!dst_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
dst = at::empty_like(dst_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
dst = dst_;
}
auto storage_byte_offset = src_.storage_offset() * src_.itemsize();
- if (!src_.is_contiguous()) {
+ if (!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
Tensor emptyShell = Tensor();
src = gatherViewTensor(src_, emptyShell);
if (src.has_storage()) {
@@ -250,8 +252,9 @@
// gather into dst. This reduces the overhead of doing an additional blit for most cases
bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset && src_.dtype() == dst_.dtype());
Tensor src;
+ auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
- if (src_.is_view() || !src_.is_contiguous()) {
+ if (!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
Tensor emptyShell = Tensor();
src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell);
@@ -273,7 +276,7 @@
// Scatter to `dst` if the memory is not contiguous
// If the memory is not contiguous, it means that the tensor has strides and we would not be
// able to do the copy using a single blit
- if (!dst_.is_contiguous()) {
+ if (!dst_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
return scatterViewTensor(src, dst_);
}
src._set_conj(src_.is_conj());
diff --git a/test/test_mps.py b/test/test_mps.py
index bd3f5c1..8b282a9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4739,10 +4739,11 @@
# Test selu, elu, celu
def test_elu(self):
- def helper(shape, alpha=1.0):
- cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
- x = cpu_x.detach().clone().to('mps').requires_grad_()
+ def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
+ cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
+ x = cpu_x.detach().clone().to('mps').requires_grad_(True)
for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
elu_result = activation_func(x)
elu_result_cpu = activation_func(cpu_x)
@@ -4757,9 +4758,10 @@
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
- for shape in [[], (2, 3), (2, 8, 4, 5)]:
- for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
- helper(shape, alpha)
+ for memory_fromat in [torch.channels_last, torch.contiguous_format]:
+ for shape in [(2, 8, 4, 5)]:
+ for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
+ helper(shape, alpha, memory_fromat)
# Test glu
def test_glu(self):