Add mask_type=2 to masked_softmax for when mask.size() == input.size() (#85915)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85915
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp
index 21a94d5..d9d1b90 100644
--- a/aten/src/ATen/native/SoftMax.cpp
+++ b/aten/src/ATen/native/SoftMax.cpp
@@ -137,7 +137,8 @@
if (MaskedSoftMax) {
TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
int64_t mask_type = mask_type_.value();
- TORCH_CHECK((mask_type == 0) || (mask_type == 1), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask)");
+ // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
+ TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");
// TODO: Add support for TxT src_mask
TORCH_CHECK(mask_type != 0, "src_mask not currently supported on CPU");
diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu
index 99ebc91..529d70c 100644
--- a/aten/src/ATen/native/cuda/SoftMax.cu
+++ b/aten/src/ATen/native/cuda/SoftMax.cu
@@ -963,7 +963,7 @@
TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
int64_t mask_type = mask_type_.value();
- TORCH_CHECK((mask_type == 0) || (mask_type == 1), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask)");
+ TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)");
// If input is [B, H, T, T] and mask is [B, T]
// we have special fast kernel
@@ -975,6 +975,7 @@
// TODO We should have special fast kernel for TxT mask as well
// mask_type == 0 => mask_ is a src_mask
bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1);
+ // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes());
auto input = input_.dim() == 0 ? input_.view(1) : input_;
diff --git a/test/test_nn.py b/test/test_nn.py
index 9694c7d..7ddb45f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -15449,26 +15449,26 @@
for shape in shapes:
dims = [0, len(shape) - 1] if len(shape) > 0 else [0]
for dim in dims:
- input = torch.randn(shape, requires_grad=True)
- mask = torch.randint(0, 2, shape).bool()
- mask_type = 1 # BxL => src_key_padding_mask
- if (self.device_type == "cuda"):
- input = input.cuda().detach().requires_grad_()
- mask = mask.cuda()
- self._test_masked_softmax_helper(input, dim, mask, mask_type)
+ for mask_type in [1, 2]: # 1 = BxL => src_key_padding_mask
+ input = torch.randn(shape, requires_grad=True)
+ mask = torch.randint(0, 2, shape).bool()
+ if (self.device_type == "cuda"):
+ input = input.cuda().detach().requires_grad_()
+ mask = mask.cuda()
+ self._test_masked_softmax_helper(input, dim, mask, mask_type)
# In this test, the forward pass is expected to produce nan's because when dim=0, we only have unspecified values
def test_masked_softmax_forward_with_nans(self, device):
dim = 0
shapes = [(4, 5), (50, 100), (1500, 1200)]
for (x, y) in shapes:
- input = torch.randn((x, y), requires_grad=True)
- mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
- mask_type = 1 # BxL => src_key_padding_mask
- if (self.device_type == "cuda"):
- input = input.cuda().detach().requires_grad_()
- mask = mask.cuda()
- self._test_masked_softmax_helper(input, dim, mask, mask_type)
+ for mask_type in [1, 2]: # 1 = BxL => src_key_padding_mask
+ input = torch.randn((x, y), requires_grad=True)
+ mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
+ if (self.device_type == "cuda"):
+ input = input.cuda().detach().requires_grad_()
+ mask = mask.cuda()
+ self._test_masked_softmax_helper(input, dim, mask, mask_type)
@onlyCUDA
def test_masked_softmax_transformer_layout(self, device):