[MPS] Add floor_divide() op and its test case (#91126)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91126
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index a246bb0..2c50a5e 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -181,7 +181,7 @@
assert(0 && "Invalid rounding mode\n");
return nullptr;
};
- binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_out_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block);
+ binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block);
}
void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name)
@@ -287,11 +287,11 @@
TORCH_IMPL_FUNC(div_out_mode_mps) (const Tensor& self, const Tensor& other, c10::optional<c10::string_view> rounding_mode, const Tensor& output) {
- mps::div_mode_template(self, other, rounding_mode, output, "div_mode");
+ mps::div_mode_template(self, other, rounding_mode, output, "div_mode_out");
}
TORCH_IMPL_FUNC(div_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
- mps::div_mode_template(self, other, c10::nullopt, output, "div");
+ mps::div_mode_template(self, other, c10::nullopt, output, "div_out");
}
TORCH_IMPL_FUNC(add_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
@@ -302,6 +302,20 @@
mps::add_sub_template(self, other, alpha, output, "sub");
}
+Tensor& floor_divide_out_mps(const Tensor& self, const Tensor& other, Tensor& result) {
+ mps::div_mode_template(self, other, "floor", result, "floor_divide_out");
+ return result;
+}
+
+Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {
+ Tensor output = at::empty_like(self);
+ mps::div_mode_template(self, other, "floor", output, "floor_divide");
+ return output;
+}
+
+Tensor& floor_divide_mps_(Tensor& self, const Tensor& other) {
+ return floor_divide_out_mps(self, other, self);
+}
TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index fc6eb0b..abf9417 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2571,6 +2571,7 @@
variants: function, method
dispatch:
CPU, CUDA: floor_divide
+ MPS: floor_divide_mps
SparseCPU, SparseCUDA: floor_divide_sparse
- func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
@@ -2578,12 +2579,14 @@
variants: method
dispatch:
CPU, CUDA: floor_divide_
+ MPS: floor_divide_mps_
SparseCPU, SparseCUDA: floor_divide_sparse_
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: floor_divide_out
+ MPS: floor_divide_out_mps
SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 200889f..748b7f9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3401,13 +3401,19 @@
# clamp to avoid division by 0
mps_y = cpu_y.detach().clone().to('mps')
- result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
- result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
- self.assertEqual(result_div_mps, result_div_cpu)
+ if (rounding_mode == "floor_divide"):
+ result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
+ result_div_mps = torch.floor_divide(mps_x, mps_y)
+ self.assertEqual(result_div_mps, result_div_cpu)
+ else:
+ result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
+ result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
+ self.assertEqual(result_div_mps, result_div_cpu)
helper((2, 8, 4, 5), None)
helper((2, 8, 4, 5), "floor")
helper((2, 8, 4, 5), "trunc")
+ helper((2, 8, 4, 5), "floor_divide")
def test_rounding(self):
def helper(shape):
@@ -7450,6 +7456,7 @@
'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'float': ['f32'],
'floor': ['f32', 'f16', 'i16', 'i32', 'i64'],
+ 'floor_divide': ['f32', 'f16'],
'frac': ['f16', 'f32'],
'gradient': ['f16', 'f32', 'i16'],
'half': ['f16'],