[MPS]: Add fix for squeezed input axes handling in BCE loss (#79676)
Fixes #79527
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79676
Approved by: https://github.com/razarmehr, https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 45550b0..679564d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1698,7 +1698,7 @@
helper([8, 4, 5, 7, 6], 'mean')
# Binary Cross Enropy
- def test_bce_loss(self):
+ def test_bce_loss_simple(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.BCELoss(reduction=reduction)
@@ -1728,6 +1728,146 @@
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
+ helper([1, 1, 32, 32], 'mean')
+
+ def test_bce_loss_always_nonnegative(self):
+ target = torch.ones(5, device='mps')
+ input = torch.ones(5, device='mps')
+ self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
+
+ target = torch.zeros(5, device='mps')
+ input = torch.zeros(5, device='mps')
+ self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
+
+ def test_bce_loss_size_mismatch(self):
+ bceloss = nn.BCELoss()
+ a = torch.rand(25, device='mps')
+ b = torch.rand(25, 1, device='mps')
+ with self.assertRaisesRegex(ValueError, r'Using a target size \('):
+ bceloss(a, b)
+
+ def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
+ x_size = 1024
+ y_size = 256
+ target = torch.rand(x_size, y_size, device='mps')
+
+ for reduction in ['none', 'mean', 'sum']:
+ output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
+ output_logits = output_sig.clone().detach()
+
+ output_sig.requires_grad = True
+ output_logits.requires_grad = True
+ weight = torch.rand(y_size, device='mps')
+
+ loss_sig = nn.BCELoss(weight, reduction=reduction)(
+ torch.sigmoid(output_sig), target
+ )
+ loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
+ output_logits, target
+ )
+
+ self.assertEqual(loss_logits, loss_sig)
+
+ if reduction == 'none':
+ grad = torch.rand(x_size, y_size, device='mps')
+ loss_sig.backward(grad)
+ loss_logits.backward(grad)
+ else:
+ loss_sig.backward()
+ loss_logits.backward()
+
+ self.assertEqual(output_sig.grad, output_logits.grad)
+
+ def test_bce_with_logits_has_correct_grad_at_zero(self):
+ output = torch.zeros(3, 1, requires_grad=True, device='mps')
+ target = torch.zeros(3, 1, device='mps')
+ nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
+ expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
+ self.assertEqual(output.grad, expected_grad)
+
+ def test_bce_with_logits_broadcasts_weights(self):
+ target = torch.rand(16, 4, device='mps')
+ output = torch.rand(16, 4, device='mps') - 0.5
+
+ weight = torch.rand(4, device='mps')
+ out1 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+ weight = weight.expand(16, 4).contiguous()
+ out2 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+ self.assertEqual(out1, out2)
+
+ weight = torch.rand(16, 1, device='mps')
+ out1 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+ weight = weight.expand(16, 4).contiguous()
+ out2 = nn.BCEWithLogitsLoss(weight)(output, target)
+
+ self.assertEqual(out1, out2)
+
+ def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
+ target = torch.rand(64, 4, device='mps')
+ output = torch.rand(64, 4, device='mps') - 0.5
+ pos_weight = torch.ones(64, 4, device='mps')
+
+ self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
+ nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
+
+ def test_bce_with_logits_broadcasts_pos_weights(self):
+ target = torch.rand(64, 4, device='mps')
+ output = torch.rand(64, 4, device='mps') - 0.5
+ pos_weight = torch.rand(4, device='mps')
+ out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
+
+ pos_weight1 = pos_weight.expand(1, 4)
+ out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
+
+ pos_weight2 = pos_weight.expand(64, 4)
+ out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
+
+ self.assertEqual(out1, out2)
+ self.assertEqual(out1, out3)
+
+ def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
+ output = torch.zeros(3, 1, requires_grad=True, device='mps')
+ target = torch.zeros(3, 1, device='mps')
+ pos_weight = torch.ones(3, 1, device='mps')
+ nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
+ expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
+ grad = output.grad
+ self.assertEqual(grad, expected_grad)
+
+ def test_bce_with_logits_stability(self):
+ output = torch.tensor([0., -120.], device='mps')
+ target = torch.tensor([0., 1.], device='mps')
+ pos_weight = torch.tensor([1., 1.], device='mps')
+
+ out1 = nn.BCEWithLogitsLoss()(output, target)
+ self.assertTrue(torch.isfinite(out1).all().item())
+
+ out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
+ self.assertTrue(torch.isfinite(out2).all().item())
+
+ def test_bce_loss_broadcasts_weights(self):
+ sigmoid = nn.Sigmoid()
+ target = torch.rand(16, 4, device='mps')
+ output = torch.rand(16, 4, device='mps') - 0.5
+
+ weight = torch.rand(4, device='mps')
+ out1 = nn.BCELoss(weight)(sigmoid(output), target)
+
+ weight = weight.expand(16, 4).contiguous()
+ out2 = nn.BCELoss(weight)(sigmoid(output), target)
+
+ self.assertEqual(out1, out2)
+
+ weight = torch.rand(16, 1, device='mps')
+ out1 = nn.BCELoss(weight)(sigmoid(output), target)
+
+ weight = weight.expand(16, 4).contiguous()
+ out2 = nn.BCELoss(weight)(sigmoid(output), target)
+
+ self.assertEqual(out1, out2)
def test_log_softmax(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]