[MPS] Handle 1D inputs for NLL (#81290)

* Add test for NLL 1d
* Fix forward NLL for 1D case
* Handle NLL backward for 1d

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81290
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm
index ab41f87..63fbfe3 100644
--- a/aten/src/ATen/native/mps/operations/LossOps.mm
+++ b/aten/src/ATen/native/mps/operations/LossOps.mm
@@ -314,10 +314,10 @@
 
 // NLLLoss
 void nllnd_loss_backward_impl(
-Tensor& grad_input,
+Tensor& grad_input_arg,
 const Tensor& grad_output,
-const Tensor& input,
-const Tensor& target,
+const Tensor& input_arg,
+const Tensor& target_arg,
 const Tensor& weight,
 int64_t reduction,
 int64_t ignore_index,
@@ -325,7 +325,7 @@
 bool is2D)
 {
     // Empty output
-    if(grad_input.numel() == 0)
+    if(grad_input_arg.numel() == 0)
         return;
 
     MPSStream* stream = getCurrentMPSStream();
@@ -342,6 +342,10 @@
 
     MPSGraphCache* cache_ = MPSGraphCache::getInstance();
 
+    auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
+    auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
+    auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg;
+
     @autoreleasepool {
 
         auto numClasses = grad_input.sizes()[1];
@@ -472,24 +476,24 @@
 void nllnd_loss_forward_impl
 (Tensor& output,
  Tensor& total_weight,
- const Tensor& input,
- const Tensor& target,
+ const Tensor& input_arg,
+ const Tensor& target_arg,
  const Tensor& weight,
  int64_t reduction,
  int64_t ignore_index,
  bool is2D)
 {
-    std::vector<long long> reshapedTarget(target.sizes().begin(), target.sizes().end());
+    std::vector<long long> reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end());
     reshapedTarget.push_back(1);
 
-    Tensor batchSizeTensor = at::empty_like(input).resize_(IntArrayRef(1));
+    Tensor batchSizeTensor = at::empty_like(input_arg).resize_(IntArrayRef(1));
     float batchVal = 1.0f;
     for(size_t i = 0; i < reshapedTarget.size(); ++i)
         batchVal *= reshapedTarget[i];
     batchSizeTensor[0] = batchVal;
 
     if(reduction == Reduction::None)
-        output.resize_(target.sizes());
+        output.resize_(target_arg.sizes());
     if(reduction == Reduction::Sum)
         output.resize_({});
     if(reduction == Reduction::Mean)
@@ -516,6 +520,9 @@
 
     MPSStream* stream = getCurrentMPSStream();
 
+    auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
+    auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
+
     @autoreleasepool {
 
         bool isWeightsArrayValid = (weight.numel() > 0);
diff --git a/test/test_mps.py b/test/test_mps.py
index 9d94ad1..fd735b1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1701,6 +1701,26 @@
         output_mps.sum().backward()
         self.assertEqual(input.grad, input_mps.grad.to('cpu'))
 
+    def _nll_loss_1d_helper(self, input_size, reduction):
+
+        # CPU
+        input = torch.rand(input_size, requires_grad=True, device='cpu')
+        num_channels = input_size[0]
+        target = torch.randint(num_channels, [], device='cpu')
+
+        # MPS
+        input_mps = input.detach().clone().to('mps').requires_grad_()
+        target_mps = target.detach().clone().to('mps')
+
+        output_cpu = F.nll_loss(input, target, reduction=reduction)
+        output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
+        # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
+        self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu'))
+
+        output_cpu.sum().backward()
+        output_mps.sum().backward()
+        self.assertEqual(input.grad, input_mps.grad.to('cpu'))
+
     def test_as_strided(self):
         values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
         values_1 = [[1.0, 1.0], [1.0, 1.0]]
@@ -1743,6 +1763,11 @@
 
         helper(3, 3)
 
+    def test_nll_loss_1d(self, device='cpu'):
+        self._nll_loss_1d_helper([10], "none")
+        self._nll_loss_1d_helper([10], "mean")
+        self._nll_loss_1d_helper([10], "sum")
+
     def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
         self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
         self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))