Add mask_type=2 to masked_softmax for when mask.size() == input.size() (#85915)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85915
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp
index 21a94d5..d9d1b90 100644
--- a/aten/src/ATen/native/SoftMax.cpp
+++ b/aten/src/ATen/native/SoftMax.cpp
@@ -137,7 +137,8 @@
   if (MaskedSoftMax) {
     TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
     int64_t mask_type = mask_type_.value();
-    TORCH_CHECK((mask_type == 0) || (mask_type == 1), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask)");
+    // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
+    TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");
 
     // TODO: Add support for TxT src_mask
     TORCH_CHECK(mask_type != 0, "src_mask not currently supported on CPU");
diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu
index 99ebc91..529d70c 100644
--- a/aten/src/ATen/native/cuda/SoftMax.cu
+++ b/aten/src/ATen/native/cuda/SoftMax.cu
@@ -963,7 +963,7 @@
 
   TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
   int64_t mask_type = mask_type_.value();
-  TORCH_CHECK((mask_type == 0) || (mask_type == 1), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask)");
+  TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)");
 
   // If input is [B, H, T, T] and mask is [B, T]
   // we have special fast kernel
@@ -975,6 +975,7 @@
   // TODO We should have special fast kernel for TxT mask as well
   // mask_type == 0 => mask_ is a src_mask
   bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1);
+  // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
   TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes());
 
   auto input = input_.dim() == 0 ? input_.view(1) : input_;
diff --git a/test/test_nn.py b/test/test_nn.py
index 9694c7d..7ddb45f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -15449,26 +15449,26 @@
         for shape in shapes:
             dims = [0, len(shape) - 1] if len(shape) > 0 else [0]
             for dim in dims:
-                input = torch.randn(shape, requires_grad=True)
-                mask = torch.randint(0, 2, shape).bool()
-                mask_type = 1   # BxL => src_key_padding_mask
-                if (self.device_type == "cuda"):
-                    input = input.cuda().detach().requires_grad_()
-                    mask = mask.cuda()
-                self._test_masked_softmax_helper(input, dim, mask, mask_type)
+                for mask_type in [1, 2]:  # 1 = BxL => src_key_padding_mask
+                    input = torch.randn(shape, requires_grad=True)
+                    mask = torch.randint(0, 2, shape).bool()
+                    if (self.device_type == "cuda"):
+                        input = input.cuda().detach().requires_grad_()
+                        mask = mask.cuda()
+                    self._test_masked_softmax_helper(input, dim, mask, mask_type)
 
     # In this test, the forward pass is expected to produce nan's because when dim=0, we only have unspecified values
     def test_masked_softmax_forward_with_nans(self, device):
         dim = 0
         shapes = [(4, 5), (50, 100), (1500, 1200)]
         for (x, y) in shapes:
-            input = torch.randn((x, y), requires_grad=True)
-            mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
-            mask_type = 1   # BxL => src_key_padding_mask
-            if (self.device_type == "cuda"):
-                input = input.cuda().detach().requires_grad_()
-                mask = mask.cuda()
-            self._test_masked_softmax_helper(input, dim, mask, mask_type)
+            for mask_type in [1, 2]:  # 1 = BxL => src_key_padding_mask
+                input = torch.randn((x, y), requires_grad=True)
+                mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
+                if (self.device_type == "cuda"):
+                    input = input.cuda().detach().requires_grad_()
+                    mask = mask.cuda()
+                self._test_masked_softmax_helper(input, dim, mask, mask_type)
 
     @onlyCUDA
     def test_masked_softmax_transformer_layout(self, device):