| from typing import Dict |
| |
| import torch |
| from torch.distributions import Categorical, constraints |
| from torch.distributions.distribution import Distribution |
| |
| __all__ = ["MixtureSameFamily"] |
| |
| |
| class MixtureSameFamily(Distribution): |
| r""" |
| The `MixtureSameFamily` distribution implements a (batch of) mixture |
| distribution where all component are from different parameterizations of |
| the same distribution type. It is parameterized by a `Categorical` |
| "selecting distribution" (over `k` component) and a component |
| distribution, i.e., a `Distribution` with a rightmost batch shape |
| (equal to `[k]`) which indexes each (batch of) component. |
| |
| Examples:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally |
| >>> # weighted normal distributions |
| >>> mix = D.Categorical(torch.ones(5,)) |
| >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) |
| >>> gmm = MixtureSameFamily(mix, comp) |
| |
| >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally |
| >>> # weighted bivariate normal distributions |
| >>> mix = D.Categorical(torch.ones(5,)) |
| >>> comp = D.Independent(D.Normal( |
| ... torch.randn(5,2), torch.rand(5,2)), 1) |
| >>> gmm = MixtureSameFamily(mix, comp) |
| |
| >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each |
| >>> # consisting of 5 random weighted bivariate normal distributions |
| >>> mix = D.Categorical(torch.rand(3,5)) |
| >>> comp = D.Independent(D.Normal( |
| ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) |
| >>> gmm = MixtureSameFamily(mix, comp) |
| |
| Args: |
| mixture_distribution: `torch.distributions.Categorical`-like |
| instance. Manages the probability of selecting component. |
| The number of categories must match the rightmost batch |
| dimension of the `component_distribution`. Must have either |
| scalar `batch_shape` or `batch_shape` matching |
| `component_distribution.batch_shape[:-1]` |
| component_distribution: `torch.distributions.Distribution`-like |
| instance. Right-most batch dimension indexes component. |
| """ |
| arg_constraints: Dict[str, constraints.Constraint] = {} |
| has_rsample = False |
| |
| def __init__( |
| self, mixture_distribution, component_distribution, validate_args=None |
| ): |
| self._mixture_distribution = mixture_distribution |
| self._component_distribution = component_distribution |
| |
| if not isinstance(self._mixture_distribution, Categorical): |
| raise ValueError( |
| " The Mixture distribution needs to be an " |
| " instance of torch.distributions.Categorical" |
| ) |
| |
| if not isinstance(self._component_distribution, Distribution): |
| raise ValueError( |
| "The Component distribution need to be an " |
| "instance of torch.distributions.Distribution" |
| ) |
| |
| # Check that batch size matches |
| mdbs = self._mixture_distribution.batch_shape |
| cdbs = self._component_distribution.batch_shape[:-1] |
| for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): |
| if size1 != 1 and size2 != 1 and size1 != size2: |
| raise ValueError( |
| f"`mixture_distribution.batch_shape` ({mdbs}) is not " |
| "compatible with `component_distribution." |
| f"batch_shape`({cdbs})" |
| ) |
| |
| # Check that the number of mixture component matches |
| km = self._mixture_distribution.logits.shape[-1] |
| kc = self._component_distribution.batch_shape[-1] |
| if km is not None and kc is not None and km != kc: |
| raise ValueError( |
| f"`mixture_distribution component` ({km}) does not" |
| " equal `component_distribution.batch_shape[-1]`" |
| f" ({kc})" |
| ) |
| self._num_component = km |
| |
| event_shape = self._component_distribution.event_shape |
| self._event_ndims = len(event_shape) |
| super().__init__( |
| batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args |
| ) |
| |
| def expand(self, batch_shape, _instance=None): |
| batch_shape = torch.Size(batch_shape) |
| batch_shape_comp = batch_shape + (self._num_component,) |
| new = self._get_checked_instance(MixtureSameFamily, _instance) |
| new._component_distribution = self._component_distribution.expand( |
| batch_shape_comp |
| ) |
| new._mixture_distribution = self._mixture_distribution.expand(batch_shape) |
| new._num_component = self._num_component |
| new._event_ndims = self._event_ndims |
| event_shape = new._component_distribution.event_shape |
| super(MixtureSameFamily, new).__init__( |
| batch_shape=batch_shape, event_shape=event_shape, validate_args=False |
| ) |
| new._validate_args = self._validate_args |
| return new |
| |
| @constraints.dependent_property |
| def support(self): |
| # FIXME this may have the wrong shape when support contains batched |
| # parameters |
| return self._component_distribution.support |
| |
| @property |
| def mixture_distribution(self): |
| return self._mixture_distribution |
| |
| @property |
| def component_distribution(self): |
| return self._component_distribution |
| |
| @property |
| def mean(self): |
| probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) |
| return torch.sum( |
| probs * self.component_distribution.mean, dim=-1 - self._event_ndims |
| ) # [B, E] |
| |
| @property |
| def variance(self): |
| # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) |
| probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) |
| mean_cond_var = torch.sum( |
| probs * self.component_distribution.variance, dim=-1 - self._event_ndims |
| ) |
| var_cond_mean = torch.sum( |
| probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), |
| dim=-1 - self._event_ndims, |
| ) |
| return mean_cond_var + var_cond_mean |
| |
| def cdf(self, x): |
| x = self._pad(x) |
| cdf_x = self.component_distribution.cdf(x) |
| mix_prob = self.mixture_distribution.probs |
| |
| return torch.sum(cdf_x * mix_prob, dim=-1) |
| |
| def log_prob(self, x): |
| if self._validate_args: |
| self._validate_sample(x) |
| x = self._pad(x) |
| log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] |
| log_mix_prob = torch.log_softmax( |
| self.mixture_distribution.logits, dim=-1 |
| ) # [B, k] |
| return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] |
| |
| def sample(self, sample_shape=torch.Size()): |
| with torch.no_grad(): |
| sample_len = len(sample_shape) |
| batch_len = len(self.batch_shape) |
| gather_dim = sample_len + batch_len |
| es = self.event_shape |
| |
| # mixture samples [n, B] |
| mix_sample = self.mixture_distribution.sample(sample_shape) |
| mix_shape = mix_sample.shape |
| |
| # component samples [n, B, k, E] |
| comp_samples = self.component_distribution.sample(sample_shape) |
| |
| # Gather along the k dimension |
| mix_sample_r = mix_sample.reshape( |
| mix_shape + torch.Size([1] * (len(es) + 1)) |
| ) |
| mix_sample_r = mix_sample_r.repeat( |
| torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es |
| ) |
| |
| samples = torch.gather(comp_samples, gather_dim, mix_sample_r) |
| return samples.squeeze(gather_dim) |
| |
| def _pad(self, x): |
| return x.unsqueeze(-1 - self._event_ndims) |
| |
| def _pad_mixture_dimensions(self, x): |
| dist_batch_ndims = self.batch_shape.numel() |
| cat_batch_ndims = self.mixture_distribution.batch_shape.numel() |
| pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims |
| xs = x.shape |
| x = x.reshape( |
| xs[:-1] |
| + torch.Size(pad_ndims * [1]) |
| + xs[-1:] |
| + torch.Size(self._event_ndims * [1]) |
| ) |
| return x |
| |
| def __repr__(self): |
| args_string = ( |
| f"\n {self.mixture_distribution},\n {self.component_distribution}" |
| ) |
| return "MixtureSameFamily" + "(" + args_string + ")" |