[MPS]: Add fix for squeezed input axes handling in BCE loss (#79676)

Fixes #79527

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79676
Approved by: https://github.com/razarmehr, https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 45550b0..679564d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1698,7 +1698,7 @@
         helper([8, 4, 5, 7, 6], 'mean')
 
     # Binary Cross Enropy
-    def test_bce_loss(self):
+    def test_bce_loss_simple(self):
         def helper(shape, reduction):
             # create the criterion
             loss = torch.nn.BCELoss(reduction=reduction)
@@ -1728,6 +1728,146 @@
         # verify if changes in shape would cause cached graph lookup problems
         helper([7, 5, 2, 4, 6], 'sum')
         helper([8, 4, 5, 7, 6], 'mean')
+        helper([1, 1, 32, 32], 'mean')
+
+    def test_bce_loss_always_nonnegative(self):
+        target = torch.ones(5, device='mps')
+        input = torch.ones(5, device='mps')
+        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
+
+        target = torch.zeros(5, device='mps')
+        input = torch.zeros(5, device='mps')
+        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
+
+    def test_bce_loss_size_mismatch(self):
+        bceloss = nn.BCELoss()
+        a = torch.rand(25, device='mps')
+        b = torch.rand(25, 1, device='mps')
+        with self.assertRaisesRegex(ValueError, r'Using a target size \('):
+            bceloss(a, b)
+
+    def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
+        x_size = 1024
+        y_size = 256
+        target = torch.rand(x_size, y_size, device='mps')
+
+        for reduction in ['none', 'mean', 'sum']:
+            output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
+            output_logits = output_sig.clone().detach()
+
+            output_sig.requires_grad = True
+            output_logits.requires_grad = True
+            weight = torch.rand(y_size, device='mps')
+
+            loss_sig = nn.BCELoss(weight, reduction=reduction)(
+                torch.sigmoid(output_sig), target
+            )
+            loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
+                output_logits, target
+            )
+
+            self.assertEqual(loss_logits, loss_sig)
+
+            if reduction == 'none':
+                grad = torch.rand(x_size, y_size, device='mps')
+                loss_sig.backward(grad)
+                loss_logits.backward(grad)
+            else:
+                loss_sig.backward()
+                loss_logits.backward()
+
+            self.assertEqual(output_sig.grad, output_logits.grad)
+
+    def test_bce_with_logits_has_correct_grad_at_zero(self):
+        output = torch.zeros(3, 1, requires_grad=True, device='mps')
+        target = torch.zeros(3, 1, device='mps')
+        nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
+        expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
+        self.assertEqual(output.grad, expected_grad)
+
+    def test_bce_with_logits_broadcasts_weights(self):
+        target = torch.rand(16, 4, device='mps')
+        output = torch.rand(16, 4, device='mps') - 0.5
+
+        weight = torch.rand(4, device='mps')
+        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+        weight = weight.expand(16, 4).contiguous()
+        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+        self.assertEqual(out1, out2)
+
+        weight = torch.rand(16, 1, device='mps')
+        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+        weight = weight.expand(16, 4).contiguous()
+        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+        self.assertEqual(out1, out2)
+
+    def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
+        target = torch.rand(64, 4, device='mps')
+        output = torch.rand(64, 4, device='mps') - 0.5
+        pos_weight = torch.ones(64, 4, device='mps')
+
+        self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
+                         nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
+
+    def test_bce_with_logits_broadcasts_pos_weights(self):
+        target = torch.rand(64, 4, device='mps')
+        output = torch.rand(64, 4, device='mps') - 0.5
+        pos_weight = torch.rand(4, device='mps')
+        out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
+
+        pos_weight1 = pos_weight.expand(1, 4)
+        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
+
+        pos_weight2 = pos_weight.expand(64, 4)
+        out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
+
+        self.assertEqual(out1, out2)
+        self.assertEqual(out1, out3)
+
+    def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
+        output = torch.zeros(3, 1, requires_grad=True, device='mps')
+        target = torch.zeros(3, 1, device='mps')
+        pos_weight = torch.ones(3, 1, device='mps')
+        nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
+        expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
+        grad = output.grad
+        self.assertEqual(grad, expected_grad)
+
+    def test_bce_with_logits_stability(self):
+        output = torch.tensor([0., -120.], device='mps')
+        target = torch.tensor([0., 1.], device='mps')
+        pos_weight = torch.tensor([1., 1.], device='mps')
+
+        out1 = nn.BCEWithLogitsLoss()(output, target)
+        self.assertTrue(torch.isfinite(out1).all().item())
+
+        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
+        self.assertTrue(torch.isfinite(out2).all().item())
+
+    def test_bce_loss_broadcasts_weights(self):
+        sigmoid = nn.Sigmoid()
+        target = torch.rand(16, 4, device='mps')
+        output = torch.rand(16, 4, device='mps') - 0.5
+
+        weight = torch.rand(4, device='mps')
+        out1 = nn.BCELoss(weight)(sigmoid(output), target)
+
+        weight = weight.expand(16, 4).contiguous()
+        out2 = nn.BCELoss(weight)(sigmoid(output), target)
+
+        self.assertEqual(out1, out2)
+
+        weight = torch.rand(16, 1, device='mps')
+        out1 = nn.BCELoss(weight)(sigmoid(output), target)
+
+        weight = weight.expand(16, 4).contiguous()
+        out2 = nn.BCELoss(weight)(sigmoid(output), target)
+
+        self.assertEqual(out1, out2)
 
     def test_log_softmax(self):
         values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]