blob: b6201d4d9e84d4a80484796b2a3f34bc803ff71b [file] [log] [blame]
# Owner(s): ["module: distributions"]
"""
Note [Randomized statistical tests]
-----------------------------------
This note describes how to maintain tests in this file as random sources
change. This file contains two types of randomized tests:
1. The easier type of randomized test are tests that should always pass but are
initialized with random data. If these fail something is wrong, but it's
fine to use a fixed seed by inheriting from common.TestCase.
2. The trickier tests are statistical tests. These tests explicitly call
set_rng_seed(n) and are marked "see Note [Randomized statistical tests]".
These statistical tests have a known positive failure rate
(we set failure_rate=1e-3 by default). We need to balance strength of these
tests with annoyance of false alarms. One way that works is to specifically
set seeds in each of the randomized tests. When a random generator
occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some
of these statistical tests may (rarely) fail. If one fails in this case,
it's fine to increment the seed of the failing test (but you shouldn't need
to increment it more than once; otherwise something is probably actually
wrong).
3. `test_geometric_sample`, `test_binomial_sample` and `test_poisson_sample`
are validated against `scipy.stats.` which are not guaranteed to be identical
across different versions of scipy (namely, they yield invalid results in 1.7+)
"""
import math
import numbers
import unittest
from collections import namedtuple
from itertools import product
from random import shuffle
from packaging import version
import torch
# TODO: remove this global setting
# Distributions tests use double as the default dtype
torch.set_default_dtype(torch.double)
from torch._six import inf, nan
from torch.testing._internal.common_utils import \
(TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests,
gradcheck)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.autograd import grad
import torch.autograd.forward_ad as fwAD
from torch.autograd.functional import jacobian
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Cauchy, Chi2, ContinuousBernoulli, Dirichlet,
Distribution, Exponential, ExponentialFamily,
FisherSnedecor, Gamma, Geometric, Gumbel,
HalfCauchy, HalfNormal, Independent, Kumaraswamy,
LKJCholesky, Laplace, LogisticNormal,
LogNormal, LowRankMultivariateNormal,
MixtureSameFamily, Multinomial, MultivariateNormal,
NegativeBinomial, Normal,
OneHotCategorical, OneHotCategoricalStraightThrough,
Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
StudentT, TransformedDistribution, Uniform,
VonMises, Weibull, Wishart, constraints, kl_divergence)
from torch.distributions.constraint_registry import transform_to
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.kl import _kl_expfamily_expfamily
from torch.distributions.transforms import (AffineTransform, CatTransform, ExpTransform,
StackTransform, identity_transform)
from torch.distributions.utils import (probs_to_logits, lazy_property, tril_matrix_to_vec,
vec_to_tril_matrix)
from torch.nn.functional import softmax
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
TEST_NUMPY = True
try:
import numpy as np
import scipy.stats
import scipy.special
except ImportError:
TEST_NUMPY = False
def pairwise(Dist, *params):
"""
Creates a pair of distributions `Dist` initialized to test each element of
param with each other.
"""
params1 = [torch.tensor([p] * len(p)) for p in params]
params2 = [p.transpose(0, 1) for p in params1]
return Dist(*params1), Dist(*params2)
def is_all_nan(tensor):
"""
Checks if all entries of a tensor is nan.
"""
return (tensor != tensor).all()
# Register all distributions for generic tests.
Example = namedtuple('Example', ['Dist', 'params'])
EXAMPLES = [
Example(Bernoulli, [
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([0.3], requires_grad=True)},
{'probs': 0.3},
{'logits': torch.tensor([0.], requires_grad=True)},
]),
Example(Geometric, [
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([0.3], requires_grad=True)},
{'probs': 0.3},
]),
Example(Beta, [
{
'concentration1': torch.randn(2, 3).exp().requires_grad_(),
'concentration0': torch.randn(2, 3).exp().requires_grad_(),
},
{
'concentration1': torch.randn(4).exp().requires_grad_(),
'concentration0': torch.randn(4).exp().requires_grad_(),
},
]),
Example(Categorical, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Binomial, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10])},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10, 8])},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
'total_count': torch.tensor([[10., 8.], [5., 3.]])},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
'total_count': torch.tensor(0.)},
]),
Example(NegativeBinomial, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
{'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': 10},
{'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10])},
{'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10, 8])},
{'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
'total_count': torch.tensor([[10., 8.], [5., 3.]])},
{'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
'total_count': torch.tensor(0.)},
]),
Example(Multinomial, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
]),
Example(Cauchy, [
{'loc': 0.0, 'scale': 1.0},
{'loc': torch.tensor([0.0]), 'scale': 1.0},
{'loc': torch.tensor([[0.0], [0.0]]),
'scale': torch.tensor([[1.0], [1.0]])}
]),
Example(Chi2, [
{'df': torch.randn(2, 3).exp().requires_grad_()},
{'df': torch.randn(1).exp().requires_grad_()},
]),
Example(StudentT, [
{'df': torch.randn(2, 3).exp().requires_grad_()},
{'df': torch.randn(1).exp().requires_grad_()},
]),
Example(Dirichlet, [
{'concentration': torch.randn(2, 3).exp().requires_grad_()},
{'concentration': torch.randn(4).exp().requires_grad_()},
]),
Example(Exponential, [
{'rate': torch.randn(5, 5).abs().requires_grad_()},
{'rate': torch.randn(1).abs().requires_grad_()},
]),
Example(FisherSnedecor, [
{
'df1': torch.randn(5, 5).abs().requires_grad_(),
'df2': torch.randn(5, 5).abs().requires_grad_(),
},
{
'df1': torch.randn(1).abs().requires_grad_(),
'df2': torch.randn(1).abs().requires_grad_(),
},
{
'df1': torch.tensor([1.0]),
'df2': 1.0,
}
]),
Example(Gamma, [
{
'concentration': torch.randn(2, 3).exp().requires_grad_(),
'rate': torch.randn(2, 3).exp().requires_grad_(),
},
{
'concentration': torch.randn(1).exp().requires_grad_(),
'rate': torch.randn(1).exp().requires_grad_(),
},
]),
Example(Gumbel, [
{
'loc': torch.randn(5, 5, requires_grad=True),
'scale': torch.randn(5, 5).abs().requires_grad_(),
},
{
'loc': torch.randn(1, requires_grad=True),
'scale': torch.randn(1).abs().requires_grad_(),
},
]),
Example(HalfCauchy, [
{'scale': 1.0},
{'scale': torch.tensor([[1.0], [1.0]])}
]),
Example(HalfNormal, [
{'scale': torch.randn(5, 5).abs().requires_grad_()},
{'scale': torch.randn(1).abs().requires_grad_()},
{'scale': torch.tensor([1e-5, 1e-5], requires_grad=True)}
]),
Example(Independent, [
{
'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
torch.randn(2, 3).abs().requires_grad_()),
'reinterpreted_batch_ndims': 0,
},
{
'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
torch.randn(2, 3).abs().requires_grad_()),
'reinterpreted_batch_ndims': 1,
},
{
'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
torch.randn(2, 3).abs().requires_grad_()),
'reinterpreted_batch_ndims': 2,
},
{
'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
torch.randn(2, 3, 5).abs().requires_grad_()),
'reinterpreted_batch_ndims': 2,
},
{
'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
torch.randn(2, 3, 5).abs().requires_grad_()),
'reinterpreted_batch_ndims': 3,
},
]),
Example(Kumaraswamy, [
{
'concentration1': torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
'concentration0': torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
},
{
'concentration1': torch.rand(4).uniform_(1, 2).requires_grad_(),
'concentration0': torch.rand(4).uniform_(1, 2).requires_grad_(),
},
]),
Example(LKJCholesky, [
{
'dim': 2,
'concentration': 0.5
},
{
'dim': 3,
'concentration': torch.tensor([0.5, 1., 2.]),
},
{
'dim': 100,
'concentration': 4.
},
]),
Example(Laplace, [
{
'loc': torch.randn(5, 5, requires_grad=True),
'scale': torch.randn(5, 5).abs().requires_grad_(),
},
{
'loc': torch.randn(1, requires_grad=True),
'scale': torch.randn(1).abs().requires_grad_(),
},
{
'loc': torch.tensor([1.0, 0.0], requires_grad=True),
'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
},
]),
Example(LogNormal, [
{
'loc': torch.randn(5, 5, requires_grad=True),
'scale': torch.randn(5, 5).abs().requires_grad_(),
},
{
'loc': torch.randn(1, requires_grad=True),
'scale': torch.randn(1).abs().requires_grad_(),
},
{
'loc': torch.tensor([1.0, 0.0], requires_grad=True),
'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
},
]),
Example(LogisticNormal, [
{
'loc': torch.randn(5, 5).requires_grad_(),
'scale': torch.randn(5, 5).abs().requires_grad_(),
},
{
'loc': torch.randn(1).requires_grad_(),
'scale': torch.randn(1).abs().requires_grad_(),
},
{
'loc': torch.tensor([1.0, 0.0], requires_grad=True),
'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
},
]),
Example(LowRankMultivariateNormal, [
{
'loc': torch.randn(5, 2, requires_grad=True),
'cov_factor': torch.randn(5, 2, 1, requires_grad=True),
'cov_diag': torch.tensor([2.0, 0.25], requires_grad=True),
},
{
'loc': torch.randn(4, 3, requires_grad=True),
'cov_factor': torch.randn(3, 2, requires_grad=True),
'cov_diag': torch.tensor([5.0, 1.5, 3.], requires_grad=True),
}
]),
Example(MultivariateNormal, [
{
'loc': torch.randn(5, 2, requires_grad=True),
'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True),
},
{
'loc': torch.randn(2, 3, requires_grad=True),
'precision_matrix': torch.tensor([[2.0, 0.1, 0.0],
[0.1, 0.25, 0.0],
[0.0, 0.0, 0.3]], requires_grad=True),
},
{
'loc': torch.randn(5, 3, 2, requires_grad=True),
'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]],
[[2.0, 0.0], [0.3, 0.25]],
[[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True),
},
{
'loc': torch.tensor([1.0, -1.0]),
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
},
]),
Example(Normal, [
{
'loc': torch.randn(5, 5, requires_grad=True),
'scale': torch.randn(5, 5).abs().requires_grad_(),
},
{
'loc': torch.randn(1, requires_grad=True),
'scale': torch.randn(1).abs().requires_grad_(),
},
{
'loc': torch.tensor([1.0, 0.0], requires_grad=True),
'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
},
]),
Example(OneHotCategorical, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(OneHotCategoricalStraightThrough, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Pareto, [
{
'scale': 1.0,
'alpha': 1.0
},
{
'scale': torch.randn(5, 5).abs().requires_grad_(),
'alpha': torch.randn(5, 5).abs().requires_grad_()
},
{
'scale': torch.tensor([1.0]),
'alpha': 1.0
}
]),
Example(Poisson, [
{
'rate': torch.randn(5, 5).abs().requires_grad_(),
},
{
'rate': torch.randn(3).abs().requires_grad_(),
},
{
'rate': 0.2,
},
{
'rate': torch.tensor([0.0], requires_grad=True),
},
{
'rate': 0.0,
}
]),
Example(RelaxedBernoulli, [
{
'temperature': torch.tensor([0.5], requires_grad=True),
'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True),
},
{
'temperature': torch.tensor([2.0]),
'probs': torch.tensor([0.3]),
},
{
'temperature': torch.tensor([7.2]),
'logits': torch.tensor([-2.0, 2.0, 1.0, 5.0])
}
]),
Example(RelaxedOneHotCategorical, [
{
'temperature': torch.tensor([0.5], requires_grad=True),
'probs': torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True)
},
{
'temperature': torch.tensor([2.0]),
'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]])
},
{
'temperature': torch.tensor([7.2]),
'logits': torch.tensor([[-2.0, 2.0], [1.0, 5.0]])
}
]),
Example(TransformedDistribution, [
{
'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
torch.randn(2, 3).abs().requires_grad_()),
'transforms': [],
},
{
'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
torch.randn(2, 3).abs().requires_grad_()),
'transforms': ExpTransform(),
},
{
'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
torch.randn(2, 3, 5).abs().requires_grad_()),
'transforms': [AffineTransform(torch.randn(3, 5), torch.randn(3, 5)),
ExpTransform()],
},
{
'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
torch.randn(2, 3, 5).abs().requires_grad_()),
'transforms': AffineTransform(1, 2),
},
{
'base_distribution': Uniform(torch.tensor(1e8).log(), torch.tensor(1e10).log()),
'transforms': ExpTransform(),
},
]),
Example(Uniform, [
{
'low': torch.zeros(5, 5, requires_grad=True),
'high': torch.ones(5, 5, requires_grad=True),
},
{
'low': torch.zeros(1, requires_grad=True),
'high': torch.ones(1, requires_grad=True),
},
{
'low': torch.tensor([1.0, 1.0], requires_grad=True),
'high': torch.tensor([2.0, 3.0], requires_grad=True),
},
]),
Example(Weibull, [
{
'scale': torch.randn(5, 5).abs().requires_grad_(),
'concentration': torch.randn(1).abs().requires_grad_()
}
]),
Example(Wishart, [
{
'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True),
'df': torch.tensor([3.], requires_grad=True),
},
{
'precision_matrix': torch.tensor([[2.0, 0.1, 0.0],
[0.1, 0.25, 0.0],
[0.0, 0.0, 0.3]], requires_grad=True),
'df': torch.tensor([5., 4], requires_grad=True),
},
{
'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]],
[[2.0, 0.0], [0.3, 0.25]],
[[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True),
'df': torch.tensor([5., 3.5, 3], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
'df': torch.tensor([3.0]),
},
{
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
'df': 3.0,
},
]),
Example(MixtureSameFamily, [
{
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
'component_distribution': Normal(torch.randn(5, requires_grad=True),
torch.rand(5, requires_grad=True)),
},
{
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
'component_distribution': MultivariateNormal(
loc=torch.randn(5, 2, requires_grad=True),
covariance_matrix=torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True)),
},
]),
Example(VonMises, [
{
'loc': torch.tensor(1.0, requires_grad=True),
'concentration': torch.tensor(10.0, requires_grad=True)
},
{
'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True),
'concentration': torch.tensor([1.0, 10.0], requires_grad=True)
},
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([0.3], requires_grad=True)},
{'probs': 0.3},
{'logits': torch.tensor([0.], requires_grad=True)},
])
]
BAD_EXAMPLES = [
Example(Bernoulli, [
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([-0.5], requires_grad=True)},
{'probs': 1.00001},
]),
Example(Beta, [
{
'concentration1': torch.tensor([0.0], requires_grad=True),
'concentration0': torch.tensor([0.0], requires_grad=True),
},
{
'concentration1': torch.tensor([-1.0], requires_grad=True),
'concentration0': torch.tensor([-2.0], requires_grad=True),
},
]),
Example(Geometric, [
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([-0.3], requires_grad=True)},
{'probs': 1.00000001},
]),
Example(Categorical, [
{'probs': torch.tensor([[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[-1.0, 10.0], [0.0, -1.0]], requires_grad=True)},
]),
Example(Binomial, [
{'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True),
'total_count': 10},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
'total_count': 10},
]),
Example(NegativeBinomial, [
{'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True),
'total_count': 10},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
'total_count': 10},
]),
Example(Cauchy, [
{'loc': 0.0, 'scale': -1.0},
{'loc': torch.tensor([0.0]), 'scale': 0.0},
{'loc': torch.tensor([[0.0], [-2.0]]),
'scale': torch.tensor([[-0.000001], [1.0]])}
]),
Example(Chi2, [
{'df': torch.tensor([0.], requires_grad=True)},
{'df': torch.tensor([-2.], requires_grad=True)},
]),
Example(StudentT, [
{'df': torch.tensor([0.], requires_grad=True)},
{'df': torch.tensor([-2.], requires_grad=True)},
]),
Example(Dirichlet, [
{'concentration': torch.tensor([0.], requires_grad=True)},
{'concentration': torch.tensor([-2.], requires_grad=True)}
]),
Example(Exponential, [
{'rate': torch.tensor([0., 0.], requires_grad=True)},
{'rate': torch.tensor([-2.], requires_grad=True)}
]),
Example(FisherSnedecor, [
{
'df1': torch.tensor([0., 0.], requires_grad=True),
'df2': torch.tensor([-1., -100.], requires_grad=True),
},
{
'df1': torch.tensor([1., 1.], requires_grad=True),
'df2': torch.tensor([0., 0.], requires_grad=True),
}
]),
Example(Gamma, [
{
'concentration': torch.tensor([0., 0.], requires_grad=True),
'rate': torch.tensor([-1., -100.], requires_grad=True),
},
{
'concentration': torch.tensor([1., 1.], requires_grad=True),
'rate': torch.tensor([0., 0.], requires_grad=True),
}
]),
Example(Gumbel, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([0., 1.], requires_grad=True),
},
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([1., -1.], requires_grad=True),
},
]),
Example(HalfCauchy, [
{'scale': -1.0},
{'scale': 0.0},
{'scale': torch.tensor([[-0.000001], [1.0]])}
]),
Example(HalfNormal, [
{'scale': torch.tensor([0., 1.], requires_grad=True)},
{'scale': torch.tensor([1., -1.], requires_grad=True)},
]),
Example(LKJCholesky, [
{
'dim': -2,
'concentration': 0.1
},
{
'dim': 1,
'concentration': 2.,
},
{
'dim': 2,
'concentration': 0.,
},
]),
Example(Laplace, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([0., 1.], requires_grad=True),
},
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([1., -1.], requires_grad=True),
},
]),
Example(LogNormal, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([0., 1.], requires_grad=True),
},
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([1., -1.], requires_grad=True),
},
]),
Example(MultivariateNormal, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True),
},
]),
Example(Normal, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([0., 1.], requires_grad=True),
},
{
'loc': torch.tensor([1., 1.], requires_grad=True),
'scale': torch.tensor([1., -1.], requires_grad=True),
},
{
'loc': torch.tensor([1.0, 0.0], requires_grad=True),
'scale': torch.tensor([1e-5, -1e-5], requires_grad=True),
},
]),
Example(OneHotCategorical, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(OneHotCategoricalStraightThrough, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Pareto, [
{
'scale': 0.0,
'alpha': 0.0
},
{
'scale': torch.tensor([0.0, 0.0], requires_grad=True),
'alpha': torch.tensor([-1e-5, 0.0], requires_grad=True)
},
{
'scale': torch.tensor([1.0]),
'alpha': -1.0
}
]),
Example(Poisson, [
{
'rate': torch.tensor([-0.1], requires_grad=True),
},
{
'rate': -1.0,
}
]),
Example(RelaxedBernoulli, [
{
'temperature': torch.tensor([1.5], requires_grad=True),
'probs': torch.tensor([1.7, 0.2, 0.4], requires_grad=True),
},
{
'temperature': torch.tensor([2.0]),
'probs': torch.tensor([-1.0]),
}
]),
Example(RelaxedOneHotCategorical, [
{
'temperature': torch.tensor([0.5], requires_grad=True),
'probs': torch.tensor([[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True)
},
{
'temperature': torch.tensor([2.0]),
'probs': torch.tensor([[-1.0, 0.0], [-1.0, 1.1]])
}
]),
Example(TransformedDistribution, [
{
'base_distribution': Normal(0, 1),
'transforms': lambda x: x,
},
{
'base_distribution': Normal(0, 1),
'transforms': [lambda x: x],
},
]),
Example(Uniform, [
{
'low': torch.tensor([2.0], requires_grad=True),
'high': torch.tensor([2.0], requires_grad=True),
},
{
'low': torch.tensor([0.0], requires_grad=True),
'high': torch.tensor([0.0], requires_grad=True),
},
{
'low': torch.tensor([1.0], requires_grad=True),
'high': torch.tensor([0.0], requires_grad=True),
}
]),
Example(Weibull, [
{
'scale': torch.tensor([0.0], requires_grad=True),
'concentration': torch.tensor([0.0], requires_grad=True)
},
{
'scale': torch.tensor([1.0], requires_grad=True),
'concentration': torch.tensor([-1.0], requires_grad=True)
}
]),
Example(Wishart, [
{
'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True),
'df': torch.tensor([1.5], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
'df': torch.tensor([3.], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
'df': 3.,
},
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([-0.5], requires_grad=True)},
{'probs': 1.00001},
])
]
class DistributionsTestCase(TestCase):
def setUp(self):
"""The tests assume that the validation flag is set."""
torch.distributions.Distribution.set_default_validate_args(True)
super(DistributionsTestCase, self).setUp()
class TestDistributions(DistributionsTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
def _gradcheck_log_prob(self, dist_ctor, ctor_params):
# performs gradient checks on log_prob
distribution = dist_ctor(*ctor_params)
s = distribution.sample()
if not distribution.support.is_discrete:
s = s.detach().requires_grad_()
expected_shape = distribution.batch_shape + distribution.event_shape
self.assertEqual(s.size(), expected_shape)
def apply_fn(s, *params):
return dist_ctor(*params).log_prob(s)
gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=True)
def _check_forward_ad(self, fn):
with fwAD.dual_level():
x = torch.tensor(1.)
t = torch.tensor(1.)
dual = fwAD.make_dual(x, t)
dual_out = fn(dual)
self.assertEqual(torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0)
def _check_log_prob(self, dist, asset_fn):
# checks that the log_prob matches a reference function
s = dist.sample()
log_probs = dist.log_prob(s)
log_probs_data_flat = log_probs.view(-1)
s_data_flat = s.view(len(log_probs_data_flat), -1)
for i, (val, log_prob) in enumerate(zip(s_data_flat, log_probs_data_flat)):
asset_fn(i, val.squeeze(), log_prob)
def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
circular=False, num_samples=10000, failure_rate=1e-3):
# Checks that the .sample() method matches a reference function.
torch_samples = torch_dist.sample((num_samples,)).squeeze()
torch_samples = torch_samples.cpu().numpy()
ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
if multivariate:
# Project onto a random axis.
axis = np.random.normal(size=(1,) + torch_samples.shape[1:])
axis /= np.linalg.norm(axis)
torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1)
ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1)
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
if circular:
samples = [(np.cos(x), v) for (x, v) in samples]
shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete
samples.sort(key=lambda x: x[0])
samples = np.array(samples)[:, 1]
# Aggregate into bins filled with roughly zero-mean unit-variance RVs.
num_bins = 10
samples_per_bin = len(samples) // num_bins
bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
stddev = samples_per_bin ** -0.5
threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins)
message = '{}.sample() is biased:\n{}'.format(message, bins)
for bias in bins:
self.assertLess(-threshold, bias, message)
self.assertLess(bias, threshold, message)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def _check_sampler_discrete(self, torch_dist, ref_dist, message,
num_samples=10000, failure_rate=1e-3):
"""Runs a Chi2-test for the support, but ignores tail instead of combining"""
torch_samples = torch_dist.sample((num_samples,)).squeeze()
torch_samples = torch_samples.cpu().numpy()
unique, counts = np.unique(torch_samples, return_counts=True)
pmf = ref_dist.pmf(unique)
pmf = pmf / pmf.sum() # renormalize to 1.0 for chisq test
msk = (counts > 5) & ((pmf * num_samples) > 5)
self.assertGreater(pmf[msk].sum(), 0.9, "Distribution is too sparse for test; try increasing num_samples")
# Add a remainder bucket that combines counts for all values
# below threshold, if such values exist (i.e. mask has False entries).
if not msk.all():
counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)])
pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)])
chisq, p = scipy.stats.chisquare(counts, pmf * num_samples)
self.assertGreater(p, failure_rate, message)
def _check_enumerate_support(self, dist, examples):
for params, expected in examples:
params = {k: torch.tensor(v) for k, v in params.items()}
d = dist(**params)
actual = d.enumerate_support(expand=False)
expected = torch.tensor(expected, dtype=actual.dtype)
self.assertEqual(actual, expected)
actual = d.enumerate_support(expand=True)
expected_with_expand = expected.expand((-1,) + d.batch_shape + d.event_shape)
self.assertEqual(actual, expected_with_expand)
def test_repr(self):
for Dist, params in EXAMPLES:
for param in params:
dist = Dist(**param)
self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
def test_sample_detached(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
variable_params = [p for p in param.values() if getattr(p, 'requires_grad', False)]
if not variable_params:
continue
dist = Dist(**param)
sample = dist.sample()
self.assertFalse(sample.requires_grad,
msg='{} example {}/{}, .sample() is not detached'.format(
Dist.__name__, i + 1, len(params)))
def test_rsample_requires_grad(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
if not any(getattr(p, 'requires_grad', False) for p in param.values()):
continue
dist = Dist(**param)
if not dist.has_rsample:
continue
sample = dist.rsample()
self.assertTrue(sample.requires_grad,
msg='{} example {}/{}, .rsample() does not require grad'.format(
Dist.__name__, i + 1, len(params)))
def test_enumerate_support_type(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
try:
self.assertTrue(type(dist.sample()) is type(dist.enumerate_support()),
msg=('{} example {}/{}, return type mismatch between ' +
'sample and enumerate_support.').format(Dist.__name__, i + 1, len(params)))
except NotImplementedError:
pass
def test_lazy_property_grad(self):
x = torch.randn(1, requires_grad=True)
class Dummy(object):
@lazy_property
def y(self):
return x + 1
def test():
x.grad = None
Dummy().y.backward()
self.assertEqual(x.grad, torch.ones(1))
test()
with torch.no_grad():
test()
mean = torch.randn(2)
cov = torch.eye(2, requires_grad=True)
distn = MultivariateNormal(mean, cov)
with torch.no_grad():
distn.scale_tril
distn.scale_tril.sum().backward()
self.assertIsNotNone(cov.grad)
def test_has_examples(self):
distributions_with_examples = {e.Dist for e in EXAMPLES}
for Dist in globals().values():
if isinstance(Dist, type) and issubclass(Dist, Distribution) \
and Dist is not Distribution and Dist is not ExponentialFamily:
self.assertIn(Dist, distributions_with_examples,
"Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))
def test_support_attributes(self):
for Dist, params in EXAMPLES:
for param in params:
d = Dist(**param)
event_dim = len(d.event_shape)
self.assertEqual(d.support.event_dim, event_dim)
try:
self.assertEqual(Dist.support.event_dim, event_dim)
except NotImplementedError:
pass
is_discrete = d.support.is_discrete
try:
self.assertEqual(Dist.support.is_discrete, is_discrete)
except NotImplementedError:
pass
def test_distribution_expand(self):
shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
for Dist, params in EXAMPLES:
for param in params:
for shape in shapes:
d = Dist(**param)
expanded_shape = shape + d.batch_shape
original_shape = d.batch_shape + d.event_shape
expected_shape = shape + original_shape
expanded = d.expand(batch_shape=list(expanded_shape))
sample = expanded.sample()
actual_shape = expanded.sample().shape
self.assertEqual(expanded.__class__, d.__class__)
self.assertEqual(d.sample().shape, original_shape)
self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
self.assertEqual(actual_shape, expected_shape)
self.assertEqual(expanded.batch_shape, expanded_shape)
try:
self.assertEqual(expanded.mean,
d.mean.expand(expanded_shape + d.event_shape))
self.assertEqual(expanded.variance,
d.variance.expand(expanded_shape + d.event_shape))
except NotImplementedError:
pass
def test_distribution_subclass_expand(self):
expand_by = torch.Size((2,))
for Dist, params in EXAMPLES:
class SubClass(Dist):
pass
for param in params:
d = SubClass(**param)
expanded_shape = expand_by + d.batch_shape
original_shape = d.batch_shape + d.event_shape
expected_shape = expand_by + original_shape
expanded = d.expand(batch_shape=expanded_shape)
sample = expanded.sample()
actual_shape = expanded.sample().shape
self.assertEqual(expanded.__class__, d.__class__)
self.assertEqual(d.sample().shape, original_shape)
self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
self.assertEqual(actual_shape, expected_shape)
def test_bernoulli(self):
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
r = torch.tensor(0.3, requires_grad=True)
s = 0.3
self.assertEqual(Bernoulli(p).sample((8,)).size(), (8, 3))
self.assertFalse(Bernoulli(p).sample().requires_grad)
self.assertEqual(Bernoulli(r).sample((8,)).size(), (8,))
self.assertEqual(Bernoulli(r).sample().size(), ())
self.assertEqual(Bernoulli(r).sample((3, 2)).size(), (3, 2,))
self.assertEqual(Bernoulli(s).sample().size(), ())
self._gradcheck_log_prob(Bernoulli, (p,))
def ref_log_prob(idx, val, log_prob):
prob = p[idx]
self.assertEqual(log_prob, math.log(prob if val else 1 - prob))
self._check_log_prob(Bernoulli(p), ref_log_prob)
self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
# check entropy computation
self.assertEqual(Bernoulli(p).entropy(), torch.tensor([0.6108, 0.5004, 0.6730]), atol=1e-4, rtol=0)
self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0]))
self.assertEqual(Bernoulli(s).entropy(), torch.tensor(0.6108), atol=1e-4, rtol=0)
self._check_forward_ad(torch.bernoulli)
self._check_forward_ad(lambda x: x.bernoulli_())
self._check_forward_ad(lambda x: x.bernoulli_(x.clone().detach()))
self._check_forward_ad(lambda x: x.bernoulli_(x))
def test_bernoulli_enumerate_support(self):
examples = [
({"probs": [0.1]}, [[0], [1]]),
({"probs": [0.1, 0.9]}, [[0], [1]]),
({"probs": [[0.1, 0.2], [0.3, 0.4]]}, [[[0]], [[1]]]),
]
self._check_enumerate_support(Bernoulli, examples)
def test_bernoulli_3d(self):
p = torch.full((2, 3, 5), 0.5).requires_grad_()
self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
self.assertEqual(Bernoulli(p).sample(sample_shape=(2, 5)).size(),
(2, 5, 2, 3, 5))
self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
def test_geometric(self):
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
r = torch.tensor(0.3, requires_grad=True)
s = 0.3
self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3))
self.assertEqual(Geometric(1).sample(), 0)
self.assertEqual(Geometric(1).log_prob(torch.tensor(1.)), -inf)
self.assertEqual(Geometric(1).log_prob(torch.tensor(0.)), 0)
self.assertFalse(Geometric(p).sample().requires_grad)
self.assertEqual(Geometric(r).sample((8,)).size(), (8,))
self.assertEqual(Geometric(r).sample().size(), ())
self.assertEqual(Geometric(r).sample((3, 2)).size(), (3, 2))
self.assertEqual(Geometric(s).sample().size(), ())
self._gradcheck_log_prob(Geometric, (p,))
self.assertRaises(ValueError, lambda: Geometric(0))
self.assertRaises(NotImplementedError, Geometric(r).rsample)
self._check_forward_ad(lambda x: x.geometric_(0.2))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_geometric_log_prob_and_entropy(self):
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
s = 0.3
def ref_log_prob(idx, val, log_prob):
prob = p[idx].detach()
self.assertEqual(log_prob, scipy.stats.geom(prob, loc=-1).logpmf(val))
self._check_log_prob(Geometric(p), ref_log_prob)
self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob)
# check entropy computation
self.assertEqual(Geometric(p).entropy(), scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(), atol=1e-3, rtol=0)
self.assertEqual(float(Geometric(s).entropy()), scipy.stats.geom(s, loc=-1).entropy().item(), atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_geometric_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for prob in [0.01, 0.18, 0.8]:
self._check_sampler_discrete(Geometric(prob),
scipy.stats.geom(p=prob, loc=-1),
'Geometric(prob={})'.format(prob))
def test_binomial(self):
p = torch.arange(0.05, 1, 0.1).requires_grad_()
for total_count in [1, 2, 10]:
self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p])
self._gradcheck_log_prob(lambda p: Binomial(total_count, None, p.log()), [p])
self.assertRaises(NotImplementedError, Binomial(10, p).rsample)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_binomial_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for prob in [0.01, 0.1, 0.5, 0.8, 0.9]:
for count in [2, 10, 100, 500]:
self._check_sampler_discrete(Binomial(total_count=count, probs=prob),
scipy.stats.binom(count, prob),
'Binomial(total_count={}, probs={})'.format(count, prob))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_binomial_log_prob_and_entropy(self):
probs = torch.arange(0.05, 1, 0.1)
for total_count in [1, 2, 10]:
def ref_log_prob(idx, x, log_prob):
p = probs.view(-1)[idx].item()
expected = scipy.stats.binom(total_count, p).logpmf(x)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Binomial(total_count, probs), ref_log_prob)
logits = probs_to_logits(probs, is_binary=True)
self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
bin = Binomial(total_count, logits=logits)
self.assertEqual(
bin.entropy(),
scipy.stats.binom(total_count, bin.probs.detach().numpy(), loc=-1).entropy(),
atol=1e-3, rtol=0)
def test_binomial_stable(self):
logits = torch.tensor([-100., 100.], dtype=torch.float)
total_count = 1.
x = torch.tensor([0., 0.], dtype=torch.float)
log_prob = Binomial(total_count, logits=logits).log_prob(x)
self.assertTrue(torch.isfinite(log_prob).all())
# make sure that the grad at logits=0, value=0 is 0.5
x = torch.tensor(0., requires_grad=True)
y = Binomial(total_count, logits=x).log_prob(torch.tensor(0.))
self.assertEqual(grad(y, x)[0], torch.tensor(-0.5))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_binomial_log_prob_vectorized_count(self):
probs = torch.tensor([0.2, 0.7, 0.9])
for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])),
(torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
log_prob = Binomial(total_count, probs).log_prob(sample)
expected = scipy.stats.binom(total_count.cpu().numpy(), probs.cpu().numpy()).logpmf(sample)
self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
def test_binomial_enumerate_support(self):
examples = [
({"probs": [0.1], "total_count": 2}, [[0], [1], [2]]),
({"probs": [0.1, 0.9], "total_count": 2}, [[0], [1], [2]]),
({"probs": [[0.1, 0.2], [0.3, 0.4]], "total_count": 3}, [[[0]], [[1]], [[2]], [[3]]]),
]
self._check_enumerate_support(Binomial, examples)
def test_binomial_extreme_vals(self):
total_count = 100
bin0 = Binomial(total_count, 0)
self.assertEqual(bin0.sample(), 0)
self.assertEqual(bin0.log_prob(torch.tensor([0.]))[0], 0, atol=1e-3, rtol=0)
self.assertEqual(float(bin0.log_prob(torch.tensor([1.])).exp()), 0)
bin1 = Binomial(total_count, 1)
self.assertEqual(bin1.sample(), total_count)
self.assertEqual(bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, atol=1e-3, rtol=0)
self.assertEqual(float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0)
zero_counts = torch.zeros(torch.Size((2, 2)))
bin2 = Binomial(zero_counts, 1)
self.assertEqual(bin2.sample(), zero_counts)
self.assertEqual(bin2.log_prob(zero_counts), zero_counts)
def test_binomial_vectorized_count(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
total_count = torch.tensor([[4, 7], [3, 8]], dtype=torch.float64)
bin0 = Binomial(total_count, torch.tensor(1.))
self.assertEqual(bin0.sample(), total_count)
bin1 = Binomial(total_count, torch.tensor(0.5))
samples = bin1.sample(torch.Size((100000,)))
self.assertTrue((samples <= total_count.type_as(samples)).all())
self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0)
self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0)
def test_negative_binomial(self):
p = torch.arange(0.05, 1, 0.1).requires_grad_()
for total_count in [1, 2, 10]:
self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, p), [p])
self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, None, p.log()), [p])
self.assertRaises(NotImplementedError, NegativeBinomial(10, p).rsample)
self.assertRaises(NotImplementedError, NegativeBinomial(10, p).entropy)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_negative_binomial_log_prob(self):
probs = torch.arange(0.05, 1, 0.1)
for total_count in [1, 2, 10]:
def ref_log_prob(idx, x, log_prob):
p = probs.view(-1)[idx].item()
expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob)
logits = probs_to_logits(probs, is_binary=True)
self._check_log_prob(NegativeBinomial(total_count, logits=logits), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_negative_binomial_log_prob_vectorized_count(self):
probs = torch.tensor([0.2, 0.7, 0.9])
for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])),
(torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
log_prob = NegativeBinomial(total_count, probs).log_prob(sample)
expected = scipy.stats.nbinom(total_count.cpu().numpy(), 1 - probs.cpu().numpy()).logpmf(sample)
self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_zero_excluded_binomial(self):
vals = Binomial(total_count=torch.tensor(1.0).cuda(),
probs=torch.tensor(0.9).cuda()
).sample(torch.Size((100000000,)))
self.assertTrue((vals >= 0).all())
vals = Binomial(total_count=torch.tensor(1.0).cuda(),
probs=torch.tensor(0.1).cuda()
).sample(torch.Size((100000000,)))
self.assertTrue((vals < 2).all())
vals = Binomial(total_count=torch.tensor(1.0).cuda(),
probs=torch.tensor(0.5).cuda()
).sample(torch.Size((10000,)))
# vals should be roughly half zeroes, half ones
assert (vals == 0.0).sum() > 4000
assert (vals == 1.0).sum() > 4000
def test_multinomial_1d(self):
total_count = 10
p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3))
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_multinomial_1d_log_prob_and_entropy(self):
total_count = 10
p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
dist = Multinomial(total_count, probs=p)
x = dist.sample()
log_prob = dist.log_prob(x)
expected = torch.tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
self.assertEqual(log_prob, expected)
dist = Multinomial(total_count, logits=p.log())
x = dist.sample()
log_prob = dist.log_prob(x)
expected = torch.tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
self.assertEqual(log_prob, expected)
expected = scipy.stats.multinomial.entropy(total_count, dist.probs.detach().numpy())
self.assertEqual(dist.entropy(), expected, atol=1e-3, rtol=0)
def test_multinomial_2d(self):
total_count = 10
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = torch.tensor(probabilities, requires_grad=True)
s = torch.tensor(probabilities_1, requires_grad=True)
self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3))
set_rng_seed(0)
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
# sample check for extreme value of probs
self.assertEqual(Multinomial(total_count, s).sample(),
torch.tensor([[total_count, 0], [0, total_count]], dtype=torch.float64))
def test_categorical_1d(self):
p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
self.assertTrue(is_all_nan(Categorical(p).mean))
self.assertTrue(is_all_nan(Categorical(p).variance))
self.assertEqual(Categorical(p).sample().size(), ())
self.assertFalse(Categorical(p).sample().requires_grad)
self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
self.assertEqual(Categorical(p).sample((1,)).size(), (1,))
self._gradcheck_log_prob(Categorical, (p,))
self.assertRaises(NotImplementedError, Categorical(p).rsample)
def test_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = torch.tensor(probabilities, requires_grad=True)
s = torch.tensor(probabilities_1, requires_grad=True)
self.assertEqual(Categorical(p).mean.size(), (2,))
self.assertEqual(Categorical(p).variance.size(), (2,))
self.assertTrue(is_all_nan(Categorical(p).mean))
self.assertTrue(is_all_nan(Categorical(p).variance))
self.assertEqual(Categorical(p).sample().size(), (2,))
self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2))
self.assertEqual(Categorical(p).sample((6,)).size(), (6, 2))
self._gradcheck_log_prob(Categorical, (p,))
# sample check for extreme value of probs
set_rng_seed(0)
self.assertEqual(Categorical(s).sample(sample_shape=(2,)),
torch.tensor([[0, 1], [0, 1]]))
def ref_log_prob(idx, val, log_prob):
sample_prob = p[idx][val] / p[idx].sum()
self.assertEqual(log_prob, math.log(sample_prob))
self._check_log_prob(Categorical(p), ref_log_prob)
self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)
# check entropy computation
self.assertEqual(Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), atol=1e-4, rtol=0)
self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0]))
# issue gh-40553
logits = p.log()
logits[1, 1] = logits[0, 2] = float('-inf')
e = Categorical(logits=logits).entropy()
self.assertEqual(e, torch.tensor([0.6365, 0.5983]), atol=1e-4, rtol=0)
def test_categorical_enumerate_support(self):
examples = [
({"probs": [0.1, 0.2, 0.7]}, [0, 1, 2]),
({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[0], [1]]),
]
self._check_enumerate_support(Categorical, examples)
def test_one_hot_categorical_1d(self):
p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
self.assertFalse(OneHotCategorical(p).sample().requires_grad)
self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(OneHotCategorical(p).sample((1,)).size(), (1, 3))
self._gradcheck_log_prob(OneHotCategorical, (p,))
self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)
def test_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = torch.tensor(probabilities, requires_grad=True)
s = torch.tensor(probabilities_1, requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3))
self._gradcheck_log_prob(OneHotCategorical, (p,))
dist = OneHotCategorical(p)
x = dist.sample()
self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
def test_one_hot_categorical_enumerate_support(self):
examples = [
({"probs": [0.1, 0.2, 0.7]}, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[[1, 0]], [[0, 1]]]),
]
self._check_enumerate_support(OneHotCategorical, examples)
def test_poisson_forward_ad(self):
self._check_forward_ad(torch.poisson)
def test_poisson_shape(self):
rate = torch.randn(2, 3).abs().requires_grad_()
rate_1d = torch.randn(1).abs().requires_grad_()
self.assertEqual(Poisson(rate).sample().size(), (2, 3))
self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3))
self.assertEqual(Poisson(rate_1d).sample().size(), (1,))
self.assertEqual(Poisson(rate_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_poisson_log_prob(self):
rate = torch.randn(2, 3).abs().requires_grad_()
rate_1d = torch.randn(1).abs().requires_grad_()
rate_zero = torch.zeros([], requires_grad=True)
def ref_log_prob(ref_rate, idx, x, log_prob):
l = ref_rate.view(-1)[idx].detach()
expected = scipy.stats.poisson.logpmf(x, l)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
set_rng_seed(0)
self._check_log_prob(Poisson(rate), lambda *args: ref_log_prob(rate, *args))
self._check_log_prob(Poisson(rate_zero), lambda *args: ref_log_prob(rate_zero, *args))
self._gradcheck_log_prob(Poisson, (rate,))
self._gradcheck_log_prob(Poisson, (rate_1d,))
# We cannot check gradients automatically for zero rates because the finite difference
# approximation enters the forbidden parameter space. We instead compare with the
# theoretical results.
dist = Poisson(rate_zero)
dist.log_prob(torch.ones_like(rate_zero)).backward()
torch.testing.assert_allclose(rate_zero.grad, torch.inf)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_poisson_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for rate in [0.1, 1.0, 5.0]:
self._check_sampler_discrete(Poisson(rate),
scipy.stats.poisson(rate),
'Poisson(lambda={})'.format(rate),
failure_rate=1e-3)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_poisson_gpu_sample(self):
set_rng_seed(1)
for rate in [0.12, 0.9, 4.0]:
self._check_sampler_discrete(Poisson(torch.tensor([rate]).cuda()),
scipy.stats.poisson(rate),
'Poisson(lambda={}, cuda)'.format(rate),
failure_rate=1e-3)
def test_relaxed_bernoulli(self):
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
r = torch.tensor(0.3, requires_grad=True)
s = 0.3
temp = torch.tensor(0.67, requires_grad=True)
self.assertEqual(RelaxedBernoulli(temp, p).sample((8,)).size(), (8, 3))
self.assertFalse(RelaxedBernoulli(temp, p).sample().requires_grad)
self.assertEqual(RelaxedBernoulli(temp, r).sample((8,)).size(), (8,))
self.assertEqual(RelaxedBernoulli(temp, r).sample().size(), ())
self.assertEqual(RelaxedBernoulli(temp, r).sample((3, 2)).size(), (3, 2,))
self.assertEqual(RelaxedBernoulli(temp, s).sample().size(), ())
self._gradcheck_log_prob(RelaxedBernoulli, (temp, p))
self._gradcheck_log_prob(RelaxedBernoulli, (temp, r))
# test that rsample doesn't fail
s = RelaxedBernoulli(temp, p).rsample()
s.backward(torch.ones_like(s))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_rounded_relaxed_bernoulli(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
class Rounded(object):
def __init__(self, dist):
self.dist = dist
def sample(self, *args, **kwargs):
return torch.round(self.dist.sample(*args, **kwargs))
for probs, temp in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]):
self._check_sampler_discrete(Rounded(RelaxedBernoulli(temp, probs)),
scipy.stats.bernoulli(probs),
'Rounded(RelaxedBernoulli(temp={}, probs={}))'.format(temp, probs),
failure_rate=1e-3)
for probs in [0.001, 0.2, 0.999]:
equal_probs = torch.tensor(0.5)
dist = RelaxedBernoulli(1e10, probs)
s = dist.rsample()
self.assertEqual(equal_probs, s)
def test_relaxed_one_hot_categorical_1d(self):
p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
temp = torch.tensor(0.67, requires_grad=True)
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,))
self.assertFalse(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad)
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), (1, 3))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p))
def test_relaxed_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
temp = torch.tensor([3.0], requires_grad=True)
# The lower the temperature, the more unstable the log_prob gradcheck is
# w.r.t. the sample. Values below 0.25 empirically fail the default tol.
temp_2 = torch.tensor([0.25], requires_grad=True)
p = torch.tensor(probabilities, requires_grad=True)
s = torch.tensor(probabilities_1, requires_grad=True)
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p))
self._gradcheck_log_prob(lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp_2, p))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_argmax_relaxed_categorical(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
class ArgMax(object):
def __init__(self, dist):
self.dist = dist
def sample(self, *args, **kwargs):
s = self.dist.sample(*args, **kwargs)
_, idx = torch.max(s, -1)
return idx
class ScipyCategorical(object):
def __init__(self, dist):
self.dist = dist
def pmf(self, samples):
new_samples = np.zeros(samples.shape + self.dist.p.shape)
new_samples[np.arange(samples.shape[0]), samples] = 1
return self.dist.pmf(new_samples)
for probs, temp in product([torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])], [0.1, 1.0, 10.0]):
self._check_sampler_discrete(ArgMax(RelaxedOneHotCategorical(temp, probs)),
ScipyCategorical(scipy.stats.multinomial(1, probs)),
'Rounded(RelaxedOneHotCategorical(temp={}, probs={}))'.format(temp, probs),
failure_rate=1e-3)
for probs in [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])]:
equal_probs = torch.ones(probs.size()) / probs.size()[0]
dist = RelaxedOneHotCategorical(1e10, probs)
s = dist.rsample()
self.assertEqual(equal_probs, s)
def test_uniform(self):
low = torch.zeros(5, 5, requires_grad=True)
high = (torch.ones(5, 5) * 3).requires_grad_()
low_1d = torch.zeros(1, requires_grad=True)
high_1d = (torch.ones(1) * 3).requires_grad_()
self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
# Check log_prob computation when value outside range
uniform = Uniform(low_1d, high_1d, validate_args=False)
above_high = torch.tensor([4.0])
below_low = torch.tensor([-1.0])
self.assertEqual(uniform.log_prob(above_high).item(), -inf)
self.assertEqual(uniform.log_prob(below_low).item(), -inf)
# check cdf computation when value outside range
self.assertEqual(uniform.cdf(below_low).item(), 0)
self.assertEqual(uniform.cdf(above_high).item(), 1)
set_rng_seed(1)
self._gradcheck_log_prob(Uniform, (low, high))
self._gradcheck_log_prob(Uniform, (low, 1.0))
self._gradcheck_log_prob(Uniform, (0.0, high))
state = torch.get_rng_state()
rand = low.new(low.size()).uniform_()
torch.set_rng_state(state)
u = Uniform(low, high).rsample()
u.backward(torch.ones_like(u))
self.assertEqual(low.grad, 1 - rand)
self.assertEqual(high.grad, rand)
low.grad.zero_()
high.grad.zero_()
self._check_forward_ad(lambda x: x.uniform_())
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_vonmises_sample(self):
for loc in [0.0, math.pi / 2.0]:
for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]:
self._check_sampler_sampler(VonMises(loc, concentration),
scipy.stats.vonmises(loc=loc, kappa=concentration),
"VonMises(loc={}, concentration={})".format(loc, concentration),
num_samples=int(1e5), circular=True)
def test_vonmises_logprob(self):
concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
for concentration in concentrations:
grid = torch.arange(0., 2 * math.pi, 1e-4)
prob = VonMises(0.0, concentration).log_prob(grid).exp()
norm = prob.mean().item() * 2 * math.pi
self.assertLess(abs(norm - 1), 1e-3)
def test_cauchy(self):
loc = torch.zeros(5, 5, requires_grad=True)
scale = torch.ones(5, 5, requires_grad=True)
loc_1d = torch.zeros(1, requires_grad=True)
scale_1d = torch.ones(1, requires_grad=True)
self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean))
self.assertEqual(Cauchy(loc_1d, scale_1d).variance, inf)
self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
self.assertEqual(Cauchy(loc, scale).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,))
self.assertEqual(Cauchy(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,))
set_rng_seed(1)
self._gradcheck_log_prob(Cauchy, (loc, scale))
self._gradcheck_log_prob(Cauchy, (loc, 1.0))
self._gradcheck_log_prob(Cauchy, (0.0, scale))
state = torch.get_rng_state()
eps = loc.new(loc.size()).cauchy_()
torch.set_rng_state(state)
c = Cauchy(loc, scale).rsample()
c.backward(torch.ones_like(c))
self.assertEqual(loc.grad, torch.ones_like(scale))
self.assertEqual(scale.grad, eps)
loc.grad.zero_()
scale.grad.zero_()
self._check_forward_ad(lambda x: x.cauchy_())
def test_halfcauchy(self):
scale = torch.ones(5, 5, requires_grad=True)
scale_1d = torch.ones(1, requires_grad=True)
self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all())
self.assertEqual(HalfCauchy(scale_1d).variance, inf)
self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))
self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,))
self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,))
set_rng_seed(1)
self._gradcheck_log_prob(HalfCauchy, (scale,))
self._gradcheck_log_prob(HalfCauchy, (1.0,))
state = torch.get_rng_state()
eps = scale.new(scale.size()).cauchy_().abs_()
torch.set_rng_state(state)
c = HalfCauchy(scale).rsample()
c.backward(torch.ones_like(c))
self.assertEqual(scale.grad, eps)
scale.grad.zero_()
def test_halfnormal(self):
std = torch.randn(5, 5).abs().requires_grad_()
std_1d = torch.randn(1).abs().requires_grad_()
std_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(HalfNormal(std).sample().size(), (5, 5))
self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5))
self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1))
self.assertEqual(HalfNormal(std_1d).sample().size(), (1,))
self.assertEqual(HalfNormal(.6).sample((1,)).size(), (1,))
self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,))
# sample check for extreme value of std
set_rng_seed(1)
self.assertEqual(HalfNormal(std_delta).sample(sample_shape=(1, 2)),
torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]),
atol=1e-4, rtol=0)
self._gradcheck_log_prob(HalfNormal, (std,))
self._gradcheck_log_prob(HalfNormal, (1.0,))
# check .log_prob() can broadcast.
dist = HalfNormal(torch.ones(2, 1, 4))
log_prob = dist.log_prob(torch.ones(3, 1))
self.assertEqual(log_prob.shape, (2, 3, 4))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_halfnormal_logprob(self):
std = torch.randn(5, 1).abs().requires_grad_()
def ref_log_prob(idx, x, log_prob):
s = std.view(-1)[idx].detach()
expected = scipy.stats.halfnorm(scale=s).logpdf(x)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(HalfNormal(std), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_halfnormal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for std in [0.1, 1.0, 10.0]:
self._check_sampler_sampler(HalfNormal(std),
scipy.stats.halfnorm(scale=std),
'HalfNormal(scale={})'.format(std))
def test_lognormal(self):
mean = torch.randn(5, 5, requires_grad=True)
std = torch.randn(5, 5).abs().requires_grad_()
mean_1d = torch.randn(1, requires_grad=True)
std_1d = torch.randn(1).abs().requires_grad_()
mean_delta = torch.tensor([1.0, 0.0])
std_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5))
self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1))
self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,))
self.assertEqual(LogNormal(0.2, .6).sample((1,)).size(), (1,))
self.assertEqual(LogNormal(-0.7, 50.0).sample((1,)).size(), (1,))
# sample check for extreme value of mean, std
set_rng_seed(1)
self.assertEqual(LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
atol=1e-4, rtol=0)
self._gradcheck_log_prob(LogNormal, (mean, std))
self._gradcheck_log_prob(LogNormal, (mean, 1.0))
self._gradcheck_log_prob(LogNormal, (0.0, std))
# check .log_prob() can broadcast.
dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1))
log_prob = dist.log_prob(torch.ones(3, 1))
self.assertEqual(log_prob.shape, (2, 3, 4))
self._check_forward_ad(lambda x: x.log_normal_())
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lognormal_logprob(self):
mean = torch.randn(5, 1, requires_grad=True)
std = torch.randn(5, 1).abs().requires_grad_()
def ref_log_prob(idx, x, log_prob):
m = mean.view(-1)[idx].detach()
s = std.view(-1)[idx].detach()
expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(LogNormal(mean, std), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lognormal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(LogNormal(mean, std),
scipy.stats.lognorm(scale=math.exp(mean), s=std),
'LogNormal(loc={}, scale={})'.format(mean, std))
def test_logisticnormal(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
mean = torch.randn(5, 5).requires_grad_()
std = torch.randn(5, 5).abs().requires_grad_()
mean_1d = torch.randn(1).requires_grad_()
std_1d = torch.randn(1).abs().requires_grad_()
mean_delta = torch.tensor([1.0, 0.0])
std_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6))
self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
self.assertEqual(LogisticNormal(0.2, .6).sample().size(), (2,))
self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,))
# sample check for extreme value of mean, std
set_rng_seed(1)
self.assertEqual(LogisticNormal(mean_delta, std_delta).sample(),
torch.tensor([math.exp(1) / (1. + 1. + math.exp(1)),
1. / (1. + 1. + math.exp(1)),
1. / (1. + 1. + math.exp(1))]),
atol=1e-4, rtol=0)
# TODO: gradcheck seems to mutate the sample values so that the simplex
# constraint fails by a very small margin.
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, std))
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, 1.0))
self._gradcheck_log_prob(lambda m, s: LogisticNormal(m, s, validate_args=False), (0.0, std))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_logisticnormal_logprob(self):
mean = torch.randn(5, 7).requires_grad_()
std = torch.randn(5, 7).abs().requires_grad_()
# Smoke test for now
# TODO: Once _check_log_prob works with multidimensional distributions,
# add proper testing of the log probabilities.
dist = LogisticNormal(mean, std)
assert dist.log_prob(dist.sample()).detach().cpu().numpy().shape == (5,)
def _get_logistic_normal_ref_sampler(self, base_dist):
def _sampler(num_samples):
x = base_dist.rvs(num_samples)
offset = np.log((x.shape[-1] + 1) - np.ones_like(x).cumsum(-1))
z = 1. / (1. + np.exp(offset - x))
z_cumprod = np.cumprod(1 - z, axis=-1)
y1 = np.pad(z, ((0, 0), (0, 1)), mode='constant', constant_values=1.)
y2 = np.pad(z_cumprod, ((0, 0), (1, 0)), mode='constant', constant_values=1.)
return y1 * y2
return _sampler
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_logisticnormal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
means = map(np.asarray, [(-1.0, -1.0), (0.0, 0.0), (1.0, 1.0)])
covs = map(np.diag, [(0.1, 0.1), (1.0, 1.0), (10.0, 10.0)])
for mean, cov in product(means, covs):
base_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
ref_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
ref_dist.rvs = self._get_logistic_normal_ref_sampler(base_dist)
mean_th = torch.tensor(mean)
std_th = torch.tensor(np.sqrt(np.diag(cov)))
self._check_sampler_sampler(
LogisticNormal(mean_th, std_th), ref_dist,
'LogisticNormal(loc={}, scale={})'.format(mean_th, std_th),
multivariate=True)
def test_mixture_same_family_shape(self):
normal_case_1d = MixtureSameFamily(
Categorical(torch.rand(5)),
Normal(torch.randn(5), torch.rand(5)))
normal_case_1d_batch = MixtureSameFamily(
Categorical(torch.rand(3, 5)),
Normal(torch.randn(3, 5), torch.rand(3, 5)))
normal_case_1d_multi_batch = MixtureSameFamily(
Categorical(torch.rand(4, 3, 5)),
Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)))
normal_case_2d = MixtureSameFamily(
Categorical(torch.rand(5)),
Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1))
normal_case_2d_batch = MixtureSameFamily(
Categorical(torch.rand(3, 5)),
Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1))
normal_case_2d_multi_batch = MixtureSameFamily(
Categorical(torch.rand(4, 3, 5)),
Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1))
self.assertEqual(normal_case_1d.sample().size(), ())
self.assertEqual(normal_case_1d.sample((2,)).size(), (2,))
self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7))
self.assertEqual(normal_case_1d_batch.sample().size(), (3,))
self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3))
self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3))
self.assertEqual(normal_case_1d_multi_batch.sample().size(), (4, 3))
self.assertEqual(normal_case_1d_multi_batch.sample((2,)).size(), (2, 4, 3))
self.assertEqual(normal_case_1d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3))
self.assertEqual(normal_case_2d.sample().size(), (2,))
self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2))
self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2))
self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2))
self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2))
self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2))
self.assertEqual(normal_case_2d_multi_batch.sample().size(), (4, 3, 2))
self.assertEqual(normal_case_2d_multi_batch.sample((2,)).size(), (2, 4, 3, 2))
self.assertEqual(normal_case_2d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3, 2))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_mixture_same_family_log_prob(self):
probs = torch.rand(5, 5).softmax(dim=-1)
loc = torch.randn(5, 5)
scale = torch.rand(5, 5)
def ref_log_prob(idx, x, log_prob):
p = probs[idx].numpy()
m = loc[idx].numpy()
s = scale[idx].numpy()
mix = scipy.stats.multinomial(1, p)
comp = scipy.stats.norm(m, s)
expected = scipy.special.logsumexp(comp.logpdf(x) + np.log(mix.p))
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(
MixtureSameFamily(Categorical(probs=probs),
Normal(loc, scale)), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_mixture_same_family_sample(self):
probs = torch.rand(5).softmax(dim=-1)
loc = torch.randn(5)
scale = torch.rand(5)
class ScipyMixtureNormal(object):
def __init__(self, probs, mu, std):
self.probs = probs
self.mu = mu
self.std = std
def rvs(self, n_sample):
comp_samples = [scipy.stats.norm(m, s).rvs(n_sample) for m, s
in zip(self.mu, self.std)]
mix_samples = scipy.stats.multinomial(1, self.probs).rvs(n_sample)
samples = []
for i in range(n_sample):
samples.append(comp_samples[mix_samples[i].argmax()][i])
return np.asarray(samples)
self._check_sampler_sampler(
MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()),
'''MixtureSameFamily(Categorical(probs={}),
Normal(loc={}, scale={}))'''.format(probs, loc, scale))
def test_normal(self):
loc = torch.randn(5, 5, requires_grad=True)
scale = torch.randn(5, 5).abs().requires_grad_()
loc_1d = torch.randn(1, requires_grad=True)
scale_1d = torch.randn(1).abs().requires_grad_()
loc_delta = torch.tensor([1.0, 0.0])
scale_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(Normal(loc, scale).sample().size(), (5, 5))
self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Normal(loc_1d, scale_1d).sample().size(), (1,))
self.assertEqual(Normal(0.2, .6).sample((1,)).size(), (1,))
self.assertEqual(Normal(-0.7, 50.0).sample((1,)).size(), (1,))
# sample check for extreme value of mean, std
set_rng_seed(1)
self.assertEqual(Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
atol=1e-4, rtol=0)
self._gradcheck_log_prob(Normal, (loc, scale))
self._gradcheck_log_prob(Normal, (loc, 1.0))
self._gradcheck_log_prob(Normal, (0.0, scale))
state = torch.get_rng_state()
eps = torch.normal(torch.zeros_like(loc), torch.ones_like(scale))
torch.set_rng_state(state)
z = Normal(loc, scale).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(loc.grad, torch.ones_like(loc))
self.assertEqual(scale.grad, eps)
loc.grad.zero_()
scale.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = loc.view(-1)[idx]
s = scale.view(-1)[idx]
expected = (math.exp(-(x - m) ** 2 / (2 * s ** 2)) /
math.sqrt(2 * math.pi * s ** 2))
self.assertEqual(log_prob, math.log(expected), atol=1e-3, rtol=0)
self._check_log_prob(Normal(loc, scale), ref_log_prob)
self._check_forward_ad(torch.normal)
self._check_forward_ad(lambda x: torch.normal(x, 0.5))
self._check_forward_ad(lambda x: torch.normal(0.2, x))
self._check_forward_ad(lambda x: torch.normal(x, x))
self._check_forward_ad(lambda x: x.normal_())
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_normal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Normal(loc, scale),
scipy.stats.norm(loc=loc, scale=scale),
'Normal(mean={}, std={})'.format(loc, scale))
def test_lowrank_multivariate_normal_shape(self):
mean = torch.randn(5, 3, requires_grad=True)
mean_no_batch = torch.randn(3, requires_grad=True)
mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
# construct PSD covariance
cov_factor = torch.randn(3, 1, requires_grad=True)
cov_diag = torch.randn(3).abs().requires_grad_()
# construct batch of PSD covariances
cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=True)
cov_diag_batched = torch.randn(6, 5, 3).abs().requires_grad_()
# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
.sample().size(), (5, 3))
self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
.sample().size(), (3,))
self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
.sample().size(), (6, 5, 3))
self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
.sample((2,)).size(), (2, 5, 3))
self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
.sample((2,)).size(), (2, 3))
self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
.sample((2,)).size(), (2, 6, 5, 3))
self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
.sample((2, 7)).size(), (2, 7, 5, 3))
self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
.sample((2, 7)).size(), (2, 7, 3))
self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
.sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched)
.sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor_batched, cov_diag_batched)
.sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor_batched, cov_diag_batched)
.sample((2, 7)).size(), (2, 7, 6, 5, 3))
# check gradients
self._gradcheck_log_prob(LowRankMultivariateNormal,
(mean, cov_factor, cov_diag))
self._gradcheck_log_prob(LowRankMultivariateNormal,
(mean_multi_batch, cov_factor, cov_diag))
self._gradcheck_log_prob(LowRankMultivariateNormal,
(mean_multi_batch, cov_factor_batched, cov_diag_batched))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_lowrank_multivariate_normal_log_prob(self):
mean = torch.randn(3, requires_grad=True)
cov_factor = torch.randn(3, 1, requires_grad=True)
cov_diag = torch.randn(3).abs().requires_grad_()
cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
# check that logprob values match scipy logpdf,
# and that covariance and scale_tril parameters are equivalent
dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
ref_dist = scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy())
x = dist1.sample((10,))
expected = ref_dist.logpdf(x.numpy())
self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
# Double-check that batched versions behave the same as unbatched
mean = torch.randn(5, 3, requires_grad=True)
cov_factor = torch.randn(5, 3, 2, requires_grad=True)
cov_diag = torch.randn(5, 3).abs().requires_grad_()
dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
dist_unbatched = [LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i])
for i in range(mean.size(0))]
x = dist_batched.sample((10,))
batched_prob = dist_batched.log_prob(x)
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
self.assertEqual(batched_prob.shape, unbatched_prob.shape)
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lowrank_multivariate_normal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
mean = torch.randn(5, requires_grad=True)
cov_factor = torch.randn(5, 1, requires_grad=True)
cov_diag = torch.randn(5).abs().requires_grad_()
cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
self._check_sampler_sampler(LowRankMultivariateNormal(mean, cov_factor, cov_diag),
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})'
.format(mean, cov_factor, cov_diag), multivariate=True)
def test_lowrank_multivariate_normal_properties(self):
loc = torch.randn(5)
cov_factor = torch.randn(5, 2)
cov_diag = torch.randn(5).abs()
cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag)
m2 = MultivariateNormal(loc=loc, covariance_matrix=cov)
self.assertEqual(m1.mean, m2.mean)
self.assertEqual(m1.variance, m2.variance)
self.assertEqual(m1.covariance_matrix, m2.covariance_matrix)
self.assertEqual(m1.scale_tril, m2.scale_tril)
self.assertEqual(m1.precision_matrix, m2.precision_matrix)
self.assertEqual(m1.entropy(), m2.entropy())
def test_lowrank_multivariate_normal_moments(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
mean = torch.randn(5)
cov_factor = torch.randn(5, 2)
cov_diag = torch.randn(5).abs()
d = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
samples = d.rsample((100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0)
def test_multivariate_normal_shape(self):
mean = torch.randn(5, 3, requires_grad=True)
mean_no_batch = torch.randn(3, requires_grad=True)
mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
# construct PSD covariance
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
# construct batch of PSD covariances
tmp = torch.randn(6, 5, 3, 10)
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
prec_batched = cov_batched.inverse()
scale_tril_batched = torch.linalg.cholesky(cov_batched)
# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, 3))
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
# check gradients
# We write a custom gradcheck function to maintain the symmetry
# of the perturbed covariances and their inverses (precision)
def multivariate_normal_log_prob_gradcheck(mean, covariance=None, precision=None, scale_tril=None):
mvn_samples = MultivariateNormal(mean, covariance, precision, scale_tril).sample().requires_grad_()
def gradcheck_func(samples, mu, sigma, prec, scale_tril):
if sigma is not None:
sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance
if prec is not None:
prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision
if scale_tril is not None:
scale_tril = scale_tril.tril()
return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples)
gradcheck(gradcheck_func, (mvn_samples, mean, covariance, precision, scale_tril), raise_exception=True)
multivariate_normal_log_prob_gradcheck(mean, cov)
multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov)
multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov_batched)
multivariate_normal_log_prob_gradcheck(mean, None, prec)
multivariate_normal_log_prob_gradcheck(mean_no_batch, None, prec_batched)
multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril)
multivariate_normal_log_prob_gradcheck(mean_no_batch, None, None, scale_tril_batched)
def test_multivariate_normal_stable_with_precision_matrix(self):
x = torch.randn(10)
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) # RBF kernel
MultivariateNormal(x.new_zeros(10), precision_matrix=P)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_multivariate_normal_log_prob(self):
mean = torch.randn(3, requires_grad=True)
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
# check that logprob values match scipy logpdf,
# and that covariance and scale_tril parameters are equivalent
dist1 = MultivariateNormal(mean, cov)
dist2 = MultivariateNormal(mean, precision_matrix=prec)
dist3 = MultivariateNormal(mean, scale_tril=scale_tril)
ref_dist = scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy())
x = dist1.sample((10,))
expected = ref_dist.logpdf(x.numpy())
self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
# Double-check that batched versions behave the same as unbatched
mean = torch.randn(5, 3, requires_grad=True)
tmp = torch.randn(5, 3, 10)
cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
dist_batched = MultivariateNormal(mean, cov)
dist_unbatched = [MultivariateNormal(mean[i], cov[i]) for i in range(mean.size(0))]
x = dist_batched.sample((10,))
batched_prob = dist_batched.log_prob(x)
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
self.assertEqual(batched_prob.shape, unbatched_prob.shape)
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_multivariate_normal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
mean = torch.randn(3, requires_grad=True)
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
self._check_sampler_sampler(MultivariateNormal(mean, cov),
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
'MultivariateNormal(loc={}, cov={})'.format(mean, cov),
multivariate=True)
self._check_sampler_sampler(MultivariateNormal(mean, precision_matrix=prec),
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
'MultivariateNormal(loc={}, atol={})'.format(mean, prec),
multivariate=True)
self._check_sampler_sampler(MultivariateNormal(mean, scale_tril=scale_tril),
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
'MultivariateNormal(loc={}, scale_tril={})'.format(mean, scale_tril),
multivariate=True)
def test_multivariate_normal_properties(self):
loc = torch.randn(5)
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
m = MultivariateNormal(loc=loc, scale_tril=scale_tril)
self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
def test_multivariate_normal_moments(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
mean = torch.randn(5)
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
d = MultivariateNormal(mean, scale_tril=scale_tril)
samples = d.rsample((100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)
# We applied same tests in Multivariate Normal distribution for Wishart distribution
def test_wishart_shape(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
ndim = 3
df = torch.rand(5, requires_grad=True) + ndim
df_no_batch = torch.rand([], requires_grad=True) + ndim
df_multi_batch = torch.rand(6, 5, requires_grad=True) + ndim
# construct PSD covariance
tmp = torch.randn(ndim, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
# construct batch of PSD covariances
tmp = torch.randn(6, 5, ndim, 10)
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
prec_batched = cov_batched.inverse()
scale_tril_batched = torch.linalg.cholesky(cov_batched)
# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(Wishart(df, cov).sample().size(), (5, ndim, ndim))
self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (ndim, ndim))
self.assertEqual(Wishart(df_multi_batch, cov).sample().size(), (6, 5, ndim, ndim))
self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, ndim, ndim))
self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, ndim, ndim))
self.assertEqual(Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, ndim, ndim))
self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, ndim, ndim))
self.assertEqual(Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, ndim, ndim))
self.assertEqual(Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim))
self.assertEqual(Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim))
self.assertEqual(Wishart(df_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim))
self.assertEqual(Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim))
self.assertEqual(Wishart(df, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, ndim, ndim))
self.assertEqual(Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim))
self.assertEqual(Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, ndim, ndim))
self.assertEqual(Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim))
# check gradients
# Modified and applied the same tests for multivariate_normal
def wishart_log_prob_gradcheck(df=None, covariance=None, precision=None, scale_tril=None):
wishart_samples = Wishart(df, covariance, precision, scale_tril).sample().requires_grad_()
def gradcheck_func(samples, nu, sigma, prec, scale_tril):
if sigma is not None:
sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance
if prec is not None:
prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision
if scale_tril is not None:
scale_tril = scale_tril.tril()
return Wishart(nu, sigma, prec, scale_tril).log_prob(samples)
gradcheck(gradcheck_func, (wishart_samples, df, covariance, precision, scale_tril), raise_exception=True)
wishart_log_prob_gradcheck(df, cov)
wishart_log_prob_gradcheck(df_multi_batch, cov)
wishart_log_prob_gradcheck(df_multi_batch, cov_batched)
wishart_log_prob_gradcheck(df, None, prec)
wishart_log_prob_gradcheck(df_no_batch, None, prec_batched)
wishart_log_prob_gradcheck(df, None, None, scale_tril)
wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched)
def test_wishart_stable_with_precision_matrix(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
ndim = 10
x = torch.randn(ndim)
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) # RBF kernel
Wishart(torch.tensor(ndim), precision_matrix=P)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_wishart_log_prob(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
ndim = 3
df = torch.rand([], requires_grad=True) + ndim - 1
# SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
if version.parse(scipy.__version__) < version.parse("1.7.0"):
df += 1.
tmp = torch.randn(ndim, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
# check that logprob values match scipy logpdf,
# and that covariance and scale_tril parameters are equivalent
dist1 = Wishart(df, cov)
dist2 = Wishart(df, precision_matrix=prec)
dist3 = Wishart(df, scale_tril=scale_tril)
ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
x = dist1.sample((1000,))
expected = ref_dist.logpdf(x.transpose(0, 2).numpy())
self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
# Double-check that batched versions behave the same as unbatched
df = torch.rand(5, requires_grad=True) + ndim - 1
# SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
if version.parse(scipy.__version__) < version.parse("1.7.0"):
df += 1.
tmp = torch.randn(5, ndim, 10)
cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
dist_batched = Wishart(df, cov)
dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))]
x = dist_batched.sample((1000,))
batched_prob = dist_batched.log_prob(x)
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
self.assertEqual(batched_prob.shape, unbatched_prob.shape)
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_wishart_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
ndim = 3
df = torch.rand([], requires_grad=True) + ndim - 1
# SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
if version.parse(scipy.__version__) < version.parse("1.7.0"):
df += 1.
tmp = torch.randn(ndim, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
self._check_sampler_sampler(Wishart(df, cov),
ref_dist,
'Wishart(df={}, covariance_matrix={})'.format(df, cov),
multivariate=True)
self._check_sampler_sampler(Wishart(df, precision_matrix=prec),
ref_dist,
'Wishart(df={}, precision_matrix={})'.format(df, prec),
multivariate=True)
self._check_sampler_sampler(Wishart(df, scale_tril=scale_tril),
ref_dist,
'Wishart(df={}, scale_tril={})'.format(df, scale_tril),
multivariate=True)
def test_wishart_properties(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
ndim = 5
df = torch.rand([]) + ndim - 1
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim))
m = Wishart(df=df, scale_tril=scale_tril)
self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
def test_wishart_moments(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
ndim = 3
df = torch.rand([]) + ndim - 1
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim))
d = Wishart(df=df, scale_tril=scale_tril)
samples = d.rsample((ndim * ndim * 100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=0.5, rtol=0)
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0)
def test_exponential(self):
rate = torch.randn(5, 5).abs().requires_grad_()
rate_1d = torch.randn(1).abs().requires_grad_()
self.assertEqual(Exponential(rate).sample().size(), (5, 5))
self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
self._gradcheck_log_prob(Exponential, (rate,))
state = torch.get_rng_state()
eps = rate.new(rate.size()).exponential_()
torch.set_rng_state(state)
z = Exponential(rate).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(rate.grad, -eps / rate**2)
rate.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = rate.view(-1)[idx]
expected = math.log(m) - m * x
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Exponential(rate), ref_log_prob)
self._check_forward_ad(lambda x: x.exponential_())
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_exponential_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for rate in [1e-5, 1.0, 10.]:
self._check_sampler_sampler(Exponential(rate),
scipy.stats.expon(scale=1. / rate),
'Exponential(rate={})'.format(rate))
def test_laplace(self):
loc = torch.randn(5, 5, requires_grad=True)
scale = torch.randn(5, 5).abs().requires_grad_()
loc_1d = torch.randn(1, requires_grad=True)
scale_1d = torch.randn(1, requires_grad=True)
loc_delta = torch.tensor([1.0, 0.0])
scale_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5))
self.assertEqual(Laplace(loc, scale).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Laplace(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,))
self.assertEqual(Laplace(0.2, .6).sample((1,)).size(), (1,))
self.assertEqual(Laplace(-0.7, 50.0).sample((1,)).size(), (1,))
# sample check for extreme value of mean, std
set_rng_seed(0)
self.assertEqual(Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
atol=1e-4, rtol=0)
self._gradcheck_log_prob(Laplace, (loc, scale))
self._gradcheck_log_prob(Laplace, (loc, 1.0))
self._gradcheck_log_prob(Laplace, (0.0, scale))
state = torch.get_rng_state()
eps = torch.ones_like(loc).uniform_(-.5, .5)
torch.set_rng_state(state)
z = Laplace(loc, scale).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(loc.grad, torch.ones_like(loc))
self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs()))
loc.grad.zero_()
scale.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = loc.view(-1)[idx]
s = scale.view(-1)[idx]
expected = (-math.log(2 * s) - abs(x - m) / s)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Laplace(loc, scale), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_laplace_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Laplace(loc, scale),
scipy.stats.laplace(loc=loc, scale=scale),
'Laplace(loc={}, scale={})'.format(loc, scale))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma_shape(self):
alpha = torch.randn(2, 3).exp().requires_grad_()
beta = torch.randn(2, 3).exp().requires_grad_()
alpha_1d = torch.randn(1).exp().requires_grad_()
beta_1d = torch.randn(1).exp().requires_grad_()
self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
def ref_log_prob(idx, x, log_prob):
a = alpha.view(-1)[idx].detach()
b = beta.view(-1)[idx].detach()
expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma_gpu_shape(self):
alpha = torch.randn(2, 3).cuda().exp().requires_grad_()
beta = torch.randn(2, 3).cuda().exp().requires_grad_()
alpha_1d = torch.randn(1).cuda().exp().requires_grad_()
beta_1d = torch.randn(1).cuda().exp().requires_grad_()
self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
def ref_log_prob(idx, x, log_prob):
a = alpha.view(-1)[idx].detach().cpu()
b = beta.view(-1)[idx].detach().cpu()
expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Gamma(alpha, beta),
scipy.stats.gamma(alpha, scale=1.0 / beta),
'Gamma(concentration={}, rate={})'.format(alpha, beta))
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_gpu_sample(self):
set_rng_seed(0)
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda()
self._check_sampler_sampler(Gamma(a, b),
scipy.stats.gamma(alpha, scale=1.0 / beta),
'Gamma(alpha={}, beta={})'.format(alpha, beta),
failure_rate=1e-4)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_pareto(self):
scale = torch.randn(2, 3).abs().requires_grad_()
alpha = torch.randn(2, 3).abs().requires_grad_()
scale_1d = torch.randn(1).abs().requires_grad_()
alpha_1d = torch.randn(1).abs().requires_grad_()
self.assertEqual(Pareto(scale_1d, 0.5).mean, inf)
self.assertEqual(Pareto(scale_1d, 0.5).variance, inf)
self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Pareto(scale_1d, alpha_1d).sample().size(), (1,))
self.assertEqual(Pareto(1.0, 1.0).sample().size(), ())
self.assertEqual(Pareto(1.0, 1.0).sample((1,)).size(), (1,))
def ref_log_prob(idx, x, log_prob):
s = scale.view(-1)[idx].detach()
a = alpha.view(-1)[idx].detach()
expected = scipy.stats.pareto.logpdf(x, a, scale=s)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Pareto(scale, alpha), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_pareto_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Pareto(scale, alpha),
scipy.stats.pareto(alpha, scale=scale),
'Pareto(scale={}, alpha={})'.format(scale, alpha))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gumbel(self):
loc = torch.randn(2, 3, requires_grad=True)
scale = torch.randn(2, 3).abs().requires_grad_()
loc_1d = torch.randn(1, requires_grad=True)
scale_1d = torch.randn(1).abs().requires_grad_()
self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3))
self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,))
self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
def ref_log_prob(idx, x, log_prob):
l = loc.view(-1)[idx].detach()
s = scale.view(-1)[idx].detach()
expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Gumbel(loc, scale), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gumbel_sample(self):
set_rng_seed(1) # see note [Randomized statistical tests]
for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Gumbel(loc, scale),
scipy.stats.gumbel_r(loc=loc, scale=scale),
'Gumbel(loc={}, scale={})'.format(loc, scale))
def test_kumaraswamy_shape(self):
concentration1 = torch.randn(2, 3).abs().requires_grad_()
concentration0 = torch.randn(2, 3).abs().requires_grad_()
concentration1_1d = torch.randn(1).abs().requires_grad_()
concentration0_1d = torch.randn(1).abs().requires_grad_()
self.assertEqual(Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3))
self.assertEqual(Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,))
self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ())
self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,))
# Kumaraswamy distribution is not implemented in SciPy
# Hence these tests are explicit
def test_kumaraswamy_mean_variance(self):
c1_1 = torch.randn(2, 3).abs().requires_grad_()
c0_1 = torch.randn(2, 3).abs().requires_grad_()
c1_2 = torch.randn(4).abs().requires_grad_()
c0_2 = torch.randn(4).abs().requires_grad_()
cases = [(c1_1, c0_1), (c1_2, c0_2)]
for i, (a, b) in enumerate(cases):
m = Kumaraswamy(a, b)
samples = m.sample((60000, ))
expected = samples.mean(0)
actual = m.mean
error = (expected - actual).abs()
max_error = max(error[error == error])
self.assertLess(max_error, 0.01,
"Kumaraswamy example {}/{}, incorrect .mean".format(i + 1, len(cases)))
expected = samples.var(0)
actual = m.variance
error = (expected - actual).abs()
max_error = max(error[error == error])
self.assertLess(max_error, 0.01,
"Kumaraswamy example {}/{}, incorrect .variance".format(i + 1, len(cases)))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_fishersnedecor(self):
df1 = torch.randn(2, 3).abs().requires_grad_()
df2 = torch.randn(2, 3).abs().requires_grad_()
df1_1d = torch.randn(1).abs()
df2_1d = torch.randn(1).abs()
self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean))
self.assertTrue(is_all_nan(FisherSnedecor(1, 4).variance))
self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3))
self.assertEqual(FisherSnedecor(df1, df2).sample((5,)).size(), (5, 2, 3))
self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,))
self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample((1,)).size(), (1, 1))
self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), ())
self.assertEqual(FisherSnedecor(1.0, 1.0).sample((1,)).size(), (1,))
def ref_log_prob(idx, x, log_prob):
f1 = df1.view(-1)[idx].detach()
f2 = df2.view(-1)[idx].detach()
expected = scipy.stats.f.logpdf(x, f1, f2)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_fishersnedecor_sample(self):
set_rng_seed(1) # see note [Randomized statistical tests]
for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]):
self._check_sampler_sampler(FisherSnedecor(df1, df2),
scipy.stats.f(df1, df2),
'FisherSnedecor(loc={}, scale={})'.format(df1, df2))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_chi2_shape(self):
df = torch.randn(2, 3).exp().requires_grad_()
df_1d = torch.randn(1).exp().requires_grad_()
self.assertEqual(Chi2(df).sample().size(), (2, 3))
self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Chi2(df_1d).sample().size(), (1,))
self.assertEqual(Chi2(torch.tensor(0.5, requires_grad=True)).sample().size(), ())
self.assertEqual(Chi2(0.5).sample().size(), ())
self.assertEqual(Chi2(0.5).sample((1,)).size(), (1,))
def ref_log_prob(idx, x, log_prob):
d = df.view(-1)[idx].detach()
expected = scipy.stats.chi2.logpdf(x, d)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(Chi2(df), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_chi2_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for df in [0.1, 1.0, 5.0]:
self._check_sampler_sampler(Chi2(df),
scipy.stats.chi2(df),
'Chi2(df={})'.format(df))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_studentT(self):
df = torch.randn(2, 3).exp().requires_grad_()
df_1d = torch.randn(1).exp().requires_grad_()
self.assertTrue(is_all_nan(StudentT(1).mean))
self.assertTrue(is_all_nan(StudentT(1).variance))
self.assertEqual(StudentT(2).variance, inf)
self.assertEqual(StudentT(df).sample().size(), (2, 3))
self.assertEqual(StudentT(df).sample((5,)).size(), (5, 2, 3))
self.assertEqual(StudentT(df_1d).sample((1,)).size(), (1, 1))
self.assertEqual(StudentT(df_1d).sample().size(), (1,))
self.assertEqual(StudentT(torch.tensor(0.5, requires_grad=True)).sample().size(), ())
self.assertEqual(StudentT(0.5).sample().size(), ())
self.assertEqual(StudentT(0.5).sample((1,)).size(), (1,))
def ref_log_prob(idx, x, log_prob):
d = df.view(-1)[idx].detach()
expected = scipy.stats.t.logpdf(x, d)
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
self._check_log_prob(StudentT(df), ref_log_prob)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_studentT_sample(self):
set_rng_seed(11) # see Note [Randomized statistical tests]
for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(StudentT(df=df, loc=loc, scale=scale),
scipy.stats.t(df=df, loc=loc, scale=scale),
'StudentT(df={}, loc={}, scale={})'.format(df, loc, scale))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_studentT_log_prob(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
num_samples = 10
for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
dist = StudentT(df=df, loc=loc, scale=scale)
x = dist.sample((num_samples,))
actual_log_prob = dist.log_prob(x)
for i in range(num_samples):
expected_log_prob = scipy.stats.t.logpdf(x[i], df=df, loc=loc, scale=scale)
self.assertEqual(float(actual_log_prob[i]), float(expected_log_prob), atol=1e-3, rtol=0)
def test_dirichlet_shape(self):
alpha = torch.randn(2, 3).exp().requires_grad_()
alpha_1d = torch.randn(4).exp().requires_grad_()
self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,))
self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_dirichlet_log_prob(self):
num_samples = 10
alpha = torch.exp(torch.randn(5))
dist = Dirichlet(alpha)
x = dist.sample((num_samples,))
actual_log_prob = dist.log_prob(x)
for i in range(num_samples):
expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_dirichlet_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
alpha = torch.exp(torch.randn(3))
self._check_sampler_sampler(Dirichlet(alpha),
scipy.stats.dirichlet(alpha.numpy()),
'Dirichlet(alpha={})'.format(list(alpha)),
multivariate=True)
def test_dirichlet_mode(self):
# Test a few edge cases for the Dirichlet distribution mode. This also covers beta distributions.
concentrations_and_modes = [
([2, 2, 1], [.5, .5, 0.]),
([3, 2, 1], [2 / 3, 1 / 3, 0]),
([.5, .2, .2], [1., 0., 0.]),
([1, 1, 1], [nan, nan, nan]),
]
for concentration, mode in concentrations_and_modes:
dist = Dirichlet(torch.tensor(concentration))
self.assertEqual(dist.mode, torch.tensor(mode))
def test_beta_shape(self):
con1 = torch.randn(2, 3).exp().requires_grad_()
con0 = torch.randn(2, 3).exp().requires_grad_()
con1_1d = torch.randn(4).exp().requires_grad_()
con0_1d = torch.randn(4).exp().requires_grad_()
self.assertEqual(Beta(con1, con0).sample().size(), (2, 3))
self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,))
self.assertEqual(Beta(con1_1d, con0_1d).sample((1,)).size(), (1, 4))
self.assertEqual(Beta(0.1, 0.3).sample().size(), ())
self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_beta_log_prob(self):
for _ in range(100):
con1 = np.exp(np.random.normal())
con0 = np.exp(np.random.normal())
dist = Beta(con1, con0)
x = dist.sample()
actual_log_prob = dist.log_prob(x).sum()
expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0)
self.assertEqual(float(actual_log_prob), float(expected_log_prob), atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_beta_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Beta(con1, con0),
scipy.stats.beta(con1, con0),
'Beta(alpha={}, beta={})'.format(con1, con0))
# Check that small alphas do not cause NANs.
for Tensor in [torch.FloatTensor, torch.DoubleTensor]:
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))
def test_beta_underflow(self):
# For low values of (alpha, beta), the gamma samples can underflow
# with float32 and result in a spurious mode at 0.5. To prevent this,
# torch._sample_dirichlet works with double precision for intermediate
# calculations.
set_rng_seed(1)
num_samples = 50000
for dtype in [torch.float, torch.double]:
conc = torch.tensor(1e-2, dtype=dtype)
beta_samples = Beta(conc, conc).sample([num_samples])
self.assertEqual((beta_samples == 0).sum(), 0)
self.assertEqual((beta_samples == 1).sum(), 0)
# assert support is concentrated around 0 and 1
frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
frac_ones = float((beta_samples > 0.9).sum()) / num_samples
self.assertEqual(frac_zeros, 0.5, atol=0.05, rtol=0)
self.assertEqual(frac_ones, 0.5, atol=0.05, rtol=0)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_beta_underflow_gpu(self):
set_rng_seed(1)
num_samples = 50000
conc = torch.tensor(1e-2, dtype=torch.float64).cuda()
beta_samples = Beta(conc, conc).sample([num_samples])
self.assertEqual((beta_samples == 0).sum(), 0)
self.assertEqual((beta_samples == 1).sum(), 0)
# assert support is concentrated around 0 and 1
frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
frac_ones = float((beta_samples > 0.9).sum()) / num_samples
# TODO: increase precision once imbalance on GPU is fixed.
self.assertEqual(frac_zeros, 0.5, atol=0.12, rtol=0)
self.assertEqual(frac_ones, 0.5, atol=0.12, rtol=0)
def test_continuous_bernoulli(self):
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
r = torch.tensor(0.3, requires_grad=True)
s = 0.3
self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3))
self.assertFalse(ContinuousBernoulli(p).sample().requires_grad)
self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,))
self.assertEqual(ContinuousBernoulli(r).sample().size(), ())
self.assertEqual(ContinuousBernoulli(r).sample((3, 2)).size(), (3, 2,))
self.assertEqual(ContinuousBernoulli(s).sample().size(), ())
self._gradcheck_log_prob(ContinuousBernoulli, (p,))
def ref_log_prob(idx, val, log_prob):
prob = p[idx]
if prob > 0.499 and prob < 0.501: # using default value of lim here
log_norm_const = math.log(2.) + 4. / 3. * math.pow(prob - 0.5, 2) + 104. / 45. * math.pow(prob - 0.5, 4)
else:
log_norm_const = math.log(2. * math.atanh(1. - 2. * prob) / (1. - 2.0 * prob))
res = val * math.log(prob) + (1. - val) * math.log1p(-prob) + log_norm_const
self.assertEqual(log_prob, res)
self._check_log_prob(ContinuousBernoulli(p), ref_log_prob)
self._check_log_prob(ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
# check entropy computation
self.assertEqual(ContinuousBernoulli(p).entropy(), torch.tensor([-0.02938, -0.07641, -0.00682]), atol=1e-4, rtol=0)
# entropy below corresponds to the clamped value of prob when using float 64
# the value for float32 should be -1.76898
self.assertEqual(ContinuousBernoulli(torch.tensor([0.0])).entropy(), torch.tensor([-2.58473]), atol=1e-5, rtol=0)
self.assertEqual(ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), atol=1e-4, rtol=0)
def test_continuous_bernoulli_3d(self):
p = torch.full((2, 3, 5), 0.5).requires_grad_()
self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5))
self.assertEqual(ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(),
(2, 5, 2, 3, 5))
self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
def test_lkj_cholesky_log_prob(self):
def tril_cholesky_to_tril_corr(x):
x = vec_to_tril_matrix(x, -1)
diag = (1 - (x * x).sum(-1)).sqrt().diag_embed()
x = x + diag
return tril_matrix_to_vec(x @ x.T, -1)
for dim in range(2, 5):
log_probs = []
lkj = LKJCholesky(dim, concentration=1., validate_args=True)
for i in range(2):
sample = lkj.sample()
sample_tril = tril_matrix_to_vec(sample, diag=-1)
log_prob = lkj.log_prob(sample)
log_abs_det_jacobian = torch.slogdet(jacobian(tril_cholesky_to_tril_corr, sample_tril)).logabsdet
log_probs.append(log_prob - log_abs_det_jacobian)
# for concentration=1., the density is uniform over the space of all
# correlation matrices.
if dim == 2:
# for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.)
self.assertTrue(all(torch.allclose(x, torch.tensor(0.5).log(), atol=1e-10) for x in log_probs))
self.assertEqual(log_probs[0], log_probs[1])
invalid_sample = torch.cat([sample, sample.new_ones(1, dim)], dim=0)
self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample))
def test_independent_shape(self):
for Dist, params in EXAMPLES:
for param in params:
base_dist = Dist(**param)
x = base_dist.sample()
base_log_prob_shape = base_dist.log_prob(x).shape
for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
indep_log_prob_shape = base_log_prob_shape[:len(base_log_prob_shape) - reinterpreted_batch_ndims]
self.assertEqual(indep_dist.log_prob(x).shape, indep_log_prob_shape)
self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample)
if indep_dist.has_rsample:
self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
try:
self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape)
self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape)
except NotImplementedError:
pass
try:
self.assertEqual(indep_dist.variance.shape, base_dist.variance.shape)
except NotImplementedError:
pass
try:
self.assertEqual(indep_dist.entropy().shape, indep_log_prob_shape)
except NotImplementedError:
pass
def test_independent_expand(self):
for Dist, params in EXAMPLES:
for param in params:
base_dist = Dist(**param)
for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
for s in [torch.Size(), torch.Size((2,)), torch.Size((2, 3))]:
indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
expanded_shape = s + indep_dist.batch_shape
expanded = indep_dist.expand(expanded_shape)
expanded_sample = expanded.sample()
expected_shape = expanded_shape + indep_dist.event_shape
self.assertEqual(expanded_sample.shape, expected_shape)
self.assertEqual(expanded.log_prob(expanded_sample),
indep_dist.log_prob(expanded_sample))
self.assertEqual(expanded.event_shape, indep_dist.event_shape)
self.assertEqual(expanded.batch_shape, expanded_shape)
def test_cdf_icdf_inverse(self):
# Tests the invertibility property on the distributions
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
samples = dist.sample(sample_shape=(20,))
try:
cdf = dist.cdf(samples)
actual = dist.icdf(cdf)
except NotImplementedError:
continue
rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples))
self.assertLess(rel_error.max(), 1e-4, msg='\n'.join([
'{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)),
'x = {}'.format(samples),
'cdf(x) = {}'.format(cdf),
'icdf(cdf(x)) = {}'.format(actual),
]))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma_log_prob_at_boundary(self):
for concentration, log_prob in [(.5, inf), (1, 0), (2, -inf)]:
dist = Gamma(concentration, 1)
scipy_dist = scipy.stats.gamma(concentration)
self.assertAlmostEqual(dist.log_prob(0), log_prob)
self.assertAlmostEqual(dist.log_prob(0), scipy_dist.logpdf(0))
def test_cdf_log_prob(self):
# Tests if the differentiation of the CDF gives the PDF at a given value
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
samples = dist.sample()
if not dist.support.is_discrete:
samples.requires_grad_()
try:
cdfs = dist.cdf(samples)
pdfs = dist.log_prob(samples).exp()
except NotImplementedError:
continue
cdfs_derivative = grad(cdfs.sum(), [samples])[0] # this should not be wrapped in torch.abs()
self.assertEqual(cdfs_derivative, pdfs, msg='\n'.join([
'{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)),
'x = {}'.format(samples),
'cdf = {}'.format(cdfs),
'pdf = {}'.format(pdfs),
'grad(cdf) = {}'.format(cdfs_derivative),
]))
def test_valid_parameter_broadcasting(self):
# Test correct broadcasting of parameter sizes for distributions that have multiple
# parameters.
# example type (distribution instance, expected sample shape)
valid_examples = [
(Normal(loc=torch.tensor([0., 0.]), scale=1),
(2,)),
(Normal(loc=0, scale=torch.tensor([1., 1.])),
(2,)),
(Normal(loc=torch.tensor([0., 0.]), scale=torch.tensor([1.])),
(2,)),
(Normal(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.], [1.]])),
(2, 2)),
(Normal(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.]])),
(1, 2)),
(Normal(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])),
(1, 1)),
(FisherSnedecor(df1=torch.tensor([1., 1.]), df2=1),
(2,)),
(FisherSnedecor(df1=1, df2=torch.tensor([1., 1.])),
(2,)),
(FisherSnedecor(df1=torch.tensor([1., 1.]), df2=torch.tensor([1.])),
(2,)),
(FisherSnedecor(df1=torch.tensor([1., 1.]), df2=torch.tensor([[1.], [1.]])),
(2, 2)),
(FisherSnedecor(df1=torch.tensor([1., 1.]), df2=torch.tensor([[1.]])),
(1, 2)),
(FisherSnedecor(df1=torch.tensor([1.]), df2=torch.tensor([[1.]])),
(1, 1)),
(Gamma(concentration=torch.tensor([1., 1.]), rate=1),
(2,)),
(Gamma(concentration=1, rate=torch.tensor([1., 1.])),
(2,)),
(Gamma(concentration=torch.tensor([1., 1.]), rate=torch.tensor([[1.], [1.], [1.]])),
(3, 2)),
(Gamma(concentration=torch.tensor([1., 1.]), rate=torch.tensor([[1.], [1.]])),
(2, 2)),
(Gamma(concentration=torch.tensor([1., 1.]), rate=torch.tensor([[1.]])),
(1, 2)),
(Gamma(concentration=torch.tensor([1.]), rate=torch.tensor([[1.]])),
(1, 1)),
(Gumbel(loc=torch.tensor([0., 0.]), scale=1),
(2,)),
(Gumbel(loc=0, scale=torch.tensor([1., 1.])),
(2,)),
(Gumbel(loc=torch.tensor([0., 0.]), scale=torch.tensor([1.])),
(2,)),
(Gumbel(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.], [1.]])),
(2, 2)),
(Gumbel(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.]])),
(1, 2)),
(Gumbel(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])),
(1, 1)),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=1.),
(2,)),
(Kumaraswamy(concentration1=1, concentration0=torch.tensor([1., 1.])),
(2, )),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([1.])),
(2,)),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.], [1.]])),
(2, 2)),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.]])),
(1, 2)),
(Kumaraswamy(concentration1=torch.tensor([1.]), concentration0=torch.tensor([[1.]])),
(1, 1)),
(Laplace(loc=torch.tensor([0., 0.]), scale=1),
(2,)),
(Laplace(loc=0, scale=torch.tensor([1., 1.])),
(2,)),
(Laplace(loc=torch.tensor([0., 0.]), scale=torch.tensor([1.])),
(2,)),
(Laplace(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.], [1.]])),
(2, 2)),
(Laplace(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.]])),
(1, 2)),
(Laplace(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])),
(1, 1)),
(Pareto(scale=torch.tensor([1., 1.]), alpha=1),
(2,)),
(Pareto(scale=1, alpha=torch.tensor([1., 1.])),
(2,)),
(Pareto(scale=torch.tensor([1., 1.]), alpha=torch.tensor([1.])),
(2,)),
(Pareto(scale=torch.tensor([1., 1.]), alpha=torch.tensor([[1.], [1.]])),
(2, 2)),
(Pareto(scale=torch.tensor([1., 1.]), alpha=torch.tensor([[1.]])),
(1, 2)),
(Pareto(scale=torch.tensor([1.]), alpha=torch.tensor([[1.]])),
(1, 1)),
(StudentT(df=torch.tensor([1., 1.]), loc=1),
(2,)),
(StudentT(df=1, scale=torch.tensor([1., 1.])),
(2,)),
(StudentT(df=torch.tensor([1., 1.]), loc=torch.tensor([1.])),
(2,)),
(StudentT(df=torch.tensor([1., 1.]), scale=torch.tensor([[1.], [1.]])),
(2, 2)),
(StudentT(df=torch.tensor([1., 1.]), loc=torch.tensor([[1.]])),
(1, 2)),
(StudentT(df=torch.tensor([1.]), scale=torch.tensor([[1.]])),
(1, 1)),
(StudentT(df=1., loc=torch.zeros(5, 1), scale=torch.ones(3)),
(5, 3)),
]
for dist, expected_size in valid_examples:
actual_size = dist.sample().size()
self.assertEqual(actual_size, expected_size,
msg='{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
sample_shape = torch.Size((2,))
expected_size = sample_shape + expected_size
actual_size = dist.sample(sample_shape).size()
self.assertEqual(actual_size, expected_size,
msg='{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
def test_invalid_parameter_broadcasting(self):
# invalid broadcasting cases; should throw error
# example type (distribution class, distribution params)
invalid_examples = [
(Normal, {
'loc': torch.tensor([[0, 0]]),
'scale': torch.tensor([1, 1, 1, 1])
}),
(Normal, {
'loc': torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
'scale': torch.tensor([1, 1])
}),
(FisherSnedecor, {
'df1': torch.tensor([1, 1]),
'df2': torch.tensor([1, 1, 1]),
}),
(Gumbel, {
'loc': torch.tensor([[0, 0]]),
'scale': torch.tensor([1, 1, 1, 1])
}),
(Gumbel, {
'loc': torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
'scale': torch.tensor([1, 1])
}),
(Gamma, {
'concentration': torch.tensor([0, 0]),
'rate': torch.tensor([1, 1, 1])
}),
(Kumaraswamy, {
'concentration1': torch.tensor([[1, 1]]),
'concentration0': torch.tensor([1, 1, 1, 1])
}),
(Kumaraswamy, {
'concentration1': torch.tensor([[[1, 1, 1], [1, 1, 1]]]),
'concentration0': torch.tensor([1, 1])
}),
(Laplace, {
'loc': torch.tensor([0, 0]),
'scale': torch.tensor([1, 1, 1])
}),
(Pareto, {
'scale': torch.tensor([1, 1]),
'alpha': torch.tensor([1, 1, 1])
}),
(StudentT, {
'df': torch.tensor([1., 1.]),
'scale': torch.tensor([1., 1., 1.])
}),
(StudentT, {
'df': torch.tensor([1., 1.]),
'loc': torch.tensor([1., 1., 1.])
})
]
for dist, kwargs in invalid_examples:
self.assertRaises(RuntimeError, dist, **kwargs)
def _test_discrete_distribution_mode(self, dist, sanitized_mode, batch_isfinite):
# We cannot easily check the mode for discrete distributions, but we can look left and right
# to ensure the log probability is smaller than at the mode.
for step in [-1, 1]:
log_prob_mode = dist.log_prob(sanitized_mode)
if isinstance(dist, OneHotCategorical):
idx = (dist._categorical.mode + 1) % dist.probs.shape[-1]
other = torch.nn.functional.one_hot(idx, num_classes=dist.probs.shape[-1]).to(dist.mode)
else:
other = dist.mode + step
mask = batch_isfinite & dist.support.check(other)
self.assertTrue(mask.any() or dist.mode.unique().numel() == 1)
# Add a dimension to the right if the event shape is not a scalar, e.g. OneHotCategorical.
other = torch.where(mask[..., None] if mask.ndim < other.ndim else mask, other, dist.sample())
log_prob_other = dist.log_prob(other)
delta = log_prob_mode - log_prob_other
self.assertTrue((-1e-12 < delta[mask].detach()).all()) # Allow up to 1e-12 rounding error.
def _test_continuous_distribution_mode(self, dist, sanitized_mode, batch_isfinite):
if isinstance(dist, Wishart):
return
# We perturb the mode in the unconstrained space and expect the log probability to decrease.
num_points = 10
transform = transform_to(dist.support)
unconstrained_mode = transform.inv(sanitized_mode)
perturbation = 1e-5 * (torch.rand((num_points,) + unconstrained_mode.shape) - 0.5)
perturbed_mode = transform(perturbation + unconstrained_mode)
log_prob_mode = dist.log_prob(sanitized_mode)
log_prob_other = dist.log_prob(perturbed_mode)
delta = log_prob_mode - log_prob_other
# We pass the test with a small tolerance to allow for rounding and manually set the
# difference to zero if both log probs are infinite with the same sign.
both_infinite_with_same_sign = (log_prob_mode == log_prob_other) & (log_prob_mode.abs() == inf)
delta[both_infinite_with_same_sign] = 0.
ordering = (delta > -1e-12).all(axis=0)
self.assertTrue(ordering[batch_isfinite].all())
def test_mode(self):
discrete_distributions = (
Bernoulli, Binomial, Categorical, Geometric, NegativeBinomial, OneHotCategorical, Poisson,
)
no_mode_available = (
ContinuousBernoulli, LKJCholesky, LogisticNormal, MixtureSameFamily, Multinomial,
RelaxedBernoulli, RelaxedOneHotCategorical,
)
for dist_cls, params in EXAMPLES:
for param in params:
dist = dist_cls(**param)
if isinstance(dist, no_mode_available) or type(dist) is TransformedDistribution:
with self.assertRaises(NotImplementedError):
dist.mode
continue
# Check that either all or no elements in the event shape are nan: the mode cannot be
# defined for part of an event.
isfinite = dist.mode.isfinite().reshape(dist.batch_shape + (dist.event_shape.numel(),))
batch_isfinite = isfinite.all(axis=-1)
self.assertTrue((batch_isfinite | ~isfinite.any(axis=-1)).all())
# We sanitize undefined modes by sampling from the distribution.
sanitized_mode = torch.where(~dist.mode.isnan(), dist.mode, dist.sample())
if isinstance(dist, discrete_distributions):
self._test_discrete_distribution_mode(dist, sanitized_mode, batch_isfinite)
else:
self._test_continuous_distribution_mode(dist, sanitized_mode, batch_isfinite)
self.assertFalse(dist.log_prob(sanitized_mode).isnan().any())
# These tests are only needed for a few distributions that implement custom
# reparameterized gradients. Most .rsample() implementations simply rely on
# the reparameterization trick and do not need to be tested for accuracy.
class TestRsample(DistributionsTestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma(self):
num_samples = 100
for alpha in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
alphas = torch.tensor([alpha] * num_samples, dtype=torch.float, requires_grad=True)
betas = alphas.new_ones(num_samples)
x = Gamma(alphas, betas).rsample()
x.sum().backward()
x, ind = x.sort()
x = x.detach().numpy()
actual_grad = alphas.grad[ind].numpy()
# Compare with expected gradient dx/dalpha along constant cdf(x,alpha).
cdf = scipy.stats.gamma.cdf
pdf = scipy.stats.gamma.pdf
eps = 0.01 * alpha / (1.0 + alpha ** 0.5)
cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps)
cdf_x = pdf(x, alpha)
expected_grad = -cdf_alpha / cdf_x
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
self.assertLess(np.max(rel_error), 0.0005, '\n'.join([
'Bad gradient dx/alpha for x ~ Gamma({}, 1)'.format(alpha),
'x {}'.format(x),
'expected {}'.format(expected_grad),
'actual {}'.format(actual_grad),
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max()),
'at alpha={}, x={}'.format(alpha, x[rel_error.argmax()]),
]))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_chi2(self):
num_samples = 100
for df in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
dfs = torch.tensor([df] * num_samples, dtype=torch.float, requires_grad=True)
x = Chi2(dfs).rsample()
x.sum().backward()
x, ind = x.sort()
x = x.detach().numpy()
actual_grad = dfs.grad[ind].numpy()
# Compare with expected gradient dx/ddf along constant cdf(x,df).
cdf = scipy.stats.chi2.cdf
pdf = scipy.stats.chi2.pdf
eps = 0.01 * df / (1.0 + df ** 0.5)
cdf_df = (cdf(x, df + eps) - cdf(x, df - eps)) / (2 * eps)
cdf_x = pdf(x, df)
expected_grad = -cdf_df / cdf_x
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
self.assertLess(np.max(rel_error), 0.001, '\n'.join([
'Bad gradient dx/ddf for x ~ Chi2({})'.format(df),
'x {}'.format(x),
'expected {}'.format(expected_grad),
'actual {}'.format(actual_grad),
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max()),
]))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_dirichlet_on_diagonal(self):
num_samples = 20
grid = [1e-1, 1e0, 1e1]
for a0, a1, a2 in product(grid, grid, grid):
alphas = torch.tensor([[a0, a1, a2]] * num_samples, dtype=torch.float, requires_grad=True)
x = Dirichlet(alphas).rsample()[:, 0]
x.sum().backward()
x, ind = x.sort()
x = x.detach().numpy()
actual_grad = alphas.grad[ind].numpy()[:, 0]
# Compare with expected gradient dx/dalpha0 along constant cdf(x,alpha).
# This reduces to a distribution Beta(alpha[0], alpha[1] + alpha[2]).
cdf = scipy.stats.beta.cdf
pdf = scipy.stats.beta.pdf
alpha, beta = a0, a1 + a2
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / (2 * eps)
cdf_x = pdf(x, alpha, beta)
expected_grad = -cdf_alpha / cdf_x
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
self.assertLess(np.max(rel_error), 0.001, '\n'.join([
'Bad gradient dx[0]/dalpha[0] for Dirichlet([{}, {}, {}])'.format(a0, a1, a2),
'x {}'.format(x),
'expected {}'.format(expected_grad),
'actual {}'.format(actual_grad),
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max()),
'at x={}'.format(x[rel_error.argmax()]),
]))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_beta_wrt_alpha(self):
num_samples = 20
grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
for con1, con0 in product(grid, grid):
con1s = torch.tensor([con1] * num_samples, dtype=torch.float, requires_grad=True)
con0s = con1s.new_tensor([con0] * num_samples)
x = Beta(con1s, con0s).rsample()
x.sum().backward()
x, ind = x.sort()
x = x.detach().numpy()
actual_grad = con1s.grad[ind].numpy()
# Compare with expected gradient dx/dcon1 along constant cdf(x,con1,con0).
cdf = scipy.stats.beta.cdf
pdf = scipy.stats.beta.pdf
eps = 0.01 * con1 / (1.0 + np.sqrt(con1))
cdf_alpha = (cdf(x, con1 + eps, con0) - cdf(x, con1 - eps, con0)) / (2 * eps)
cdf_x = pdf(x, con1, con0)
expected_grad = -cdf_alpha / cdf_x
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
self.assertLess(np.max(rel_error), 0.005, '\n'.join([
'Bad gradient dx/dcon1 for x ~ Beta({}, {})'.format(con1, con0),
'x {}'.format(x),
'expected {}'.format(expected_grad),
'actual {}'.format(actual_grad),
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max()),
'at x = {}'.format(x[rel_error.argmax()]),
]))
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_beta_wrt_beta(self):
num_samples = 20
grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
for con1, con0 in product(grid, grid):
con0s = torch.tensor([con0] * num_samples, dtype=torch.float, requires_grad=True)
con1s = con0s.new_tensor([con1] * num_samples)
x = Beta(con1s, con0s).rsample()
x.sum().backward()
x, ind = x.sort()
x = x.detach().numpy()
actual_grad = con0s.grad[ind].numpy()
# Compare with expected gradient dx/dcon0 along constant cdf(x,con1,con0).
cdf = scipy.stats.beta.cdf
pdf = scipy.stats.beta.pdf
eps = 0.01 * con0 / (1.0 + np.sqrt(con0))
cdf_beta = (cdf(x, con1, con0 + eps) - cdf(x, con1, con0 - eps)) / (2 * eps)
cdf_x = pdf(x, con1, con0)
expected_grad = -cdf_beta / cdf_x
rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
self.assertLess(np.max(rel_error), 0.005, '\n'.join([
'Bad gradient dx/dcon0 for x ~ Beta({}, {})'.format(con1, con0),
'x {}'.format(x),
'expected {}'.format(expected_grad),
'actual {}'.format(actual_grad),
'rel error {}'.format(rel_error),
'max error {}'.format(rel_error.max()),
'at x = {!r}'.format(x[rel_error.argmax()]),
]))
def test_dirichlet_multivariate(self):
alpha_crit = 0.25 * (5.0 ** 0.5 - 1.0)
num_samples = 100000
for shift in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]:
alpha = alpha_crit + shift
alpha = torch.tensor([alpha], dtype=torch.float, requires_grad=True)
alpha_vec = torch.cat([alpha, alpha, alpha.new([1])])
z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample()
mean_z3 = 1.0 / (2.0 * alpha + 1.0)
loss = torch.pow(z[:, 2] - mean_z3, 2.0).mean()
actual_grad = grad(loss, [alpha])[0]
# Compute expected gradient by hand.
num = 1.0 - 2.0 * alpha - 4.0 * alpha**2
den = (1.0 + alpha)**2 * (1.0 + 2.0 * alpha)**3
expected_grad = num / den
self.assertEqual(actual_grad, expected_grad, atol=0.002, rtol=0, msg='\n'.join([
"alpha = alpha_c + %.2g" % shift,
"expected_grad: %.5g" % expected_grad,
"actual_grad: %.5g" % actual_grad,
"error = %.2g" % torch.abs(expected_grad - actual_grad).max(),
]))
def test_dirichlet_tangent_field(self):
num_samples = 20
alpha_grid = [0.5, 1.0, 2.0]
# v = dx/dalpha[0] is the reparameterized gradient aka tangent field.
def compute_v(x, alpha):
return torch.stack([
_Dirichlet_backward(x, alpha, torch.eye(3, 3)[i].expand_as(x))[:, 0]
for i in range(3)
], dim=-1)
for a1, a2, a3 in product(alpha_grid, alpha_grid, alpha_grid):
alpha = torch.tensor([a1, a2, a3], requires_grad=True).expand(num_samples, 3)
x = Dirichlet(alpha).rsample()
dlogp_da = grad([Dirichlet(alpha).log_prob(x.detach()).sum()],
[alpha], retain_graph=True)[0][:, 0]
dlogp_dx = grad([Dirichlet(alpha.detach()).log_prob(x).sum()],
[x], retain_graph=True)[0]
v = torch.stack([grad([x[:, i].sum()], [alpha], retain_graph=True)[0][:, 0]
for i in range(3)], dim=-1)
# Compute ramaining properties by finite difference.
self.assertEqual(compute_v(x, alpha), v, msg='Bug in compute_v() helper')
# dx is an arbitrary orthonormal basis tangent to the simplex.
dx = torch.tensor([[2., -1., -1.], [0., 1., -1.]])
dx /= dx.norm(2, -1, True)
eps = 1e-2 * x.min(-1, True)[0] # avoid boundary
dv0 = (compute_v(x + eps * dx[0], alpha) - compute_v(x - eps * dx[0], alpha)) / (2 * eps)
dv1 = (compute_v(x + eps * dx[1], alpha) - compute_v(x - eps * dx[1], alpha)) / (2 * eps)
div_v = (dv0 * dx[0] + dv1 * dx[1]).sum(-1)
# This is a modification of the standard continuity equation, using the product rule to allow
# expression in terms of log_prob rather than the less numerically stable log_prob.exp().
error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v
self.assertLess(torch.abs(error).max(), 0.005, '\n'.join([
'Dirichlet([{}, {}, {}]) gradient violates continuity equation:'.format(a1, a2, a3),
'error = {}'.format(error),
]))
class TestDistributionShapes(DistributionsTestCase):
def setUp(self):
super(TestDistributionShapes, self).setUp()
self.scalar_sample = 1
self.tensor_sample_1 = torch.ones(3, 2)
self.tensor_sample_2 = torch.ones(3, 2, 3)
def tearDown(self):
super(TestDistributionShapes, self).tearDown()
def test_entropy_shape(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(validate_args=False, **param)
try:
actual_shape = dist.entropy().size()
expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
Dist.__name__, i + 1, len(params), expected_shape, actual_shape)
self.assertEqual(actual_shape, expected_shape, msg=message)
except NotImplementedError:
continue
def test_bernoulli_shape_scalar_params(self):
bernoulli = Bernoulli(0.3)
self.assertEqual(bernoulli._batch_shape, torch.Size())
self.assertEqual(bernoulli._event_shape, torch.Size())
self.assertEqual(bernoulli.sample().size(), torch.Size())
self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, bernoulli.log_prob, self.scalar_sample)
self.assertEqual(bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_bernoulli_shape_tensor_params(self):
bernoulli = Bernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(bernoulli._batch_shape, torch.Size((3, 2)))
self.assertEqual(bernoulli._event_shape, torch.Size(()))
self.assertEqual(bernoulli.sample().size(), torch.Size((3, 2)))
self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2)
self.assertEqual(bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
def test_geometric_shape_scalar_params(self):
geometric = Geometric(0.3)
self.assertEqual(geometric._batch_shape, torch.Size())
self.assertEqual(geometric._event_shape, torch.Size())
self.assertEqual(geometric.sample().size(), torch.Size())
self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, geometric.log_prob, self.scalar_sample)
self.assertEqual(geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(geometric.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_geometric_shape_tensor_params(self):
geometric = Geometric(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(geometric._batch_shape, torch.Size((3, 2)))
self.assertEqual(geometric._event_shape, torch.Size(()))
self.assertEqual(geometric.sample().size(), torch.Size((3, 2)))
self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, geometric.log_prob, self.tensor_sample_2)
self.assertEqual(geometric.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
def test_beta_shape_scalar_params(self):
dist = Beta(0.1, 0.1)
self.assertEqual(dist._batch_shape, torch.Size())
self.assertEqual(dist._event_shape, torch.Size())
self.assertEqual(dist.sample().size(), torch.Size())
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.scalar_sample)
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_beta_shape_tensor_params(self):
dist = Beta(torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
self.assertEqual(dist._batch_shape, torch.Size((3, 2)))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
def test_binomial_shape(self):
dist = Binomial(10, torch.tensor([0.6, 0.3]))
self.assertEqual(dist._batch_shape, torch.Size((2,)))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size((2,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
def test_binomial_shape_vectorized_n(self):
dist = Binomial(torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1]))
self.assertEqual(dist._batch_shape, torch.Size((2, 3)))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size((2, 3)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
def test_multinomial_shape(self):
dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
def test_categorical_shape(self):
# unbatched
dist = Categorical(torch.tensor([0.6, 0.3, 0.1]))
self.assertEqual(dist._batch_shape, torch.Size(()))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size())
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 1)))
# batched
dist = Categorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 3)))
def test_one_hot_categorical_shape(self):
# unbatched
dist = OneHotCategorical(torch.tensor([0.6, 0.3, 0.1]))
self.assertEqual(dist._batch_shape, torch.Size(()))
self.assertEqual(dist._event_shape, torch.Size((3,)))
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
sample = torch.tensor([0., 1., 0.]).expand(3, 2, 3)
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
sample = torch.eye(3)
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
# batched
dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
sample = torch.tensor([0., 1.])
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
sample = torch.tensor([0., 1.]).expand(3, 1, 2)
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 3)))
def test_cauchy_shape_scalar_params(self):
cauchy = Cauchy(0, 1)
self.assertEqual(cauchy._batch_shape, torch.Size())
self.assertEqual(cauchy._event_shape, torch.Size())
self.assertEqual(cauchy.sample().size(), torch.Size())
self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, cauchy.log_prob, self.scalar_sample)
self.assertEqual(cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(cauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_cauchy_shape_tensor_params(self):
cauchy = Cauchy(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
self.assertEqual(cauchy._batch_shape, torch.Size((2,)))
self.assertEqual(cauchy._event_shape, torch.Size(()))
self.assertEqual(cauchy.sample().size(), torch.Size((2,)))
self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2)
self.assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_halfcauchy_shape_scalar_params(self):
halfcauchy = HalfCauchy(1)
self.assertEqual(halfcauchy._batch_shape, torch.Size())
self.assertEqual(halfcauchy._event_shape, torch.Size())
self.assertEqual(halfcauchy.sample().size(), torch.Size())
self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
torch.Size((3, 2)))
self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample)
self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(),
torch.Size((3, 2)))
self.assertEqual(halfcauchy.log_prob(self.tensor_sample_2).size(),
torch.Size((3, 2, 3)))
def test_halfcauchy_shape_tensor_params(self):
halfcauchy = HalfCauchy(torch.tensor([1., 1.]))
self.assertEqual(halfcauchy._batch_shape, torch.Size((2,)))
self.assertEqual(halfcauchy._event_shape, torch.Size(()))
self.assertEqual(halfcauchy.sample().size(), torch.Size((2,)))
self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
torch.Size((3, 2, 2)))
self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(),
torch.Size((3, 2)))
self.assertRaises(ValueError, halfcauchy.log_prob, self.tensor_sample_2)
self.assertEqual(halfcauchy.log_prob(torch.ones(2, 1)).size(),
torch.Size((2, 2)))
def test_dirichlet_shape(self):
dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=True)
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
simplex_sample = torch.ones(3, 1, 2)
simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1)
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
def test_mixture_same_family_shape(self):
dist = MixtureSameFamily(Categorical(torch.rand(5)),
Normal(torch.randn(5), torch.rand(5)))
self.assertEqual(dist._batch_shape, torch.Size())
self.assertEqual(dist._event_shape, torch.Size())
self.assertEqual(dist.sample().size(), torch.Size())
self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_gamma_shape_scalar_params(self):
gamma = Gamma(1, 1)
self.assertEqual(gamma._batch_shape, torch.Size())
self.assertEqual(gamma._event_shape, torch.Size())
self.assertEqual(gamma.sample().size(), torch.Size())
self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertEqual(gamma.log_prob(self.scalar_sample).size(), torch.Size())
self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_gamma_shape_tensor_params(self):
gamma = Gamma(torch.tensor([1., 1.]), torch.tensor([1., 1.]))
self.assertEqual(gamma._batch_shape, torch.Size((2,)))
self.assertEqual(gamma._event_shape, torch.Size(()))
self.assertEqual(gamma.sample().size(), torch.Size((2,)))
self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
self.assertEqual(gamma.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_chi2_shape_scalar_params(self):
chi2 = Chi2(1)
self.assertEqual(chi2._batch_shape, torch.Size())
self.assertEqual(chi2._event_shape, torch.Size())
self.assertEqual(chi2.sample().size(), torch.Size())
self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertEqual(chi2.log_prob(self.scalar_sample).size(), torch.Size())
self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(chi2.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_chi2_shape_tensor_params(self):
chi2 = Chi2(torch.tensor([1., 1.]))
self.assertEqual(chi2._batch_shape, torch.Size((2,)))
self.assertEqual(chi2._event_shape, torch.Size(()))
self.assertEqual(chi2.sample().size(), torch.Size((2,)))
self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2)
self.assertEqual(chi2.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_studentT_shape_scalar_params(self):
st = StudentT(1)
self.assertEqual(st._batch_shape, torch.Size())
self.assertEqual(st._event_shape, torch.Size())
self.assertEqual(st.sample().size(), torch.Size())
self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, st.log_prob, self.scalar_sample)
self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_studentT_shape_tensor_params(self):
st = StudentT(torch.tensor([1., 1.]))
self.assertEqual(st._batch_shape, torch.Size((2,)))
self.assertEqual(st._event_shape, torch.Size(()))
self.assertEqual(st.sample().size(), torch.Size((2,)))
self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2)
self.assertEqual(st.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_pareto_shape_scalar_params(self):
pareto = Pareto(1, 1)
self.assertEqual(pareto._batch_shape, torch.Size())
self.assertEqual(pareto._event_shape, torch.Size())
self.assertEqual(pareto.sample().size(), torch.Size())
self.assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertEqual(pareto.log_prob(self.tensor_sample_1 + 1).size(), torch.Size((3, 2)))
self.assertEqual(pareto.log_prob(self.tensor_sample_2 + 1).size(), torch.Size((3, 2, 3)))
def test_gumbel_shape_scalar_params(self):
gumbel = Gumbel(1, 1)
self.assertEqual(gumbel._batch_shape, torch.Size())
self.assertEqual(gumbel._event_shape, torch.Size())
self.assertEqual(gumbel.sample().size(), torch.Size())
self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_kumaraswamy_shape_scalar_params(self):
kumaraswamy = Kumaraswamy(1, 1)
self.assertEqual(kumaraswamy._batch_shape, torch.Size())
self.assertEqual(kumaraswamy._event_shape, torch.Size())
self.assertEqual(kumaraswamy.sample().size(), torch.Size())
self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_vonmises_shape_tensor_params(self):
von_mises = VonMises(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
self.assertEqual(von_mises._batch_shape, torch.Size((2,)))
self.assertEqual(von_mises._event_shape, torch.Size(()))
self.assertEqual(von_mises.sample().size(), torch.Size((2,)))
self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_vonmises_shape_scalar_params(self):
von_mises = VonMises(0., 1.)
self.assertEqual(von_mises._batch_shape, torch.Size())
self.assertEqual(von_mises._event_shape, torch.Size())
self.assertEqual(von_mises.sample().size(), torch.Size())
self.assertEqual(von_mises.sample(torch.Size((3, 2))).size(),
torch.Size((3, 2)))
self.assertEqual(von_mises.log_prob(self.tensor_sample_1).size(),
torch.Size((3, 2)))
self.assertEqual(von_mises.log_prob(self.tensor_sample_2).size(),
torch.Size((3, 2, 3)))
def test_weibull_scale_scalar_params(self):
weibull = Weibull(1, 1)
self.assertEqual(weibull._batch_shape, torch.Size())
self.assertEqual(weibull._event_shape, torch.Size())
self.assertEqual(weibull.sample().size(), torch.Size())
self.assertEqual(weibull.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertEqual(weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_wishart_shape_scalar_params(self):
wishart = Wishart(torch.tensor(1), torch.tensor([[1.]]))
self.assertEqual(wishart._batch_shape, torch.Size())
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
self.assertEqual(wishart.sample().size(), torch.Size((1, 1)))
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1)))
self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample)
def test_wishart_shape_tensor_params(self):
wishart = Wishart(torch.tensor([1., 1.]), torch.tensor([[[1.]], [[1.]]]))
self.assertEqual(wishart._batch_shape, torch.Size((2,)))
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1)))
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1)))
self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2)
self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,)))
def test_normal_shape_scalar_params(self):
normal = Normal(0, 1)
self.assertEqual(normal._batch_shape, torch.Size())
self.assertEqual(normal._event_shape, torch.Size())
self.assertEqual(normal.sample().size(), torch.Size())
self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, normal.log_prob, self.scalar_sample)
self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_normal_shape_tensor_params(self):
normal = Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
self.assertEqual(normal._batch_shape, torch.Size((2,)))
self.assertEqual(normal._event_shape, torch.Size(()))
self.assertEqual(normal.sample().size(), torch.Size((2,)))
self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2)
self.assertEqual(normal.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_uniform_shape_scalar_params(self):
uniform = Uniform(0, 1)
self.assertEqual(uniform._batch_shape, torch.Size())
self.assertEqual(uniform._event_shape, torch.Size())
self.assertEqual(uniform.sample().size(), torch.Size())
self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, uniform.log_prob, self.scalar_sample)
self.assertEqual(uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(uniform.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_uniform_shape_tensor_params(self):
uniform = Uniform(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
self.assertEqual(uniform._batch_shape, torch.Size((2,)))
self.assertEqual(uniform._event_shape, torch.Size(()))
self.assertEqual(uniform.sample().size(), torch.Size((2,)))
self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2)
self.assertEqual(uniform.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_exponential_shape_scalar_param(self):
expon = Exponential(1.)
self.assertEqual(expon._batch_shape, torch.Size())
self.assertEqual(expon._event_shape, torch.Size())
self.assertEqual(expon.sample().size(), torch.Size())
self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, expon.log_prob, self.scalar_sample)
self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_exponential_shape_tensor_param(self):
expon = Exponential(torch.tensor([1., 1.]))
self.assertEqual(expon._batch_shape, torch.Size((2,)))
self.assertEqual(expon._event_shape, torch.Size(()))
self.assertEqual(expon.sample().size(), torch.Size((2,)))
self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2)
self.assertEqual(expon.log_prob(torch.ones(2, 2)).size(), torch.Size((2, 2)))
def test_laplace_shape_scalar_params(self):
laplace = Laplace(0, 1)
self.assertEqual(laplace._batch_shape, torch.Size())
self.assertEqual(laplace._event_shape, torch.Size())
self.assertEqual(laplace.sample().size(), torch.Size())
self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample)
self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_laplace_shape_tensor_params(self):
laplace = Laplace(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
self.assertEqual(laplace._batch_shape, torch.Size((2,)))
self.assertEqual(laplace._event_shape, torch.Size(()))
self.assertEqual(laplace.sample().size(), torch.Size((2,)))
self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_continuous_bernoulli_shape_scalar_params(self):
continuous_bernoulli = ContinuousBernoulli(0.3)
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size())
self.assertEqual(continuous_bernoulli._event_shape, torch.Size())
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size())
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample)
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_continuous_bernoulli_shape_tensor_params(self):
continuous_bernoulli = ContinuousBernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli._event_shape, torch.Size(()))
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2)
self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
class TestKL(DistributionsTestCase):
def setUp(self):
super(TestKL, self).setUp()
class Binomial30(Binomial):
def __init__(self, probs):
super(Binomial30, self).__init__(30, probs)
# These are pairs of distributions with 4 x 4 parameters as specified.
# The first of the pair e.g. bernoulli[0] varies column-wise and the second
# e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9])
binomial_vectorized_count = (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])))
beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
categorical = pairwise(Categorical, [[0.4, 0.3, 0.3],
[0.2, 0.7, 0.1],
[0.33, 0.33, 0.34],
[0.2, 0.2, 0.6]])
cauchy = pairwise(Cauchy, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
chi2 = pairwise(Chi2, [1.0, 2.0, 2.5, 5.0])
dirichlet = pairwise(Dirichlet, [[0.1, 0.2, 0.7],
[0.5, 0.4, 0.1],
[0.33, 0.33, 0.34],
[0.2, 0.2, 0.4]])
exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0])
gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0])
laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
independent = (Independent(normal[0], 1), Independent(normal[1], 1))
onehotcategorical = pairwise(OneHotCategorical, [[0.4, 0.3, 0.3],
[0.2, 0.7, 0.1],
[0.33, 0.33, 0.34],
[0.2, 0.2, 0.6]])
pareto = (Pareto(torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4),
torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4)),
Pareto(torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4),
torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4)))
poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
uniform_within_unit = pairwise(Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8])
uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4])
uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5])
continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9])
# These tests should pass with precision = 0.01, but that makes tests very expensive.
# Instead, we test with precision = 0.1 and only test with higher precision locally
# when adding a new KL implementation.
# The following pairs are not tested due to very high variance of the monte carlo
# estimator; their implementations have been reviewed with extra care:
# - (pareto, normal)
self.precision = 0.1 # Set this to 0.01 when testing a new KL implementation.
self.max_samples = int(1e07) # Increase this when testing at smaller precision.
self.samples_per_batch = int(1e04)
self.finite_examples = [
(bernoulli, bernoulli),
(bernoulli, poisson),
(beta, beta),
(beta, chi2),
(beta, exponential),
(beta, gamma),
(beta, normal),
(binomial30, binomial30),
(binomial_vectorized_count, binomial_vectorized_count),
(categorical, categorical),
(cauchy, cauchy),
(chi2, chi2),
(chi2, exponential),
(chi2, gamma),
(chi2, normal),
(dirichlet, dirichlet),
(exponential, chi2),
(exponential, exponential),
(exponential, gamma),
(exponential, gumbel),
(exponential, normal),
(gamma, chi2),
(gamma, exponential),
(gamma, gamma),
(gamma, gumbel),
(gamma, normal),
(gumbel, gumbel),
(gumbel, normal),
(halfnormal, halfnormal),
(independent, independent),
(laplace, laplace),
(lognormal, lognormal),
(laplace, normal),
(normal, gumbel),
(normal, laplace),
(normal, normal),
(onehotcategorical, onehotcategorical),
(pareto, chi2),
(pareto, pareto),
(pareto, exponential),
(pareto, gamma),
(poisson, poisson),
(uniform_within_unit, beta),
(uniform_positive, chi2),
(uniform_positive, exponential),
(uniform_positive, gamma),
(uniform_real, gumbel),
(uniform_real, normal),
(uniform_pareto, pareto),
(continuous_bernoulli, continuous_bernoulli),
(continuous_bernoulli, exponential),
(continuous_bernoulli, normal),
(beta, continuous_bernoulli)
]
self.infinite_examples = [
(Bernoulli(0), Bernoulli(1)),
(Bernoulli(1), Bernoulli(0)),
(Categorical(torch.tensor([0.9, 0.1])), Categorical(torch.tensor([1., 0.]))),
(Categorical(torch.tensor([[0.9, 0.1], [.9, .1]])), Categorical(torch.tensor([1., 0.]))),
(Beta(1, 2), Uniform(0.25, 1)),
(Beta(1, 2), Uniform(0, 0.75)),
(Beta(1, 2), Uniform(0.25, 0.75)),
(Beta(1, 2), Pareto(1, 2)),
(Binomial(31, 0.7), Binomial(30, 0.3)),
(Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8]))),
(Chi2(1), Beta(2, 3)),
(Chi2(1), Pareto(2, 3)),
(Chi2(1), Uniform(-2, 3)),
(Exponential(1), Beta(2, 3)),
(Exponential(1), Pareto(2, 3)),
(Exponential(1), Uniform(-2, 3)),
(Gamma(1, 2), Beta(3, 4)),
(Gamma(1, 2), Pareto(3, 4)),
(Gamma(1, 2), Uniform(-3, 4)),
(Gumbel(-1, 2), Beta(3, 4)),
(Gumbel(-1, 2), Chi2(3)),
(Gumbel(-1, 2), Exponential(3)),
(Gumbel(-1, 2), Gamma(3, 4)),
(Gumbel(-1, 2), Pareto(3, 4)),
(Gumbel(-1, 2), Uniform(-3, 4)),
(Laplace(-1, 2), Beta(3, 4)),
(Laplace(-1, 2), Chi2(3)),
(Laplace(-1, 2), Exponential(3)),
(Laplace(-1, 2), Gamma(3, 4)),
(Laplace(-1, 2), Pareto(3, 4)),
(Laplace(-1, 2), Uniform(-3, 4)),
(Normal(-1, 2), Beta(3, 4)),
(Normal(-1, 2), Chi2(3)),
(Normal(-1, 2), Exponential(3)),
(Normal(-1, 2), Gamma(3, 4)),
(Normal(-1, 2), Pareto(3, 4)),
(Normal(-1, 2), Uniform(-3, 4)),
(Pareto(2, 1), Chi2(3)),
(Pareto(2, 1), Exponential(3)),
(Pareto(2, 1), Gamma(3, 4)),
(Pareto(1, 2), Normal(-3, 4)),
(Pareto(1, 2), Pareto(3, 4)),
(Poisson(2), Bernoulli(0.5)),
(Poisson(2.3), Binomial(10, 0.2)),
(Uniform(-1, 1), Beta(2, 2)),
(Uniform(0, 2), Beta(3, 4)),
(Uniform(-1, 2), Beta(3, 4)),
(Uniform(-1, 2), Chi2(3)),
(Uniform(-1, 2), Exponential(3)),
(Uniform(-1, 2), Gamma(3, 4)),
(Uniform(-1, 2), Pareto(3, 4)),
(ContinuousBernoulli(0.25), Uniform(0.25, 1)),
(ContinuousBernoulli(0.25), Uniform(0, 0.75)),
(ContinuousBernoulli(0.25), Uniform(0.25, 0.75)),
(ContinuousBernoulli(0.25), Pareto(1, 2)),
(Exponential(1), ContinuousBernoulli(0.75)),
(Gamma(1, 2), ContinuousBernoulli(0.75)),
(Gumbel(-1, 2), ContinuousBernoulli(0.75)),
(Laplace(-1, 2), ContinuousBernoulli(0.75)),
(Normal(-1, 2), ContinuousBernoulli(0.75)),
(Uniform(-1, 1), ContinuousBernoulli(0.75)),
(Uniform(0, 2), ContinuousBernoulli(0.75)),
(Uniform(-1, 2), ContinuousBernoulli(0.75))
]
def test_kl_monte_carlo(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for (p, _), (_, q) in self.finite_examples:
actual = kl_divergence(p, q)
numerator = 0
denominator = 0
while denominator < self.max_samples:
x = p.sample(sample_shape=(self.samples_per_batch,))
numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
denominator += x.size(0)
expected = numerator / denominator
error = torch.abs(expected - actual) / (1 + expected)
if error[error == error].max() < self.precision:
break
self.assertLess(error[error == error].max(), self.precision, '\n'.join([
'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
'Expected ({} Monte Carlo samples): {}'.format(denominator, expected),
'Actual (analytic): {}'.format(actual),
]))
# Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
# positive (semi) definite matrices. n is set to 5, but can be increased during testing.
def test_kl_multivariate_normal(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
n = 5 # Number of tests for multivariate_normal
for i in range(0, n):
loc = [torch.randn(4) for _ in range(0, 2)]
scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(4, 4)) for _ in range(0, 2)]
p = MultivariateNormal(loc=loc[0], scale_tril=scale_tril[0])
q = MultivariateNormal(loc=loc[1], scale_tril=scale_tril[1])
actual = kl_divergence(p, q)
numerator = 0
denominator = 0
while denominator < self.max_samples:
x = p.sample(sample_shape=(self.samples_per_batch,))
numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
denominator += x.size(0)
expected = numerator / denominator
error = torch.abs(expected - actual) / (1 + expected)
if error[error == error].max() < self.precision:
break
self.assertLess(error[error == error].max(), self.precision, '\n'.join([
'Incorrect KL(MultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
'Expected ({} Monte Carlo sample): {}'.format(denominator, expected),
'Actual (analytic): {}'.format(actual),
]))
def test_kl_multivariate_normal_batched(self):
b = 7 # Number of batches
loc = [torch.randn(b, 3) for _ in range(0, 2)]
scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)) for _ in range(0, 2)]
expected_kl = torch.stack([
kl_divergence(MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
MultivariateNormal(loc[1][i], scale_tril=scale_tril[1][i])) for i in range(0, b)])
actual_kl = kl_divergence(MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
self.assertEqual(expected_kl, actual_kl)
def test_kl_multivariate_normal_batched_broadcasted(self):
b = 7 # Number of batches
loc = [torch.randn(b, 3) for _ in range(0, 2)]
scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)),
transform_to(constraints.lower_cholesky)(torch.randn(3, 3))]
expected_kl = torch.stack([
kl_divergence(MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
MultivariateNormal(loc[1][i], scale_tril=scale_tril[1])) for i in range(0, b)])
actual_kl = kl_divergence(MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
self.assertEqual(expected_kl, actual_kl)
def test_kl_lowrank_multivariate_normal(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
n = 5 # Number of tests for lowrank_multivariate_normal
for i in range(0, n):
loc = [torch.randn(4) for _ in range(0, 2)]
cov_factor = [torch.randn(4, 3) for _ in range(0, 2)]
cov_diag = [transform_to(constraints.positive)(torch.randn(4)) for _ in range(0, 2)]
covariance_matrix = [cov_factor[i].matmul(cov_factor[i].t()) +
cov_diag[i].diag() for i in range(0, 2)]
p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0])
q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1])
p_full = MultivariateNormal(loc[0], covariance_matrix[0])
q_full = MultivariateNormal(loc[1], covariance_matrix[1])
expected = kl_divergence(p_full, q_full)
actual_lowrank_lowrank = kl_divergence(p, q)
actual_lowrank_full = kl_divergence(p, q_full)
actual_full_lowrank = kl_divergence(p_full, q)
error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max()
self.assertLess(error_lowrank_lowrank, self.precision, '\n'.join([
'Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
'Expected (from KL MultivariateNormal): {}'.format(expected),
'Actual (analytic): {}'.format(actual_lowrank_lowrank),
]))
error_lowrank_full = torch.abs(actual_lowrank_full - expected).max()
self.assertLess(error_lowrank_full, self.precision, '\n'.join([
'Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
'Expected (from KL MultivariateNormal): {}'.format(expected),
'Actual (analytic): {}'.format(actual_lowrank_full),
]))
error_full_lowrank = torch.abs(actual_full_lowrank - expected).max()
self.assertLess(error_full_lowrank, self.precision, '\n'.join([
'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
'Expected (from KL MultivariateNormal): {}'.format(expected),
'Actual (analytic): {}'.format(actual_full_lowrank),
]))
def test_kl_lowrank_multivariate_normal_batched(self):
b = 7 # Number of batches
loc = [torch.randn(b, 3) for _ in range(0, 2)]
cov_factor = [torch.randn(b, 3, 2) for _ in range(0, 2)]
cov_diag = [transform_to(constraints.positive)(torch.randn(b, 3)) for _ in range(0, 2)]
expected_kl = torch.stack([
kl_divergence(LowRankMultivariateNormal(loc[0][i], cov_factor[0][i], cov_diag[0][i]),
LowRankMultivariateNormal(loc[1][i], cov_factor[1][i], cov_diag[1][i]))
for i in range(0, b)])
actual_kl = kl_divergence(LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]),
LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]))
self.assertEqual(expected_kl, actual_kl)
def test_kl_exponential_family(self):
for (p, _), (_, q) in self.finite_examples:
if type(p) == type(q) and issubclass(type(p), ExponentialFamily):
actual = kl_divergence(p, q)
expected = _kl_expfamily_expfamily(p, q)
self.assertEqual(actual, expected, msg='\n'.join([
'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
'Expected (using Bregman Divergence) {}'.format(expected),
'Actual (analytic) {}'.format(actual),
'max error = {}'.format(torch.abs(actual - expected).max())
]))
def test_kl_infinite(self):
for p, q in self.infinite_examples:
self.assertTrue((kl_divergence(p, q) == inf).all(),
'Incorrect KL({}, {})'.format(type(p).__name__, type(q).__name__))
def test_kl_edgecases(self):
self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0)
self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
self.assertEqual(kl_divergence(Categorical(torch.tensor([0., 1.])), Categorical(torch.tensor([0., 1.]))), 0)
def test_kl_shape(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
try:
kl = kl_divergence(dist, dist)
except NotImplementedError:
continue
expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
self.assertEqual(kl.shape, expected_shape, msg='\n'.join([
'{} example {}/{}'.format(Dist.__name__, i + 1, len(params)),
'Expected {}'.format(expected_shape),
'Actual {}'.format(kl.shape),
]))
def test_kl_transformed(self):
# Regression test for https://github.com/pytorch/pytorch/issues/34859
scale = torch.ones(2, 3)
loc = torch.zeros(2, 3)
normal = Normal(loc=loc, scale=scale)
diag_normal = Independent(normal, reinterpreted_batch_ndims=1)
trans_dist = TransformedDistribution(diag_normal, AffineTransform(loc=0., scale=2.))
self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,))
self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,))
def test_entropy_monte_carlo(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
try:
actual = dist.entropy()
except NotImplementedError:
continue
x = dist.sample(sample_shape=(60000,))
expected = -dist.log_prob(x).mean(0)
ignore = (expected == inf) | (expected == -inf)
expected[ignore] = actual[ignore]
self.assertEqual(actual, expected, atol=0.2, rtol=0, msg='\n'.join([
'{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
'Expected (monte carlo) {}'.format(expected),
'Actual (analytic) {}'.format(actual),
'max error = {}'.format(torch.abs(actual - expected).max()),
]))
def test_entropy_exponential_family(self):
for Dist, params in EXAMPLES:
if not issubclass(Dist, ExponentialFamily):
continue
for i, param in enumerate(params):
dist = Dist(**param)
try:
actual = dist.entropy()
except NotImplementedError:
continue
try:
expected = ExponentialFamily.entropy(dist)
except NotImplementedError:
continue
self.assertEqual(actual, expected, msg='\n'.join([
'{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
'Expected (Bregman Divergence) {}'.format(expected),
'Actual (analytic) {}'.format(actual),
'max error = {}'.format(torch.abs(actual - expected).max())
]))
class TestConstraints(DistributionsTestCase):
def test_params_constraints(self):
normalize_probs_dists = (
Categorical,
Multinomial,
OneHotCategorical,
OneHotCategoricalStraightThrough,
RelaxedOneHotCategorical
)
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
for name, value in param.items():
if isinstance(value, numbers.Number):
value = torch.tensor([value])
if Dist in normalize_probs_dists and name == 'probs':
# These distributions accept positive probs, but elsewhere we
# use a stricter constraint to the simplex.
value = value / value.sum(-1, True)
try:
constraint = dist.arg_constraints[name]
except KeyError:
continue # ignore optional parameters
# Check param shape is compatible with distribution shape.
self.assertGreaterEqual(value.dim(), constraint.event_dim)
value_batch_shape = value.shape[:value.dim() - constraint.event_dim]
torch.broadcast_shapes(dist.batch_shape, value_batch_shape)
if is_dependent(constraint):
continue
message = '{} example {}/{} parameter {} = {}'.format(
Dist.__name__, i + 1, len(params), name, value)
self.assertTrue(constraint.check(value).all(), msg=message)
def test_support_constraints(self):
for Dist, params in EXAMPLES:
self.assertIsInstance(Dist.support, Constraint)
for i, param in enumerate(params):
dist = Dist(**param)
value = dist.sample()
constraint = dist.support
message = '{} example {}/{} sample = {}'.format(
Dist.__name__, i + 1, len(params), value)
self.assertEqual(constraint.event_dim, len(dist.event_shape), msg=message)
ok = constraint.check(value)
self.assertEqual(ok.shape, dist.batch_shape, msg=message)
self.assertTrue(ok.all(), msg=message)
class TestNumericalStability(DistributionsTestCase):
def _test_pdf_score(self,
dist_class,
x,
expected_value,
probs=None,
logits=None,
expected_gradient=None,
atol=1e-5):
if probs is not None:
p = probs.detach().requires_grad_()
dist = dist_class(p)
else:
p = logits.detach().requires_grad_()
dist = dist_class(logits=p)
log_pdf = dist.log_prob(x)
log_pdf.sum().backward()
self.assertEqual(log_pdf,
expected_value,
atol=atol,
rtol=0,
msg='Incorrect value for tensor type: {}. Expected = {}, Actual = {}'
.format(type(x), expected_value, log_pdf))
if expected_gradient is not None:
self.assertEqual(p.grad,
expected_gradient,
atol=atol,
rtol=0,
msg='Incorrect gradient for tensor type: {}. Expected = {}, Actual = {}'
.format(type(x), expected_gradient, p.grad))
def test_bernoulli_gradient(self):
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([0]),
x=tensor_type([0]),
expected_value=tensor_type([0]),
expected_gradient=tensor_type([0]))
self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([0]),
x=tensor_type([1]),
expected_value=tensor_type([torch.finfo(tensor_type([]).dtype).eps]).log(),
expected_gradient=tensor_type([0]))
self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([1e-4]),
x=tensor_type([1]),
expected_value=tensor_type([math.log(1e-4)]),
expected_gradient=tensor_type([10000]))
# Lower precision due to:
# >>> 1 / (1 - torch.FloatTensor([0.9999]))
# 9998.3408
# [torch.FloatTensor of size 1]
self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([1 - 1e-4]),
x=tensor_type([0]),
expected_value=tensor_type([math.log(1e-4)]),
expected_gradient=tensor_type([-10000]),
atol=2)
self._test_pdf_score(dist_class=Bernoulli,
logits=tensor_type([math.log(9999)]),
x=tensor_type([0]),
expected_value=tensor_type([math.log(1e-4)]),
expected_gradient=tensor_type([-1]),
atol=1e-3)
def test_bernoulli_with_logits_underflow(self):
for tensor_type, lim in ([(torch.FloatTensor, -1e38),
(torch.DoubleTensor, -1e308)]):
self._test_pdf_score(dist_class=Bernoulli,
logits=tensor_type([lim]),
x=tensor_type([0]),
expected_value=tensor_type([0]),
expected_gradient=tensor_type([0]))
def test_bernoulli_with_logits_overflow(self):
for tensor_type, lim in ([(torch.FloatTensor, 1e38),
(torch.DoubleTensor, 1e308)]):
self._test_pdf_score(dist_class=Bernoulli,
logits=tensor_type([lim]),
x=tensor_type([1]),
expected_value=tensor_type([0]),
expected_gradient=tensor_type([0]))
def test_categorical_log_prob(self):
for dtype in ([torch.float, torch.double]):
p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
categorical = OneHotCategorical(p)
log_pdf = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
self.assertEqual(log_pdf.item(), 0)
def test_categorical_log_prob_with_logits(self):
for dtype in ([torch.float, torch.double]):
p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
categorical = OneHotCategorical(logits=p)
log_pdf_prob_1 = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
self.assertEqual(log_pdf_prob_1.item(), 0)
log_pdf_prob_0 = categorical.log_prob(torch.tensor([1, 0], dtype=dtype))
self.assertEqual(log_pdf_prob_0.item(), -inf)
def test_multinomial_log_prob(self):
for dtype in ([torch.float, torch.double]):
p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
s = torch.tensor([0, 10], dtype=dtype)
multinomial = Multinomial(10, p)
log_pdf = multinomial.log_prob(s)
self.assertEqual(log_pdf.item(), 0)
def test_multinomial_log_prob_with_logits(self):
for dtype in ([torch.float, torch.double]):
p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
multinomial = Multinomial(10, logits=p)
log_pdf_prob_1 = multinomial.log_prob(torch.tensor([0, 10], dtype=dtype))
self.assertEqual(log_pdf_prob_1.item(), 0)
log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype))
self.assertEqual(log_pdf_prob_0.item(), -inf)
def test_continuous_bernoulli_gradient(self):
def expec_val(x, probs=None, logits=None):
assert not (probs is None and logits is None)
if logits is not None:
probs = 1. / (1. + math.exp(-logits))
bern_log_lik = x * math.log(probs) + (1. - x) * math.log1p(-probs)
if probs < 0.499 or probs > 0.501: # using default values of lims here
log_norm_const = math.log(
math.fabs(math.atanh(1. - 2. * probs))) - math.log(math.fabs(1. - 2. * probs)) + math.log(2.)
else:
aux = math.pow(probs - 0.5, 2)
log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux
log_lik = bern_log_lik + log_norm_const
return log_lik
def expec_grad(x, probs=None, logits=None):
assert not (probs is None and logits is None)
if logits is not None:
probs = 1. / (1. + math.exp(-logits))
grad_bern_log_lik = x / probs - (1. - x) / (1. - probs)
if probs < 0.499 or probs > 0.501: # using default values of lims here
grad_log_c = 2. * probs - 4. * (probs - 1.) * probs * math.atanh(1. - 2. * probs) - 1.
grad_log_c /= 2. * (probs - 1.) * probs * (2. * probs - 1.) * math.atanh(1. - 2. * probs)
else:
grad_log_c = 8. / 3. * (probs - 0.5) + 416. / 45. * math.pow(probs - 0.5, 3)
grad = grad_bern_log_lik + grad_log_c
if logits is not None:
grad *= 1. / (1. + math.exp(logits)) - 1. / math.pow(1. + math.exp(logits), 2)
return grad
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.1]),
x=tensor_type([0.1]),
expected_value=tensor_type([expec_val(0.1, probs=0.1)]),
expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]))
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.1]),
x=tensor_type([1.]),
expected_value=tensor_type([expec_val(1., probs=0.1)]),
expected_gradient=tensor_type([expec_grad(1., probs=0.1)]))
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.4999]),
x=tensor_type([0.9]),
expected_value=tensor_type([expec_val(0.9, probs=0.4999)]),
expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]))
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([1e-4]),
x=tensor_type([1]),
expected_value=tensor_type([expec_val(1, probs=1e-4)]),
expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
atol=1e-3)
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([1 - 1e-4]),
x=tensor_type([0.1]),
expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
atol=2)
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([math.log(9999)]),
x=tensor_type([0]),
expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
atol=1e-3)
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([0.001]),
x=tensor_type([0.5]),
expected_value=tensor_type([expec_val(0.5, logits=0.001)]),
expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]))
def test_continuous_bernoulli_with_logits_underflow(self):
for tensor_type, lim, expected in ([(torch.FloatTensor, -1e38, 2.76898),
(torch.DoubleTensor, -1e308, 3.58473)]):
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([lim]),
x=tensor_type([0]),
expected_value=tensor_type([expected]),
expected_gradient=tensor_type([0.]))
def test_continuous_bernoulli_with_logits_overflow(self):
for tensor_type, lim, expected in ([(torch.FloatTensor, 1e38, 2.76898),
(torch.DoubleTensor, 1e308, 3.58473)]):
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([lim]),
x=tensor_type([1]),
expected_value=tensor_type([expected]),
expected_gradient=tensor_type([0.]))
# TODO: make this a pytest parameterized test
class TestLazyLogitsInitialization(DistributionsTestCase):
def setUp(self):
super(TestLazyLogitsInitialization, self).setUp()
# ContinuousBernoulli is not tested because log_prob is not computed simply
# from 'logits', but 'probs' is also needed
self.examples = [e for e in EXAMPLES if e.Dist in
(Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)]
def test_lazy_logits_initialization(self):
for Dist, params in self.examples:
param = params[0].copy()
if 'probs' not in param:
continue
probs = param.pop('probs')
param['logits'] = probs_to_logits(probs)
dist = Dist(**param)
# Create new instance to generate a valid sample
dist.log_prob(Dist(**param).sample())
message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
self.assertNotIn('probs', dist.__dict__, msg=message)
try:
dist.enumerate_support()
except NotImplementedError:
pass
self.assertNotIn('probs', dist.__dict__, msg=message)
batch_shape, event_shape = dist.batch_shape, dist.event_shape
self.assertNotIn('probs', dist.__dict__, msg=message)
def test_lazy_probs_initialization(self):
for Dist, params in self.examples:
param = params[0].copy()
if 'probs' not in param:
continue
dist = Dist(**param)
dist.sample()
message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
self.assertNotIn('logits', dist.__dict__, msg=message)
try:
dist.enumerate_support()
except NotImplementedError:
pass
self.assertNotIn('logits', dist.__dict__, msg=message)
batch_shape, event_shape = dist.batch_shape, dist.event_shape
self.assertNotIn('logits', dist.__dict__, msg=message)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
class TestAgainstScipy(DistributionsTestCase):
def setUp(self):
super(TestAgainstScipy, self).setUp()
positive_var = torch.randn(20).exp()
positive_var2 = torch.randn(20).exp()
random_var = torch.randn(20)
simplex_tensor = softmax(torch.randn(20), dim=-1)
cov_tensor = torch.randn(20, 20)
cov_tensor = cov_tensor @ cov_tensor.mT
self.distribution_pairs = [
(
Bernoulli(simplex_tensor),
scipy.stats.bernoulli(simplex_tensor)
),
(
Beta(positive_var, positive_var2),
scipy.stats.beta(positive_var, positive_var2)
),
(
Binomial(10, simplex_tensor),
scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor.numpy())
),
(
Cauchy(random_var, positive_var),
scipy.stats.cauchy(loc=random_var, scale=positive_var)
),
(
Dirichlet(positive_var),
scipy.stats.dirichlet(positive_var)
),
(
Exponential(positive_var),
scipy.stats.expon(scale=positive_var.reciprocal())
),
(
FisherSnedecor(positive_var, 4 + positive_var2), # var for df2<=4 is undefined
scipy.stats.f(positive_var, 4 + positive_var2)
),
(
Gamma(positive_var, positive_var2),
scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal())
),
(
Geometric(simplex_tensor),
scipy.stats.geom(simplex_tensor, loc=-1)
),
(
Gumbel(random_var, positive_var2),
scipy.stats.gumbel_r(random_var, positive_var2)
),
(
HalfCauchy(positive_var),
scipy.stats.halfcauchy(scale=positive_var)
),
(
HalfNormal(positive_var2),
scipy.stats.halfnorm(scale=positive_var2)
),
(
Laplace(random_var, positive_var2),
scipy.stats.laplace(random_var, positive_var2)
),
(
# Tests fail 1e-5 threshold if scale > 3
LogNormal(random_var, positive_var.clamp(max=3)),
scipy.stats.lognorm(s=positive_var.clamp(max=3), scale=random_var.exp())
),
(
LowRankMultivariateNormal(random_var, torch.zeros(20, 1), positive_var2),
scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
),
(
Multinomial(10, simplex_tensor),
scipy.stats.multinomial(10, simplex_tensor)
),
(
MultivariateNormal(random_var, torch.diag(positive_var2)),
scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
),
(
MultivariateNormal(random_var, cov_tensor),
scipy.stats.multivariate_normal(random_var, cov_tensor)
),
(
Normal(random_var, positive_var2),
scipy.stats.norm(random_var, positive_var2)
),
(
OneHotCategorical(simplex_tensor),
scipy.stats.multinomial(1, simplex_tensor)
),
(
Pareto(positive_var, 2 + positive_var2),
scipy.stats.pareto(2 + positive_var2, scale=positive_var)
),
(
Poisson(positive_var),
scipy.stats.poisson(positive_var)
),
(
StudentT(2 + positive_var, random_var, positive_var2),
scipy.stats.t(2 + positive_var, random_var, positive_var2)
),
(
Uniform(random_var, random_var + positive_var),
scipy.stats.uniform(random_var, positive_var)
),
(
VonMises(random_var, positive_var),
scipy.stats.vonmises(positive_var, loc=random_var)
),
(
Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
),
(
# scipy var for Wishart only supports scalars
# SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
Wishart(
(20 if version.parse(scipy.__version__) < version.parse("1.7.0") else 19) + positive_var[0],
cov_tensor,
),
scipy.stats.wishart(
(20 if version.parse(scipy.__version__) < version.parse("1.7.0") else 19) + positive_var[0].item(),
cov_tensor,
),
),
]
def test_mean(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
# Cauchy, HalfCauchy distributions' mean is nan, skipping check
continue
elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)):
self.assertEqual(pytorch_dist.mean, scipy_dist.mean, msg=pytorch_dist)
else:
self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), msg=pytorch_dist)
def test_variance_stddev(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)):
# Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
# VonMises variance is circular and scipy doesn't produce a correct result
continue
elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), msg=pytorch_dist)
self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, msg=pytorch_dist)
elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)):
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov), msg=pytorch_dist)
self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov) ** 0.5, msg=pytorch_dist)
else:
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), msg=pytorch_dist)
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, msg=pytorch_dist)
def test_cdf(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = pytorch_dist.sample((5,))
try:
cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
continue
self.assertEqual(cdf, scipy_dist.cdf(samples), msg=pytorch_dist)
def test_icdf(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = torch.rand((5,) + pytorch_dist.batch_shape)
try:
icdf = pytorch_dist.icdf(samples)
except NotImplementedError:
continue
self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist)
class TestFunctors(DistributionsTestCase):
def test_cat_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100
x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
dim = 0
x = torch.cat([x1, x2, x3], dim=dim)
t = CatTransform([t1, t2, t3], dim=dim)
actual_dom_check = t.domain.check(x)
expected_dom_check = torch.cat([t1.domain.check(x1),
t2.domain.check(x2),
t3.domain.check(x3)], dim=dim)
self.assertEqual(expected_dom_check, actual_dom_check)
actual = t(x)
expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim)
self.assertEqual(expected, actual)
y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
y = torch.cat([y1, y2, y3], dim=dim)
actual_cod_check = t.codomain.check(y)
expected_cod_check = torch.cat([t1.codomain.check(y1),
t2.codomain.check(y2),
t3.codomain.check(y3)], dim=dim)
self.assertEqual(actual_cod_check, expected_cod_check)
actual_inv = t.inv(y)
expected_inv = torch.cat([t1.inv(y1), t2.inv(y2), t3.inv(y3)], dim=dim)
self.assertEqual(expected_inv, actual_inv)
actual_jac = t.log_abs_det_jacobian(x, y)
expected_jac = torch.cat([t1.log_abs_det_jacobian(x1, y1),
t2.log_abs_det_jacobian(x2, y2),
t3.log_abs_det_jacobian(x3, y3)], dim=dim)
self.assertEqual(actual_jac, expected_jac)
def test_cat_transform_non_uniform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
x2 = torch.cat([(torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100,
torch.arange(1, 101, dtype=torch.float).view(-1, 100)])
t1 = ExpTransform()
t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0)
dim = 0
x = torch.cat([x1, x2], dim=dim)
t = CatTransform([t1, t2], dim=dim, lengths=[1, 2])
actual_dom_check = t.domain.check(x)
expected_dom_check = torch.cat([t1.domain.check(x1),
t2.domain.check(x2)], dim=dim)
self.assertEqual(expected_dom_check, actual_dom_check)
actual = t(x)
expected = torch.cat([t1(x1), t2(x2)], dim=dim)
self.assertEqual(expected, actual)
y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
y2 = torch.cat([torch.arange(1, 101, dtype=torch.float).view(-1, 100),
torch.arange(1, 101, dtype=torch.float).view(-1, 100)])
y = torch.cat([y1, y2], dim=dim)
actual_cod_check = t.codomain.check(y)
expected_cod_check = torch.cat([t1.codomain.check(y1),
t2.codomain.check(y2)], dim=dim)
self.assertEqual(actual_cod_check, expected_cod_check)
actual_inv = t.inv(y)
expected_inv = torch.cat([t1.inv(y1), t2.inv(y2)], dim=dim)
self.assertEqual(expected_inv, actual_inv)
actual_jac = t.log_abs_det_jacobian(x, y)
expected_jac = torch.cat([t1.log_abs_det_jacobian(x1, y1),
t2.log_abs_det_jacobian(x2, y2)], dim=dim)
self.assertEqual(actual_jac, expected_jac)
def test_cat_event_dim(self):
t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1)
t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1)
dim = 1
bs = 16
x1 = torch.randn(bs, 2)
x2 = torch.randn(bs, 2)
x = torch.cat([x1, x2], dim=1)
t = CatTransform([t1, t2], dim=dim, lengths=[2, 2])
y1 = t1(x1)
y2 = t2(x2)
y = t(x)
actual_jac = t.log_abs_det_jacobian(x, y)
expected_jac = sum([t1.log_abs_det_jacobian(x1, y1),
t2.log_abs_det_jacobian(x2, y2)])
def test_stack_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float)
x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100
x3 = torch.arange(1, 101, dtype=torch.float)
t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
dim = 0
x = torch.stack([x1, x2, x3], dim=dim)
t = StackTransform([t1, t2, t3], dim=dim)
actual_dom_check = t.domain.check(x)
expected_dom_check = torch.stack([t1.domain.check(x1),
t2.domain.check(x2),
t3.domain.check(x3)], dim=dim)
self.assertEqual(expected_dom_check, actual_dom_check)
actual = t(x)
expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim)
self.assertEqual(expected, actual)
y1 = torch.arange(1, 101, dtype=torch.float)
y2 = torch.arange(1, 101, dtype=torch.float)
y3 = torch.arange(1, 101, dtype=torch.float)
y = torch.stack([y1, y2, y3], dim=dim)
actual_cod_check = t.codomain.check(y)
expected_cod_check = torch.stack([t1.codomain.check(y1),
t2.codomain.check(y2),
t3.codomain.check(y3)], dim=dim)
self.assertEqual(actual_cod_check, expected_cod_check)
actual_inv = t.inv(x)
expected_inv = torch.stack([t1.inv(x1), t2.inv(x2), t3.inv(x3)], dim=dim)
self.assertEqual(expected_inv, actual_inv)
actual_jac = t.log_abs_det_jacobian(x, y)
expected_jac = torch.stack([t1.log_abs_det_jacobian(x1, y1),
t2.log_abs_det_jacobian(x2, y2),
t3.log_abs_det_jacobian(x3, y3)], dim=dim)
self.assertEqual(actual_jac, expected_jac)
class TestValidation(DistributionsTestCase):
def setUp(self):
super(TestValidation, self).setUp()
def test_valid(self):
for Dist, params in EXAMPLES:
for param in params:
Dist(validate_args=True, **param)
def test_invalid_log_probs_arg(self):
# Check that validation errors are indeed disabled,
# but they might raise another error
for Dist, params in EXAMPLES:
if Dist == TransformedDistribution:
# TransformedDistribution has a distribution instance
# as the argument, so we cannot do much about that
continue
for i, param in enumerate(params):
d_nonval = Dist(validate_args=False, **param)
d_val = Dist(validate_args=True, **param)
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
# samples with incorrect shape must throw ValueError only
try:
log_prob = d_val.log_prob(v)
except ValueError:
pass
# get sample of correct shape
val = torch.full(d_val.batch_shape + d_val.event_shape, v)
# check samples with incorrect support
try:
log_prob = d_val.log_prob(val)
except ValueError as e:
if e.args and 'must be within the support' in e.args[0]:
try:
log_prob = d_nonval.log_prob(val)
except RuntimeError:
pass
# check correct samples are ok
valid_value = d_val.sample()
d_val.log_prob(valid_value)
# check invalid values raise ValueError
if valid_value.dtype == torch.long:
valid_value = valid_value.float()
invalid_value = torch.full_like(valid_value, math.nan)
try:
with self.assertRaisesRegex(
ValueError,
"Expected value argument .* to be within the support .*",
):
d_val.log_prob(invalid_value)
except AssertionError as e:
fail_string = "Support ValueError not raised for {} example {}/{}"
raise AssertionError(
fail_string.format(Dist.__name__, i + 1, len(params))
) from e
@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
def test_invalid(self):
for Dist, params in BAD_EXAMPLES:
for i, param in enumerate(params):
try:
with self.assertRaises(ValueError):
Dist(validate_args=True, **param)
except AssertionError as e:
fail_string = "ValueError not raised for {} example {}/{}"
raise AssertionError(
fail_string.format(Dist.__name__, i + 1, len(params))
) from e
def test_warning_unimplemented_constraints(self):
class Delta(Distribution):
def __init__(self, validate_args=True):
super().__init__(validate_args=validate_args)
def sample(self, sample_shape=torch.Size()):
return torch.tensor(0.).expand(sample_shape)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value[value != 0.] = -float('inf')
value[value == 0.] = 0.
return value
with self.assertWarns(UserWarning):
d = Delta()
sample = d.sample((2,))
with self.assertWarns(UserWarning):
d.log_prob(sample)
def tearDown(self):
super(TestValidation, self).tearDown()
class TestJit(DistributionsTestCase):
def _examples(self):
for Dist, params in EXAMPLES:
for param in params:
keys = param.keys()
values = tuple(param[key] for key in keys)
if not all(isinstance(x, torch.Tensor) for x in values):
continue
sample = Dist(**param).sample()
yield Dist, keys, values, sample
def _perturb_tensor(self, value, constraint):
if isinstance(constraint, constraints._IntegerGreaterThan):
return value + 1
if isinstance(constraint, constraints._PositiveDefinite) or isinstance(constraint, constraints._PositiveSemidefinite):
return value + torch.eye(value.shape[-1])
if value.dtype in [torch.float, torch.double]:
transform = transform_to(constraint)
delta = value.new(value.shape).normal_()
return transform(transform.inv(value) + delta)
if value.dtype == torch.long:
result = value.clone()
result[value == 0] = 1
result[value == 1] = 0
return result
raise NotImplementedError
def _perturb(self, Dist, keys, values, sample):
with torch.no_grad():
if Dist is Uniform:
param = dict(zip(keys, values))
param['low'] = param['low'] - torch.rand(param['low'].shape)
param['high'] = param['high'] + torch.rand(param['high'].shape)
values = [param[key] for key in keys]
else:
values = [self._perturb_tensor(value, Dist.arg_constraints.get(key, constraints.real))
for key, value in zip(keys, values)]
param = dict(zip(keys, values))
sample = Dist(**param).sample()
return values, sample
def test_sample(self):
for Dist, keys, values, sample in self._examples():
def f(*values):
param = dict(zip(keys, values))
dist = Dist(**param)
return dist.sample()
traced_f = torch.jit.trace(f, values, check_trace=False)
# FIXME Schema not found for node
xfail = [
Cauchy, # aten::cauchy(Double(2,1), float, float, Generator)
HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator)
VonMises # Variance is not Euclidean
]
if Dist in xfail:
continue
with torch.random.fork_rng():
sample = f(*values)
traced_sample = traced_f(*values)
self.assertEqual(sample, traced_sample)
# FIXME no nondeterministic nodes found in trace
xfail = [Beta, Dirichlet]
if Dist not in xfail:
self.assertTrue(any(n.isNondeterministic() for n in traced_f.graph.nodes()))
def test_rsample(self):
for Dist, keys, values, sample in self._examples():
if not Dist.has_rsample:
continue
def f(*values):
param = dict(zip(keys, values))
dist = Dist(**param)
return dist.rsample()
traced_f = torch.jit.trace(f, values, check_trace=False)
# FIXME Schema not found for node
xfail = [
Cauchy, # aten::cauchy(Double(2,1), float, float, Generator)
HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator)
]
if Dist in xfail:
continue
with torch.random.fork_rng():
sample = f(*values)
traced_sample = traced_f(*values)
self.assertEqual(sample, traced_sample)
# FIXME no nondeterministic nodes found in trace
xfail = [Beta, Dirichlet]
if Dist not in xfail:
self.assertTrue(any(n.isNondeterministic() for n in traced_f.graph.nodes()))
def test_log_prob(self):
for Dist, keys, values, sample in self._examples():
# FIXME traced functions produce incorrect results
xfail = [LowRankMultivariateNormal, MultivariateNormal]
if Dist in xfail:
continue
def f(sample, *values):
param = dict(zip(keys, values))
dist = Dist(**param)
return dist.log_prob(sample)
traced_f = torch.jit.trace(f, (sample,) + values)
# check on different data
values, sample = self._perturb(Dist, keys, values, sample)
expected = f(sample, *values)
actual = traced_f(sample, *values)
self.assertEqual(expected, actual,
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
def test_enumerate_support(self):
for Dist, keys, values, sample in self._examples():
# FIXME traced functions produce incorrect results
xfail = [Binomial]
if Dist in xfail:
continue
def f(*values):
param = dict(zip(keys, values))
dist = Dist(**param)
return dist.enumerate_support()
try:
traced_f = torch.jit.trace(f, values)
except NotImplementedError:
continue
# check on different data
values, sample = self._perturb(Dist, keys, values, sample)
expected = f(*values)
actual = traced_f(*values)
self.assertEqual(expected, actual,
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
def test_mean(self):
for Dist, keys, values, sample in self._examples():
def f(*values):
param = dict(zip(keys, values))
dist = Dist(**param)
return dist.mean
try:
traced_f = torch.jit.trace(f, values)
except NotImplementedError:
continue
# check on different data
values, sample = self._perturb(Dist, keys, values, sample)
expected = f(*values)
actual = traced_f(*values)
expected[expected == float('inf')] = 0.
actual[actual == float('inf')] = 0.
self.assertEqual(expected, actual,
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
def test_variance(self):
for Dist, keys, values, sample in self._examples():
if Dist in [Cauchy, HalfCauchy]:
continue # infinite variance
def f(*values):
param = dict(zip(keys, values))
dist = Dist(**param)
return dist.variance
try:
traced_f = torch.jit.trace(f, values)
except NotImplementedError:
continue
# check on different data
values, sample = self._perturb(Dist, keys, values, sample)
expected = f(*values).clone()
actual = traced_f(*values).clone()
expected[expected == float('inf')] = 0.
actual[actual == float('inf')] = 0.
self.assertEqual(expected, actual,
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
def test_entropy(self):
for Dist, keys, values, sample in self._examples():
# FIXME traced functions produce incorrect results
xfail = [LowRankMultivariateNormal, MultivariateNormal]
if Dist in xfail:
continue
def f(*values):
param = dict(zip(keys, values))
dist = Dist(**param)
return dist.entropy()
try:
traced_f = torch.jit.trace(f, values)
except NotImplementedError:
continue
# check on different data
values, sample = self._perturb(Dist, keys, values, sample)
expected = f(*values)
actual = traced_f(*values)
self.assertEqual(expected, actual,
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
def test_cdf(self):
for Dist, keys, values, sample in self._examples():
def f(sample, *values):
param = dict(zip(keys, values))
dist = Dist(**param)
cdf = dist.cdf(sample)
return dist.icdf(cdf)
try:
traced_f = torch.jit.trace(f, (sample,) + values)
except NotImplementedError:
continue
# check on different data
values, sample = self._perturb(Dist, keys, values, sample)
expected = f(sample, *values)
actual = traced_f(sample, *values)
self.assertEqual(expected, actual,
msg='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
if __name__ == '__main__' and torch._C.has_lapack:
run_tests()