Fix placeholder tensor is empty for relu in mps (#118965)
Fixes #118845
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118965
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm
index 99f044c..c5e6102 100644
--- a/aten/src/ATen/native/mps/operations/Activation.mm
+++ b/aten/src/ATen/native/mps/operations/Activation.mm
@@ -49,6 +49,10 @@
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
+ if (self.numel() == 0) {
+ return self;
+ }
+
MPSStream* stream = getCurrentMPSStream();
bool executeGatherOp =
@@ -81,6 +85,10 @@
Tensor& relu_mps_(Tensor& self) {
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
+
+ if (self.numel() == 0) {
+ return self;
+ }
// Inplace relu
Tensor& output = self;
bool executeGatherOp =
diff --git a/test/test_mps.py b/test/test_mps.py
index 186518b..6a8807b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1380,6 +1380,8 @@
self._testReluInPlace(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="mps")
+ self._testRelu(np.array([]).astype(t), device="mps")
+ self._testReluInPlace(np.array([]).astype(t), device="mps")
class MatmulTest(TestCaseMPS):
def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):