[MPS] Handle 1D inputs for NLL (#81290)
* Add test for NLL 1d
* Fix forward NLL for 1D case
* Handle NLL backward for 1d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81290
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm
index ab41f87..63fbfe3 100644
--- a/aten/src/ATen/native/mps/operations/LossOps.mm
+++ b/aten/src/ATen/native/mps/operations/LossOps.mm
@@ -314,10 +314,10 @@
// NLLLoss
void nllnd_loss_backward_impl(
-Tensor& grad_input,
+Tensor& grad_input_arg,
const Tensor& grad_output,
-const Tensor& input,
-const Tensor& target,
+const Tensor& input_arg,
+const Tensor& target_arg,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
@@ -325,7 +325,7 @@
bool is2D)
{
// Empty output
- if(grad_input.numel() == 0)
+ if(grad_input_arg.numel() == 0)
return;
MPSStream* stream = getCurrentMPSStream();
@@ -342,6 +342,10 @@
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
+ auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
+ auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
+ auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg;
+
@autoreleasepool {
auto numClasses = grad_input.sizes()[1];
@@ -472,24 +476,24 @@
void nllnd_loss_forward_impl
(Tensor& output,
Tensor& total_weight,
- const Tensor& input,
- const Tensor& target,
+ const Tensor& input_arg,
+ const Tensor& target_arg,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
bool is2D)
{
- std::vector<long long> reshapedTarget(target.sizes().begin(), target.sizes().end());
+ std::vector<long long> reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end());
reshapedTarget.push_back(1);
- Tensor batchSizeTensor = at::empty_like(input).resize_(IntArrayRef(1));
+ Tensor batchSizeTensor = at::empty_like(input_arg).resize_(IntArrayRef(1));
float batchVal = 1.0f;
for(size_t i = 0; i < reshapedTarget.size(); ++i)
batchVal *= reshapedTarget[i];
batchSizeTensor[0] = batchVal;
if(reduction == Reduction::None)
- output.resize_(target.sizes());
+ output.resize_(target_arg.sizes());
if(reduction == Reduction::Sum)
output.resize_({});
if(reduction == Reduction::Mean)
@@ -516,6 +520,9 @@
MPSStream* stream = getCurrentMPSStream();
+ auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
+ auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
+
@autoreleasepool {
bool isWeightsArrayValid = (weight.numel() > 0);
diff --git a/test/test_mps.py b/test/test_mps.py
index 9d94ad1..fd735b1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1701,6 +1701,26 @@
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
+ def _nll_loss_1d_helper(self, input_size, reduction):
+
+ # CPU
+ input = torch.rand(input_size, requires_grad=True, device='cpu')
+ num_channels = input_size[0]
+ target = torch.randint(num_channels, [], device='cpu')
+
+ # MPS
+ input_mps = input.detach().clone().to('mps').requires_grad_()
+ target_mps = target.detach().clone().to('mps')
+
+ output_cpu = F.nll_loss(input, target, reduction=reduction)
+ output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
+ # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
+ self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu'))
+
+ output_cpu.sum().backward()
+ output_mps.sum().backward()
+ self.assertEqual(input.grad, input_mps.grad.to('cpu'))
+
def test_as_strided(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
values_1 = [[1.0, 1.0], [1.0, 1.0]]
@@ -1743,6 +1763,11 @@
helper(3, 3)
+ def test_nll_loss_1d(self, device='cpu'):
+ self._nll_loss_1d_helper([10], "none")
+ self._nll_loss_1d_helper([10], "mean")
+ self._nll_loss_1d_helper([10], "sum")
+
def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))