blob: 3c888e0a837f7c553d0f418ad183d36c46fd3d14 [file] [log] [blame]
import torch
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
__all__ = [
'to_padded_tensor'
]
Tensor = torch.Tensor
# Note: This not only adds doc strings for the nested ops, but
# also connects the torch.nested Python namespace to the torch._C._nested builtins.
to_padded_tensor = _add_docstr(_nested.nested_to_padded_tensor,
r"""
to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor
Returns a new (non-nested) Tensor by padding the nested tensor.
The leading entries will be filled with the nested data,
while the trailing entries will be padded.
.. warning::
:func:`to_padded_tensor` always copies the underlying data,
since the nested and the non-nested tensors differ in memory layout.
Args:
padding (float): The padding value for the trailing entries.
Keyword args:
output_size (Tuple[int]): The size of the output tensor.
If given, it must be large enough to contain all nested data;
else, will infer by taking the max size of each nested sub-tensor along each dimension.
out (Tensor, optional): the output tensor.
Example::
>>> nt = torch.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))])
nested_tensor([
tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]),
tensor([[-1.8546, -0.7194, -0.2918, -0.1846],
[ 0.2773, 0.8793, -0.5183, -0.6447],
[ 1.8009, 1.8468, -0.9832, -1.5272]])
])
>>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0)
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000],
[ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000],
[ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]])
>>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6))
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
[[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000],
[ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000],
[ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])
>>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
""")