Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
172
venv/Lib/site-packages/torch/distributions/__init__.py
Normal file
172
venv/Lib/site-packages/torch/distributions/__init__.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
r"""
|
||||
The ``distributions`` package contains parameterizable probability distributions
|
||||
and sampling functions. This allows the construction of stochastic computation
|
||||
graphs and stochastic gradient estimators for optimization. This package
|
||||
generally follows the design of the `TensorFlow Distributions`_ package.
|
||||
|
||||
.. _`TensorFlow Distributions`:
|
||||
https://arxiv.org/abs/1711.10604
|
||||
|
||||
It is not possible to directly backpropagate through random samples. However,
|
||||
there are two main methods for creating surrogate functions that can be
|
||||
backpropagated through. These are the score function estimator/likelihood ratio
|
||||
estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
|
||||
seen as the basis for policy gradient methods in reinforcement learning, and the
|
||||
pathwise derivative estimator is commonly seen in the reparameterization trick
|
||||
in variational autoencoders. Whilst the score function only requires the value
|
||||
of samples :math:`f(x)`, the pathwise derivative requires the derivative
|
||||
:math:`f'(x)`. The next sections discuss these two in a reinforcement learning
|
||||
example. For more details see
|
||||
`Gradient Estimation Using Stochastic Computation Graphs`_ .
|
||||
|
||||
.. _`Gradient Estimation Using Stochastic Computation Graphs`:
|
||||
https://arxiv.org/abs/1506.05254
|
||||
|
||||
Score function
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
When the probability density function is differentiable with respect to its
|
||||
parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
|
||||
:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
|
||||
|
||||
.. math::
|
||||
|
||||
\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
|
||||
|
||||
where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
|
||||
:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
|
||||
taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
|
||||
|
||||
In practice we would sample an action from the output of a network, apply this
|
||||
action in an environment, and then use ``log_prob`` to construct an equivalent
|
||||
loss function. Note that we use a negative because optimizers use gradient
|
||||
descent, whilst the rule above assumes gradient ascent. With a categorical
|
||||
policy, the code for implementing REINFORCE would be as follows::
|
||||
|
||||
probs = policy_network(state)
|
||||
# Note that this is equivalent to what used to be called multinomial
|
||||
m = Categorical(probs)
|
||||
action = m.sample()
|
||||
next_state, reward = env.step(action)
|
||||
loss = -m.log_prob(action) * reward
|
||||
loss.backward()
|
||||
|
||||
Pathwise derivative
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The other way to implement these stochastic/policy gradients would be to use the
|
||||
reparameterization trick from the
|
||||
:meth:`~torch.distributions.Distribution.rsample` method, where the
|
||||
parameterized random variable can be constructed via a parameterized
|
||||
deterministic function of a parameter-free random variable. The reparameterized
|
||||
sample therefore becomes differentiable. The code for implementing the pathwise
|
||||
derivative would be as follows::
|
||||
|
||||
params = policy_network(state)
|
||||
m = Normal(*params)
|
||||
# Any distribution with .has_rsample == True could work based on the application
|
||||
action = m.rsample()
|
||||
next_state, reward = env.step(action) # Assuming that reward is differentiable
|
||||
loss = -reward
|
||||
loss.backward()
|
||||
"""
|
||||
|
||||
from . import transforms
|
||||
from .bernoulli import Bernoulli
|
||||
from .beta import Beta
|
||||
from .binomial import Binomial
|
||||
from .categorical import Categorical
|
||||
from .cauchy import Cauchy
|
||||
from .chi2 import Chi2
|
||||
from .constraint_registry import biject_to, transform_to
|
||||
from .continuous_bernoulli import ContinuousBernoulli
|
||||
from .dirichlet import Dirichlet
|
||||
from .distribution import Distribution
|
||||
from .exp_family import ExponentialFamily
|
||||
from .exponential import Exponential
|
||||
from .fishersnedecor import FisherSnedecor
|
||||
from .gamma import Gamma
|
||||
from .geometric import Geometric
|
||||
from .gumbel import Gumbel
|
||||
from .half_cauchy import HalfCauchy
|
||||
from .half_normal import HalfNormal
|
||||
from .independent import Independent
|
||||
from .inverse_gamma import InverseGamma
|
||||
from .kl import _add_kl_info, kl_divergence, register_kl
|
||||
from .kumaraswamy import Kumaraswamy
|
||||
from .laplace import Laplace
|
||||
from .lkj_cholesky import LKJCholesky
|
||||
from .log_normal import LogNormal
|
||||
from .logistic_normal import LogisticNormal
|
||||
from .lowrank_multivariate_normal import LowRankMultivariateNormal
|
||||
from .mixture_same_family import MixtureSameFamily
|
||||
from .multinomial import Multinomial
|
||||
from .multivariate_normal import MultivariateNormal
|
||||
from .negative_binomial import NegativeBinomial
|
||||
from .normal import Normal
|
||||
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
|
||||
from .pareto import Pareto
|
||||
from .poisson import Poisson
|
||||
from .relaxed_bernoulli import RelaxedBernoulli
|
||||
from .relaxed_categorical import RelaxedOneHotCategorical
|
||||
from .studentT import StudentT
|
||||
from .transformed_distribution import TransformedDistribution
|
||||
from .transforms import * # noqa: F403
|
||||
from .uniform import Uniform
|
||||
from .von_mises import VonMises
|
||||
from .weibull import Weibull
|
||||
from .wishart import Wishart
|
||||
|
||||
|
||||
_add_kl_info()
|
||||
del _add_kl_info
|
||||
|
||||
__all__ = [
|
||||
"Bernoulli",
|
||||
"Beta",
|
||||
"Binomial",
|
||||
"Categorical",
|
||||
"Cauchy",
|
||||
"Chi2",
|
||||
"ContinuousBernoulli",
|
||||
"Dirichlet",
|
||||
"Distribution",
|
||||
"Exponential",
|
||||
"ExponentialFamily",
|
||||
"FisherSnedecor",
|
||||
"Gamma",
|
||||
"Geometric",
|
||||
"Gumbel",
|
||||
"HalfCauchy",
|
||||
"HalfNormal",
|
||||
"Independent",
|
||||
"InverseGamma",
|
||||
"Kumaraswamy",
|
||||
"LKJCholesky",
|
||||
"Laplace",
|
||||
"LogNormal",
|
||||
"LogisticNormal",
|
||||
"LowRankMultivariateNormal",
|
||||
"MixtureSameFamily",
|
||||
"Multinomial",
|
||||
"MultivariateNormal",
|
||||
"NegativeBinomial",
|
||||
"Normal",
|
||||
"OneHotCategorical",
|
||||
"OneHotCategoricalStraightThrough",
|
||||
"Pareto",
|
||||
"RelaxedBernoulli",
|
||||
"RelaxedOneHotCategorical",
|
||||
"StudentT",
|
||||
"Poisson",
|
||||
"Uniform",
|
||||
"VonMises",
|
||||
"Weibull",
|
||||
"Wishart",
|
||||
"TransformedDistribution",
|
||||
"biject_to",
|
||||
"kl_divergence",
|
||||
"register_kl",
|
||||
"transform_to",
|
||||
]
|
||||
__all__.extend(transforms.__all__)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
132
venv/Lib/site-packages/torch/distributions/bernoulli.py
Normal file
132
venv/Lib/site-packages/torch/distributions/bernoulli.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import (
|
||||
broadcast_all,
|
||||
lazy_property,
|
||||
logits_to_probs,
|
||||
probs_to_logits,
|
||||
)
|
||||
from torch.nn.functional import binary_cross_entropy_with_logits
|
||||
from torch.types import _Number
|
||||
|
||||
|
||||
__all__ = ["Bernoulli"]
|
||||
|
||||
|
||||
class Bernoulli(ExponentialFamily):
|
||||
r"""
|
||||
Creates a Bernoulli distribution parameterized by :attr:`probs`
|
||||
or :attr:`logits` (but not both).
|
||||
|
||||
Samples are binary (0 or 1). They take the value `1` with probability `p`
|
||||
and `0` with probability `1 - p`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Bernoulli(torch.tensor([0.3]))
|
||||
>>> m.sample() # 30% chance 1; 70% chance 0
|
||||
tensor([ 0.])
|
||||
|
||||
Args:
|
||||
probs (Number, Tensor): the probability of sampling `1`
|
||||
logits (Number, Tensor): the log-odds of sampling `1`
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.boolean
|
||||
has_enumerate_support = True
|
||||
_mean_carrier_measure = 0
|
||||
|
||||
def __init__(self, probs=None, logits=None, validate_args=None):
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError(
|
||||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
is_scalar = isinstance(probs, _Number)
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
else:
|
||||
is_scalar = isinstance(logits, _Number)
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
if is_scalar:
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self._param.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Bernoulli, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
if "probs" in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
if "logits" in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(Bernoulli, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._param.new(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.probs
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
mode = (self.probs >= 0.5).to(self.probs)
|
||||
mode[self.probs == 0.5] = nan
|
||||
return mode
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.probs * (1 - self.probs)
|
||||
|
||||
@lazy_property
|
||||
def logits(self) -> Tensor:
|
||||
return probs_to_logits(self.probs, is_binary=True)
|
||||
|
||||
@lazy_property
|
||||
def probs(self) -> Tensor:
|
||||
return logits_to_probs(self.logits, is_binary=True)
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._param.size()
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
with torch.no_grad():
|
||||
return torch.bernoulli(self.probs.expand(shape))
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
return -binary_cross_entropy_with_logits(logits, value, reduction="none")
|
||||
|
||||
def entropy(self):
|
||||
return binary_cross_entropy_with_logits(
|
||||
self.logits, self.probs, reduction="none"
|
||||
)
|
||||
|
||||
def enumerate_support(self, expand=True):
|
||||
values = torch.arange(2, dtype=self._param.dtype, device=self._param.device)
|
||||
values = values.view((-1,) + (1,) * len(self._batch_shape))
|
||||
if expand:
|
||||
values = values.expand((-1,) + self._batch_shape)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (torch.logit(self.probs),)
|
||||
|
||||
def _log_normalizer(self, x):
|
||||
return torch.log1p(torch.exp(x))
|
110
venv/Lib/site-packages/torch/distributions/beta.py
Normal file
110
venv/Lib/site-packages/torch/distributions/beta.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.dirichlet import Dirichlet
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Beta"]
|
||||
|
||||
|
||||
class Beta(ExponentialFamily):
|
||||
r"""
|
||||
Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
|
||||
>>> m.sample() # Beta distributed with concentration concentration1 and concentration0
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
concentration1 (float or Tensor): 1st concentration parameter of the distribution
|
||||
(often referred to as alpha)
|
||||
concentration0 (float or Tensor): 2nd concentration parameter of the distribution
|
||||
(often referred to as beta)
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"concentration1": constraints.positive,
|
||||
"concentration0": constraints.positive,
|
||||
}
|
||||
support = constraints.unit_interval
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, concentration1, concentration0, validate_args=None):
|
||||
if isinstance(concentration1, _Number) and isinstance(concentration0, _Number):
|
||||
concentration1_concentration0 = torch.tensor(
|
||||
[float(concentration1), float(concentration0)]
|
||||
)
|
||||
else:
|
||||
concentration1, concentration0 = broadcast_all(
|
||||
concentration1, concentration0
|
||||
)
|
||||
concentration1_concentration0 = torch.stack(
|
||||
[concentration1, concentration0], -1
|
||||
)
|
||||
self._dirichlet = Dirichlet(
|
||||
concentration1_concentration0, validate_args=validate_args
|
||||
)
|
||||
super().__init__(self._dirichlet._batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Beta, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new._dirichlet = self._dirichlet.expand(batch_shape)
|
||||
super(Beta, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.concentration1 / (self.concentration1 + self.concentration0)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self._dirichlet.mode[..., 0]
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
total = self.concentration1 + self.concentration0
|
||||
return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
|
||||
|
||||
def rsample(self, sample_shape: _size = ()) -> Tensor:
|
||||
return self._dirichlet.rsample(sample_shape).select(-1, 0)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
heads_tails = torch.stack([value, 1.0 - value], -1)
|
||||
return self._dirichlet.log_prob(heads_tails)
|
||||
|
||||
def entropy(self):
|
||||
return self._dirichlet.entropy()
|
||||
|
||||
@property
|
||||
def concentration1(self) -> Tensor:
|
||||
result = self._dirichlet.concentration[..., 0]
|
||||
if isinstance(result, _Number):
|
||||
return torch.tensor([result])
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def concentration0(self) -> Tensor:
|
||||
result = self._dirichlet.concentration[..., 1]
|
||||
if isinstance(result, _Number):
|
||||
return torch.tensor([result])
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||
return (self.concentration1, self.concentration0)
|
||||
|
||||
def _log_normalizer(self, x, y):
|
||||
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
|
169
venv/Lib/site-packages/torch/distributions/binomial.py
Normal file
169
venv/Lib/site-packages/torch/distributions/binomial.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import (
|
||||
broadcast_all,
|
||||
lazy_property,
|
||||
logits_to_probs,
|
||||
probs_to_logits,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Binomial"]
|
||||
|
||||
|
||||
def _clamp_by_zero(x):
|
||||
# works like clamp(x, min=0) but has grad at 0 is 0.5
|
||||
return (x.clamp(min=0) + x - x.clamp(max=0)) / 2
|
||||
|
||||
|
||||
class Binomial(Distribution):
|
||||
r"""
|
||||
Creates a Binomial distribution parameterized by :attr:`total_count` and
|
||||
either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be
|
||||
broadcastable with :attr:`probs`/:attr:`logits`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
|
||||
>>> x = m.sample()
|
||||
tensor([ 0., 22., 71., 100.])
|
||||
|
||||
>>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
|
||||
>>> x = m.sample()
|
||||
tensor([[ 4., 5.],
|
||||
[ 7., 6.]])
|
||||
|
||||
Args:
|
||||
total_count (int or Tensor): number of Bernoulli trials
|
||||
probs (Tensor): Event probabilities
|
||||
logits (Tensor): Event log-odds
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"total_count": constraints.nonnegative_integer,
|
||||
"probs": constraints.unit_interval,
|
||||
"logits": constraints.real,
|
||||
}
|
||||
has_enumerate_support = True
|
||||
|
||||
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError(
|
||||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
(
|
||||
self.total_count,
|
||||
self.probs,
|
||||
) = broadcast_all(total_count, probs)
|
||||
self.total_count = self.total_count.type_as(self.probs)
|
||||
else:
|
||||
(
|
||||
self.total_count,
|
||||
self.logits,
|
||||
) = broadcast_all(total_count, logits)
|
||||
self.total_count = self.total_count.type_as(self.logits)
|
||||
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
batch_shape = self._param.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Binomial, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.total_count = self.total_count.expand(batch_shape)
|
||||
if "probs" in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
if "logits" in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(Binomial, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._param.new(*args, **kwargs)
|
||||
|
||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||
def support(self):
|
||||
return constraints.integer_interval(0, self.total_count)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.total_count * self.probs
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.total_count * self.probs * (1 - self.probs)
|
||||
|
||||
@lazy_property
|
||||
def logits(self) -> Tensor:
|
||||
return probs_to_logits(self.probs, is_binary=True)
|
||||
|
||||
@lazy_property
|
||||
def probs(self) -> Tensor:
|
||||
return logits_to_probs(self.logits, is_binary=True)
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._param.size()
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
with torch.no_grad():
|
||||
return torch.binomial(
|
||||
self.total_count.expand(shape), self.probs.expand(shape)
|
||||
)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
log_factorial_n = torch.lgamma(self.total_count + 1)
|
||||
log_factorial_k = torch.lgamma(value + 1)
|
||||
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
|
||||
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
|
||||
# (case logit < 0) = k * logit - n * log1p(e^logit)
|
||||
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
|
||||
# = k * logit - n * logit - n * log1p(e^-logit)
|
||||
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
|
||||
normalize_term = (
|
||||
self.total_count * _clamp_by_zero(self.logits)
|
||||
+ self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))
|
||||
- log_factorial_n
|
||||
)
|
||||
return (
|
||||
value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
total_count = int(self.total_count.max())
|
||||
if not self.total_count.min() == total_count:
|
||||
raise NotImplementedError(
|
||||
"Inhomogeneous total count not supported by `entropy`."
|
||||
)
|
||||
|
||||
log_prob = self.log_prob(self.enumerate_support(False))
|
||||
return -(torch.exp(log_prob) * log_prob).sum(0)
|
||||
|
||||
def enumerate_support(self, expand=True):
|
||||
total_count = int(self.total_count.max())
|
||||
if not self.total_count.min() == total_count:
|
||||
raise NotImplementedError(
|
||||
"Inhomogeneous total count not supported by `enumerate_support`."
|
||||
)
|
||||
values = torch.arange(
|
||||
1 + total_count, dtype=self._param.dtype, device=self._param.device
|
||||
)
|
||||
values = values.view((-1,) + (1,) * len(self._batch_shape))
|
||||
if expand:
|
||||
values = values.expand((-1,) + self._batch_shape)
|
||||
return values
|
158
venv/Lib/site-packages/torch/distributions/categorical.py
Normal file
158
venv/Lib/site-packages/torch/distributions/categorical.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits
|
||||
|
||||
|
||||
__all__ = ["Categorical"]
|
||||
|
||||
|
||||
class Categorical(Distribution):
|
||||
r"""
|
||||
Creates a categorical distribution parameterized by either :attr:`probs` or
|
||||
:attr:`logits` (but not both).
|
||||
|
||||
.. note::
|
||||
It is equivalent to the distribution that :func:`torch.multinomial`
|
||||
samples from.
|
||||
|
||||
Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
|
||||
|
||||
If `probs` is 1-dimensional with length-`K`, each element is the relative probability
|
||||
of sampling the class at that index.
|
||||
|
||||
If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
|
||||
relative probability vectors.
|
||||
|
||||
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
|
||||
will return this normalized value.
|
||||
The `logits` argument will be interpreted as unnormalized log probabilities
|
||||
and can therefore be any real number. It will likewise be normalized so that
|
||||
the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
|
||||
will return this normalized value.
|
||||
|
||||
See also: :func:`torch.multinomial`
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
|
||||
>>> m.sample() # equal probability of 0, 1, 2, 3
|
||||
tensor(3)
|
||||
|
||||
Args:
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
has_enumerate_support = True
|
||||
|
||||
def __init__(self, probs=None, logits=None, validate_args=None):
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError(
|
||||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
if probs.dim() < 1:
|
||||
raise ValueError("`probs` parameter must be at least one-dimensional.")
|
||||
self.probs = probs / probs.sum(-1, keepdim=True)
|
||||
else:
|
||||
if logits.dim() < 1:
|
||||
raise ValueError("`logits` parameter must be at least one-dimensional.")
|
||||
# Normalize
|
||||
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
self._num_events = self._param.size()[-1]
|
||||
batch_shape = (
|
||||
self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
|
||||
)
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Categorical, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
param_shape = batch_shape + torch.Size((self._num_events,))
|
||||
if "probs" in self.__dict__:
|
||||
new.probs = self.probs.expand(param_shape)
|
||||
new._param = new.probs
|
||||
if "logits" in self.__dict__:
|
||||
new.logits = self.logits.expand(param_shape)
|
||||
new._param = new.logits
|
||||
new._num_events = self._num_events
|
||||
super(Categorical, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._param.new(*args, **kwargs)
|
||||
|
||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||
def support(self):
|
||||
return constraints.integer_interval(0, self._num_events - 1)
|
||||
|
||||
@lazy_property
|
||||
def logits(self) -> Tensor:
|
||||
return probs_to_logits(self.probs)
|
||||
|
||||
@lazy_property
|
||||
def probs(self) -> Tensor:
|
||||
return logits_to_probs(self.logits)
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._param.size()
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return torch.full(
|
||||
self._extended_shape(),
|
||||
nan,
|
||||
dtype=self.probs.dtype,
|
||||
device=self.probs.device,
|
||||
)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.probs.argmax(dim=-1)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return torch.full(
|
||||
self._extended_shape(),
|
||||
nan,
|
||||
dtype=self.probs.dtype,
|
||||
device=self.probs.device,
|
||||
)
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
if not isinstance(sample_shape, torch.Size):
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
probs_2d = self.probs.reshape(-1, self._num_events)
|
||||
samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
|
||||
return samples_2d.reshape(self._extended_shape(sample_shape))
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
value = value.long().unsqueeze(-1)
|
||||
value, log_pmf = torch.broadcast_tensors(value, self.logits)
|
||||
value = value[..., :1]
|
||||
return log_pmf.gather(-1, value).squeeze(-1)
|
||||
|
||||
def entropy(self):
|
||||
min_real = torch.finfo(self.logits.dtype).min
|
||||
logits = torch.clamp(self.logits, min=min_real)
|
||||
p_log_p = logits * self.probs
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
def enumerate_support(self, expand=True):
|
||||
num_events = self._num_events
|
||||
values = torch.arange(num_events, dtype=torch.long, device=self._param.device)
|
||||
values = values.view((-1,) + (1,) * len(self._batch_shape))
|
||||
if expand:
|
||||
values = values.expand((-1,) + self._batch_shape)
|
||||
return values
|
93
venv/Lib/site-packages/torch/distributions/cauchy.py
Normal file
93
venv/Lib/site-packages/torch/distributions/cauchy.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import inf, nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Cauchy"]
|
||||
|
||||
|
||||
class Cauchy(Distribution):
|
||||
r"""
|
||||
Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of
|
||||
independent normally distributed random variables with means `0` follows a
|
||||
Cauchy distribution.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1
|
||||
tensor([ 2.3214])
|
||||
|
||||
Args:
|
||||
loc (float or Tensor): mode or median of the distribution.
|
||||
scale (float or Tensor): half width at half maximum.
|
||||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, loc, scale, validate_args=None):
|
||||
self.loc, self.scale = broadcast_all(loc, scale)
|
||||
if isinstance(loc, _Number) and isinstance(scale, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.loc.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Cauchy, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.loc = self.loc.expand(batch_shape)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
super(Cauchy, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return torch.full(
|
||||
self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device
|
||||
)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return torch.full(
|
||||
self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device
|
||||
)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
eps = self.loc.new(shape).cauchy_()
|
||||
return self.loc + eps * self.scale
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return (
|
||||
-math.log(math.pi)
|
||||
- self.scale.log()
|
||||
- (((value - self.loc) / self.scale) ** 2).log1p()
|
||||
)
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5
|
||||
|
||||
def icdf(self, value):
|
||||
return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc
|
||||
|
||||
def entropy(self):
|
||||
return math.log(4 * math.pi) + self.scale.log()
|
37
venv/Lib/site-packages/torch/distributions/chi2.py
Normal file
37
venv/Lib/site-packages/torch/distributions/chi2.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.gamma import Gamma
|
||||
|
||||
|
||||
__all__ = ["Chi2"]
|
||||
|
||||
|
||||
class Chi2(Gamma):
|
||||
r"""
|
||||
Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`.
|
||||
This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)``
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Chi2(torch.tensor([1.0]))
|
||||
>>> m.sample() # Chi2 distributed with shape df=1
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
df (float or Tensor): shape parameter of the distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"df": constraints.positive}
|
||||
|
||||
def __init__(self, df, validate_args=None):
|
||||
super().__init__(0.5 * df, 0.5, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Chi2, _instance)
|
||||
return super().expand(batch_shape, new)
|
||||
|
||||
@property
|
||||
def df(self) -> Tensor:
|
||||
return self.concentration * 2
|
|
@ -0,0 +1,291 @@
|
|||
# mypy: allow-untyped-defs
|
||||
r"""
|
||||
PyTorch provides two global :class:`ConstraintRegistry` objects that link
|
||||
:class:`~torch.distributions.constraints.Constraint` objects to
|
||||
:class:`~torch.distributions.transforms.Transform` objects. These objects both
|
||||
input constraints and return transforms, but they have different guarantees on
|
||||
bijectivity.
|
||||
|
||||
1. ``biject_to(constraint)`` looks up a bijective
|
||||
:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
|
||||
to the given ``constraint``. The returned transform is guaranteed to have
|
||||
``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
|
||||
2. ``transform_to(constraint)`` looks up a not-necessarily bijective
|
||||
:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
|
||||
to the given ``constraint``. The returned transform is not guaranteed to
|
||||
implement ``.log_abs_det_jacobian()``.
|
||||
|
||||
The ``transform_to()`` registry is useful for performing unconstrained
|
||||
optimization on constrained parameters of probability distributions, which are
|
||||
indicated by each distribution's ``.arg_constraints`` dict. These transforms often
|
||||
overparameterize a space in order to avoid rotation; they are thus more
|
||||
suitable for coordinate-wise optimization algorithms like Adam::
|
||||
|
||||
loc = torch.zeros(100, requires_grad=True)
|
||||
unconstrained = torch.zeros(100, requires_grad=True)
|
||||
scale = transform_to(Normal.arg_constraints["scale"])(unconstrained)
|
||||
loss = -Normal(loc, scale).log_prob(data).sum()
|
||||
|
||||
The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
|
||||
samples from a probability distribution with constrained ``.support`` are
|
||||
propagated in an unconstrained space, and algorithms are typically rotation
|
||||
invariant.::
|
||||
|
||||
dist = Exponential(rate)
|
||||
unconstrained = torch.zeros(100, requires_grad=True)
|
||||
sample = biject_to(dist.support)(unconstrained)
|
||||
potential_energy = -dist.log_prob(sample).sum()
|
||||
|
||||
.. note::
|
||||
|
||||
An example where ``transform_to`` and ``biject_to`` differ is
|
||||
``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
|
||||
:class:`~torch.distributions.transforms.SoftmaxTransform` that simply
|
||||
exponentiates and normalizes its inputs; this is a cheap and mostly
|
||||
coordinate-wise operation appropriate for algorithms like SVI. In
|
||||
contrast, ``biject_to(constraints.simplex)`` returns a
|
||||
:class:`~torch.distributions.transforms.StickBreakingTransform` that
|
||||
bijects its input down to a one-fewer-dimensional space; this a more
|
||||
expensive less numerically stable transform but is needed for algorithms
|
||||
like HMC.
|
||||
|
||||
The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
|
||||
constraints and transforms using their ``.register()`` method either as a
|
||||
function on singleton constraints::
|
||||
|
||||
transform_to.register(my_constraint, my_transform)
|
||||
|
||||
or as a decorator on parameterized constraints::
|
||||
|
||||
@transform_to.register(MyConstraintClass)
|
||||
def my_factory(constraint):
|
||||
assert isinstance(constraint, MyConstraintClass)
|
||||
return MyTransform(constraint.param1, constraint.param2)
|
||||
|
||||
You can create your own registry by creating a new :class:`ConstraintRegistry`
|
||||
object.
|
||||
"""
|
||||
|
||||
from torch.distributions import constraints, transforms
|
||||
from torch.types import _Number
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConstraintRegistry",
|
||||
"biject_to",
|
||||
"transform_to",
|
||||
]
|
||||
|
||||
|
||||
class ConstraintRegistry:
|
||||
"""
|
||||
Registry to link constraints to transforms.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry = {}
|
||||
super().__init__()
|
||||
|
||||
def register(self, constraint, factory=None):
|
||||
"""
|
||||
Registers a :class:`~torch.distributions.constraints.Constraint`
|
||||
subclass in this registry. Usage::
|
||||
|
||||
@my_registry.register(MyConstraintClass)
|
||||
def construct_transform(constraint):
|
||||
assert isinstance(constraint, MyConstraint)
|
||||
return MyTransform(constraint.arg_constraints)
|
||||
|
||||
Args:
|
||||
constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
|
||||
A subclass of :class:`~torch.distributions.constraints.Constraint`, or
|
||||
a singleton object of the desired class.
|
||||
factory (Callable): A callable that inputs a constraint object and returns
|
||||
a :class:`~torch.distributions.transforms.Transform` object.
|
||||
"""
|
||||
# Support use as decorator.
|
||||
if factory is None:
|
||||
return lambda factory: self.register(constraint, factory)
|
||||
|
||||
# Support calling on singleton instances.
|
||||
if isinstance(constraint, constraints.Constraint):
|
||||
constraint = type(constraint)
|
||||
|
||||
if not isinstance(constraint, type) or not issubclass(
|
||||
constraint, constraints.Constraint
|
||||
):
|
||||
raise TypeError(
|
||||
f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
|
||||
)
|
||||
|
||||
self._registry[constraint] = factory
|
||||
return factory
|
||||
|
||||
def __call__(self, constraint):
|
||||
"""
|
||||
Looks up a transform to constrained space, given a constraint object.
|
||||
Usage::
|
||||
|
||||
constraint = Normal.arg_constraints["scale"]
|
||||
scale = transform_to(constraint)(torch.zeros(1)) # constrained
|
||||
u = transform_to(constraint).inv(scale) # unconstrained
|
||||
|
||||
Args:
|
||||
constraint (:class:`~torch.distributions.constraints.Constraint`):
|
||||
A constraint object.
|
||||
|
||||
Returns:
|
||||
A :class:`~torch.distributions.transforms.Transform` object.
|
||||
|
||||
Raises:
|
||||
`NotImplementedError` if no transform has been registered.
|
||||
"""
|
||||
# Look up by Constraint subclass.
|
||||
try:
|
||||
factory = self._registry[type(constraint)]
|
||||
except KeyError:
|
||||
raise NotImplementedError(
|
||||
f"Cannot transform {type(constraint).__name__} constraints"
|
||||
) from None
|
||||
return factory(constraint)
|
||||
|
||||
|
||||
biject_to = ConstraintRegistry()
|
||||
transform_to = ConstraintRegistry()
|
||||
|
||||
|
||||
################################################################################
|
||||
# Registration Table
|
||||
################################################################################
|
||||
|
||||
|
||||
@biject_to.register(constraints.real)
|
||||
@transform_to.register(constraints.real)
|
||||
def _transform_to_real(constraint):
|
||||
return transforms.identity_transform
|
||||
|
||||
|
||||
@biject_to.register(constraints.independent)
|
||||
def _biject_to_independent(constraint):
|
||||
base_transform = biject_to(constraint.base_constraint)
|
||||
return transforms.IndependentTransform(
|
||||
base_transform, constraint.reinterpreted_batch_ndims
|
||||
)
|
||||
|
||||
|
||||
@transform_to.register(constraints.independent)
|
||||
def _transform_to_independent(constraint):
|
||||
base_transform = transform_to(constraint.base_constraint)
|
||||
return transforms.IndependentTransform(
|
||||
base_transform, constraint.reinterpreted_batch_ndims
|
||||
)
|
||||
|
||||
|
||||
@biject_to.register(constraints.positive)
|
||||
@biject_to.register(constraints.nonnegative)
|
||||
@transform_to.register(constraints.positive)
|
||||
@transform_to.register(constraints.nonnegative)
|
||||
def _transform_to_positive(constraint):
|
||||
return transforms.ExpTransform()
|
||||
|
||||
|
||||
@biject_to.register(constraints.greater_than)
|
||||
@biject_to.register(constraints.greater_than_eq)
|
||||
@transform_to.register(constraints.greater_than)
|
||||
@transform_to.register(constraints.greater_than_eq)
|
||||
def _transform_to_greater_than(constraint):
|
||||
return transforms.ComposeTransform(
|
||||
[
|
||||
transforms.ExpTransform(),
|
||||
transforms.AffineTransform(constraint.lower_bound, 1),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@biject_to.register(constraints.less_than)
|
||||
@transform_to.register(constraints.less_than)
|
||||
def _transform_to_less_than(constraint):
|
||||
return transforms.ComposeTransform(
|
||||
[
|
||||
transforms.ExpTransform(),
|
||||
transforms.AffineTransform(constraint.upper_bound, -1),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@biject_to.register(constraints.interval)
|
||||
@biject_to.register(constraints.half_open_interval)
|
||||
@transform_to.register(constraints.interval)
|
||||
@transform_to.register(constraints.half_open_interval)
|
||||
def _transform_to_interval(constraint):
|
||||
# Handle the special case of the unit interval.
|
||||
lower_is_0 = (
|
||||
isinstance(constraint.lower_bound, _Number) and constraint.lower_bound == 0
|
||||
)
|
||||
upper_is_1 = (
|
||||
isinstance(constraint.upper_bound, _Number) and constraint.upper_bound == 1
|
||||
)
|
||||
if lower_is_0 and upper_is_1:
|
||||
return transforms.SigmoidTransform()
|
||||
|
||||
loc = constraint.lower_bound
|
||||
scale = constraint.upper_bound - constraint.lower_bound
|
||||
return transforms.ComposeTransform(
|
||||
[transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]
|
||||
)
|
||||
|
||||
|
||||
@biject_to.register(constraints.simplex)
|
||||
def _biject_to_simplex(constraint):
|
||||
return transforms.StickBreakingTransform()
|
||||
|
||||
|
||||
@transform_to.register(constraints.simplex)
|
||||
def _transform_to_simplex(constraint):
|
||||
return transforms.SoftmaxTransform()
|
||||
|
||||
|
||||
# TODO define a bijection for LowerCholeskyTransform
|
||||
@transform_to.register(constraints.lower_cholesky)
|
||||
def _transform_to_lower_cholesky(constraint):
|
||||
return transforms.LowerCholeskyTransform()
|
||||
|
||||
|
||||
@transform_to.register(constraints.positive_definite)
|
||||
@transform_to.register(constraints.positive_semidefinite)
|
||||
def _transform_to_positive_definite(constraint):
|
||||
return transforms.PositiveDefiniteTransform()
|
||||
|
||||
|
||||
@biject_to.register(constraints.corr_cholesky)
|
||||
@transform_to.register(constraints.corr_cholesky)
|
||||
def _transform_to_corr_cholesky(constraint):
|
||||
return transforms.CorrCholeskyTransform()
|
||||
|
||||
|
||||
@biject_to.register(constraints.cat)
|
||||
def _biject_to_cat(constraint):
|
||||
return transforms.CatTransform(
|
||||
[biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
|
||||
)
|
||||
|
||||
|
||||
@transform_to.register(constraints.cat)
|
||||
def _transform_to_cat(constraint):
|
||||
return transforms.CatTransform(
|
||||
[transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
|
||||
)
|
||||
|
||||
|
||||
@biject_to.register(constraints.stack)
|
||||
def _biject_to_stack(constraint):
|
||||
return transforms.StackTransform(
|
||||
[biject_to(c) for c in constraint.cseq], constraint.dim
|
||||
)
|
||||
|
||||
|
||||
@transform_to.register(constraints.stack)
|
||||
def _transform_to_stack(constraint):
|
||||
return transforms.StackTransform(
|
||||
[transform_to(c) for c in constraint.cseq], constraint.dim
|
||||
)
|
689
venv/Lib/site-packages/torch/distributions/constraints.py
Normal file
689
venv/Lib/site-packages/torch/distributions/constraints.py
Normal file
|
@ -0,0 +1,689 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
r"""
|
||||
The following constraints are implemented:
|
||||
|
||||
- ``constraints.boolean``
|
||||
- ``constraints.cat``
|
||||
- ``constraints.corr_cholesky``
|
||||
- ``constraints.dependent``
|
||||
- ``constraints.greater_than(lower_bound)``
|
||||
- ``constraints.greater_than_eq(lower_bound)``
|
||||
- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
|
||||
- ``constraints.integer_interval(lower_bound, upper_bound)``
|
||||
- ``constraints.interval(lower_bound, upper_bound)``
|
||||
- ``constraints.less_than(upper_bound)``
|
||||
- ``constraints.lower_cholesky``
|
||||
- ``constraints.lower_triangular``
|
||||
- ``constraints.multinomial``
|
||||
- ``constraints.nonnegative``
|
||||
- ``constraints.nonnegative_integer``
|
||||
- ``constraints.one_hot``
|
||||
- ``constraints.positive_integer``
|
||||
- ``constraints.positive``
|
||||
- ``constraints.positive_semidefinite``
|
||||
- ``constraints.positive_definite``
|
||||
- ``constraints.real_vector``
|
||||
- ``constraints.real``
|
||||
- ``constraints.simplex``
|
||||
- ``constraints.symmetric``
|
||||
- ``constraints.stack``
|
||||
- ``constraints.square``
|
||||
- ``constraints.symmetric``
|
||||
- ``constraints.unit_interval``
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Constraint",
|
||||
"boolean",
|
||||
"cat",
|
||||
"corr_cholesky",
|
||||
"dependent",
|
||||
"dependent_property",
|
||||
"greater_than",
|
||||
"greater_than_eq",
|
||||
"independent",
|
||||
"integer_interval",
|
||||
"interval",
|
||||
"half_open_interval",
|
||||
"is_dependent",
|
||||
"less_than",
|
||||
"lower_cholesky",
|
||||
"lower_triangular",
|
||||
"multinomial",
|
||||
"nonnegative",
|
||||
"nonnegative_integer",
|
||||
"one_hot",
|
||||
"positive",
|
||||
"positive_semidefinite",
|
||||
"positive_definite",
|
||||
"positive_integer",
|
||||
"real",
|
||||
"real_vector",
|
||||
"simplex",
|
||||
"square",
|
||||
"stack",
|
||||
"symmetric",
|
||||
"unit_interval",
|
||||
]
|
||||
|
||||
|
||||
class Constraint:
|
||||
"""
|
||||
Abstract base class for constraints.
|
||||
|
||||
A constraint object represents a region over which a variable is valid,
|
||||
e.g. within which a variable can be optimized.
|
||||
|
||||
Attributes:
|
||||
is_discrete (bool): Whether constrained space is discrete.
|
||||
Defaults to False.
|
||||
event_dim (int): Number of rightmost dimensions that together define
|
||||
an event. The :meth:`check` method will remove this many dimensions
|
||||
when computing validity.
|
||||
"""
|
||||
|
||||
is_discrete = False # Default to continuous.
|
||||
event_dim = 0 # Default to univariate.
|
||||
|
||||
def check(self, value):
|
||||
"""
|
||||
Returns a byte tensor of ``sample_shape + batch_shape`` indicating
|
||||
whether each event in value satisfies this constraint.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__[1:] + "()"
|
||||
|
||||
|
||||
class _Dependent(Constraint):
|
||||
"""
|
||||
Placeholder for variables whose support depends on other variables.
|
||||
These variables obey no simple coordinate-wise constraints.
|
||||
|
||||
Args:
|
||||
is_discrete (bool): Optional value of ``.is_discrete`` in case this
|
||||
can be computed statically. If not provided, access to the
|
||||
``.is_discrete`` attribute will raise a NotImplementedError.
|
||||
event_dim (int): Optional value of ``.event_dim`` in case this
|
||||
can be computed statically. If not provided, access to the
|
||||
``.event_dim`` attribute will raise a NotImplementedError.
|
||||
"""
|
||||
|
||||
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
|
||||
self._is_discrete = is_discrete
|
||||
self._event_dim = event_dim
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def is_discrete(self) -> bool: # type: ignore[override]
|
||||
if self._is_discrete is NotImplemented:
|
||||
raise NotImplementedError(".is_discrete cannot be determined statically")
|
||||
return self._is_discrete
|
||||
|
||||
@property
|
||||
def event_dim(self) -> int: # type: ignore[override]
|
||||
if self._event_dim is NotImplemented:
|
||||
raise NotImplementedError(".event_dim cannot be determined statically")
|
||||
return self._event_dim
|
||||
|
||||
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
|
||||
"""
|
||||
Support for syntax to customize static attributes::
|
||||
|
||||
constraints.dependent(is_discrete=True, event_dim=1)
|
||||
"""
|
||||
if is_discrete is NotImplemented:
|
||||
is_discrete = self._is_discrete
|
||||
if event_dim is NotImplemented:
|
||||
event_dim = self._event_dim
|
||||
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
|
||||
|
||||
def check(self, x):
|
||||
raise ValueError("Cannot determine validity of dependent constraint")
|
||||
|
||||
|
||||
def is_dependent(constraint):
|
||||
"""
|
||||
Checks if ``constraint`` is a ``_Dependent`` object.
|
||||
|
||||
Args:
|
||||
constraint : A ``Constraint`` object.
|
||||
|
||||
Returns:
|
||||
``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from torch.distributions import Bernoulli
|
||||
>>> from torch.distributions.constraints import is_dependent
|
||||
|
||||
>>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True))
|
||||
>>> constraint1 = dist.arg_constraints["probs"]
|
||||
>>> constraint2 = dist.arg_constraints["logits"]
|
||||
|
||||
>>> for constraint in [constraint1, constraint2]:
|
||||
>>> if is_dependent(constraint):
|
||||
>>> continue
|
||||
"""
|
||||
return isinstance(constraint, _Dependent)
|
||||
|
||||
|
||||
class _DependentProperty(property, _Dependent):
|
||||
"""
|
||||
Decorator that extends @property to act like a `Dependent` constraint when
|
||||
called on a class and act like a property when called on an object.
|
||||
|
||||
Example::
|
||||
|
||||
class Uniform(Distribution):
|
||||
def __init__(self, low, high):
|
||||
self.low = low
|
||||
self.high = high
|
||||
|
||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||
def support(self):
|
||||
return constraints.interval(self.low, self.high)
|
||||
|
||||
Args:
|
||||
fn (Callable): The function to be decorated.
|
||||
is_discrete (bool): Optional value of ``.is_discrete`` in case this
|
||||
can be computed statically. If not provided, access to the
|
||||
``.is_discrete`` attribute will raise a NotImplementedError.
|
||||
event_dim (int): Optional value of ``.event_dim`` in case this
|
||||
can be computed statically. If not provided, access to the
|
||||
``.event_dim`` attribute will raise a NotImplementedError.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: Optional[Callable[..., Any]] = None,
|
||||
*,
|
||||
is_discrete: Optional[bool] = NotImplemented,
|
||||
event_dim: Optional[int] = NotImplemented,
|
||||
) -> None:
|
||||
super().__init__(fn)
|
||||
self._is_discrete = is_discrete
|
||||
self._event_dim = event_dim
|
||||
|
||||
def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": # type: ignore[override]
|
||||
"""
|
||||
Support for syntax to customize static attributes::
|
||||
|
||||
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
||||
def support(self): ...
|
||||
"""
|
||||
return _DependentProperty(
|
||||
fn, is_discrete=self._is_discrete, event_dim=self._event_dim
|
||||
)
|
||||
|
||||
|
||||
class _IndependentConstraint(Constraint):
|
||||
"""
|
||||
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
|
||||
dims in :meth:`check`, so that an event is valid only if all its
|
||||
independent entries are valid.
|
||||
"""
|
||||
|
||||
def __init__(self, base_constraint, reinterpreted_batch_ndims):
|
||||
assert isinstance(base_constraint, Constraint)
|
||||
assert isinstance(reinterpreted_batch_ndims, int)
|
||||
assert reinterpreted_batch_ndims >= 0
|
||||
self.base_constraint = base_constraint
|
||||
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def is_discrete(self) -> bool: # type: ignore[override]
|
||||
return self.base_constraint.is_discrete
|
||||
|
||||
@property
|
||||
def event_dim(self) -> int: # type: ignore[override]
|
||||
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
|
||||
|
||||
def check(self, value):
|
||||
result = self.base_constraint.check(value)
|
||||
if result.dim() < self.reinterpreted_batch_ndims:
|
||||
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
|
||||
raise ValueError(
|
||||
f"Expected value.dim() >= {expected} but got {value.dim()}"
|
||||
)
|
||||
result = result.reshape(
|
||||
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
|
||||
)
|
||||
result = result.all(-1)
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
|
||||
|
||||
|
||||
class _Boolean(Constraint):
|
||||
"""
|
||||
Constrain to the two values `{0, 1}`.
|
||||
"""
|
||||
|
||||
is_discrete = True
|
||||
|
||||
def check(self, value):
|
||||
return (value == 0) | (value == 1)
|
||||
|
||||
|
||||
class _OneHot(Constraint):
|
||||
"""
|
||||
Constrain to one-hot vectors.
|
||||
"""
|
||||
|
||||
is_discrete = True
|
||||
event_dim = 1
|
||||
|
||||
def check(self, value):
|
||||
is_boolean = (value == 0) | (value == 1)
|
||||
is_normalized = value.sum(-1).eq(1)
|
||||
return is_boolean.all(-1) & is_normalized
|
||||
|
||||
|
||||
class _IntegerInterval(Constraint):
|
||||
"""
|
||||
Constrain to an integer interval `[lower_bound, upper_bound]`.
|
||||
"""
|
||||
|
||||
is_discrete = True
|
||||
|
||||
def __init__(self, lower_bound, upper_bound):
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return (
|
||||
(value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += (
|
||||
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
|
||||
)
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _IntegerLessThan(Constraint):
|
||||
"""
|
||||
Constrain to an integer interval `(-inf, upper_bound]`.
|
||||
"""
|
||||
|
||||
is_discrete = True
|
||||
|
||||
def __init__(self, upper_bound):
|
||||
self.upper_bound = upper_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return (value % 1 == 0) & (value <= self.upper_bound)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += f"(upper_bound={self.upper_bound})"
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _IntegerGreaterThan(Constraint):
|
||||
"""
|
||||
Constrain to an integer interval `[lower_bound, inf)`.
|
||||
"""
|
||||
|
||||
is_discrete = True
|
||||
|
||||
def __init__(self, lower_bound):
|
||||
self.lower_bound = lower_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return (value % 1 == 0) & (value >= self.lower_bound)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += f"(lower_bound={self.lower_bound})"
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _Real(Constraint):
|
||||
"""
|
||||
Trivially constrain to the extended real line `[-inf, inf]`.
|
||||
"""
|
||||
|
||||
def check(self, value):
|
||||
return value == value # False for NANs.
|
||||
|
||||
|
||||
class _GreaterThan(Constraint):
|
||||
"""
|
||||
Constrain to a real half line `(lower_bound, inf]`.
|
||||
"""
|
||||
|
||||
def __init__(self, lower_bound):
|
||||
self.lower_bound = lower_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return self.lower_bound < value
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += f"(lower_bound={self.lower_bound})"
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _GreaterThanEq(Constraint):
|
||||
"""
|
||||
Constrain to a real half line `[lower_bound, inf)`.
|
||||
"""
|
||||
|
||||
def __init__(self, lower_bound):
|
||||
self.lower_bound = lower_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return self.lower_bound <= value
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += f"(lower_bound={self.lower_bound})"
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _LessThan(Constraint):
|
||||
"""
|
||||
Constrain to a real half line `[-inf, upper_bound)`.
|
||||
"""
|
||||
|
||||
def __init__(self, upper_bound):
|
||||
self.upper_bound = upper_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return value < self.upper_bound
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += f"(upper_bound={self.upper_bound})"
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _Interval(Constraint):
|
||||
"""
|
||||
Constrain to a real interval `[lower_bound, upper_bound]`.
|
||||
"""
|
||||
|
||||
def __init__(self, lower_bound, upper_bound):
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return (self.lower_bound <= value) & (value <= self.upper_bound)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += (
|
||||
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
|
||||
)
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _HalfOpenInterval(Constraint):
|
||||
"""
|
||||
Constrain to a real interval `[lower_bound, upper_bound)`.
|
||||
"""
|
||||
|
||||
def __init__(self, lower_bound, upper_bound):
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
super().__init__()
|
||||
|
||||
def check(self, value):
|
||||
return (self.lower_bound <= value) & (value < self.upper_bound)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = self.__class__.__name__[1:]
|
||||
fmt_string += (
|
||||
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
|
||||
)
|
||||
return fmt_string
|
||||
|
||||
|
||||
class _Simplex(Constraint):
|
||||
"""
|
||||
Constrain to the unit simplex in the innermost (rightmost) dimension.
|
||||
Specifically: `x >= 0` and `x.sum(-1) == 1`.
|
||||
"""
|
||||
|
||||
event_dim = 1
|
||||
|
||||
def check(self, value):
|
||||
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
|
||||
|
||||
|
||||
class _Multinomial(Constraint):
|
||||
"""
|
||||
Constrain to nonnegative integer values summing to at most an upper bound.
|
||||
|
||||
Note due to limitations of the Multinomial distribution, this currently
|
||||
checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future
|
||||
this may be strengthened to ``value.sum(-1) == upper_bound``.
|
||||
"""
|
||||
|
||||
is_discrete = True
|
||||
event_dim = 1
|
||||
|
||||
def __init__(self, upper_bound):
|
||||
self.upper_bound = upper_bound
|
||||
|
||||
def check(self, x):
|
||||
return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)
|
||||
|
||||
|
||||
class _LowerTriangular(Constraint):
|
||||
"""
|
||||
Constrain to lower-triangular square matrices.
|
||||
"""
|
||||
|
||||
event_dim = 2
|
||||
|
||||
def check(self, value):
|
||||
value_tril = value.tril()
|
||||
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
|
||||
|
||||
|
||||
class _LowerCholesky(Constraint):
|
||||
"""
|
||||
Constrain to lower-triangular square matrices with positive diagonals.
|
||||
"""
|
||||
|
||||
event_dim = 2
|
||||
|
||||
def check(self, value):
|
||||
value_tril = value.tril()
|
||||
lower_triangular = (
|
||||
(value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
|
||||
)
|
||||
|
||||
positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
|
||||
return lower_triangular & positive_diagonal
|
||||
|
||||
|
||||
class _CorrCholesky(Constraint):
|
||||
"""
|
||||
Constrain to lower-triangular square matrices with positive diagonals and each
|
||||
row vector being of unit length.
|
||||
"""
|
||||
|
||||
event_dim = 2
|
||||
|
||||
def check(self, value):
|
||||
tol = (
|
||||
torch.finfo(value.dtype).eps * value.size(-1) * 10
|
||||
) # 10 is an adjustable fudge factor
|
||||
row_norm = torch.linalg.norm(value.detach(), dim=-1)
|
||||
unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1)
|
||||
return _LowerCholesky().check(value) & unit_row_norm
|
||||
|
||||
|
||||
class _Square(Constraint):
|
||||
"""
|
||||
Constrain to square matrices.
|
||||
"""
|
||||
|
||||
event_dim = 2
|
||||
|
||||
def check(self, value):
|
||||
return torch.full(
|
||||
size=value.shape[:-2],
|
||||
fill_value=(value.shape[-2] == value.shape[-1]),
|
||||
dtype=torch.bool,
|
||||
device=value.device,
|
||||
)
|
||||
|
||||
|
||||
class _Symmetric(_Square):
|
||||
"""
|
||||
Constrain to Symmetric square matrices.
|
||||
"""
|
||||
|
||||
def check(self, value):
|
||||
square_check = super().check(value)
|
||||
if not square_check.all():
|
||||
return square_check
|
||||
return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
|
||||
|
||||
|
||||
class _PositiveSemidefinite(_Symmetric):
|
||||
"""
|
||||
Constrain to positive-semidefinite matrices.
|
||||
"""
|
||||
|
||||
def check(self, value):
|
||||
sym_check = super().check(value)
|
||||
if not sym_check.all():
|
||||
return sym_check
|
||||
return torch.linalg.eigvalsh(value).ge(0).all(-1)
|
||||
|
||||
|
||||
class _PositiveDefinite(_Symmetric):
|
||||
"""
|
||||
Constrain to positive-definite matrices.
|
||||
"""
|
||||
|
||||
def check(self, value):
|
||||
sym_check = super().check(value)
|
||||
if not sym_check.all():
|
||||
return sym_check
|
||||
return torch.linalg.cholesky_ex(value).info.eq(0)
|
||||
|
||||
|
||||
class _Cat(Constraint):
|
||||
"""
|
||||
Constraint functor that applies a sequence of constraints
|
||||
`cseq` at the submatrices at dimension `dim`,
|
||||
each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
|
||||
"""
|
||||
|
||||
def __init__(self, cseq, dim=0, lengths=None):
|
||||
assert all(isinstance(c, Constraint) for c in cseq)
|
||||
self.cseq = list(cseq)
|
||||
if lengths is None:
|
||||
lengths = [1] * len(self.cseq)
|
||||
self.lengths = list(lengths)
|
||||
assert len(self.lengths) == len(self.cseq)
|
||||
self.dim = dim
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def is_discrete(self) -> bool: # type: ignore[override]
|
||||
return any(c.is_discrete for c in self.cseq)
|
||||
|
||||
@property
|
||||
def event_dim(self) -> int: # type: ignore[override]
|
||||
return max(c.event_dim for c in self.cseq)
|
||||
|
||||
def check(self, value):
|
||||
assert -value.dim() <= self.dim < value.dim()
|
||||
checks = []
|
||||
start = 0
|
||||
for constr, length in zip(self.cseq, self.lengths):
|
||||
v = value.narrow(self.dim, start, length)
|
||||
checks.append(constr.check(v))
|
||||
start = start + length # avoid += for jit compat
|
||||
return torch.cat(checks, self.dim)
|
||||
|
||||
|
||||
class _Stack(Constraint):
|
||||
"""
|
||||
Constraint functor that applies a sequence of constraints
|
||||
`cseq` at the submatrices at dimension `dim`,
|
||||
in a way compatible with :func:`torch.stack`.
|
||||
"""
|
||||
|
||||
def __init__(self, cseq, dim=0):
|
||||
assert all(isinstance(c, Constraint) for c in cseq)
|
||||
self.cseq = list(cseq)
|
||||
self.dim = dim
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def is_discrete(self) -> bool: # type: ignore[override]
|
||||
return any(c.is_discrete for c in self.cseq)
|
||||
|
||||
@property
|
||||
def event_dim(self) -> int: # type: ignore[override]
|
||||
dim = max(c.event_dim for c in self.cseq)
|
||||
if self.dim + dim < 0:
|
||||
dim += 1
|
||||
return dim
|
||||
|
||||
def check(self, value):
|
||||
assert -value.dim() <= self.dim < value.dim()
|
||||
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
|
||||
return torch.stack(
|
||||
[constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim
|
||||
)
|
||||
|
||||
|
||||
# Public interface.
|
||||
dependent = _Dependent()
|
||||
dependent_property = _DependentProperty
|
||||
independent = _IndependentConstraint
|
||||
boolean = _Boolean()
|
||||
one_hot = _OneHot()
|
||||
nonnegative_integer = _IntegerGreaterThan(0)
|
||||
positive_integer = _IntegerGreaterThan(1)
|
||||
integer_interval = _IntegerInterval
|
||||
real = _Real()
|
||||
real_vector = independent(real, 1)
|
||||
positive = _GreaterThan(0.0)
|
||||
nonnegative = _GreaterThanEq(0.0)
|
||||
greater_than = _GreaterThan
|
||||
greater_than_eq = _GreaterThanEq
|
||||
less_than = _LessThan
|
||||
multinomial = _Multinomial
|
||||
unit_interval = _Interval(0.0, 1.0)
|
||||
interval = _Interval
|
||||
half_open_interval = _HalfOpenInterval
|
||||
simplex = _Simplex()
|
||||
lower_triangular = _LowerTriangular()
|
||||
lower_cholesky = _LowerCholesky()
|
||||
corr_cholesky = _CorrCholesky()
|
||||
square = _Square()
|
||||
symmetric = _Symmetric()
|
||||
positive_semidefinite = _PositiveSemidefinite()
|
||||
positive_definite = _PositiveDefinite()
|
||||
cat = _Cat
|
||||
stack = _Stack
|
|
@ -0,0 +1,239 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import (
|
||||
broadcast_all,
|
||||
clamp_probs,
|
||||
lazy_property,
|
||||
logits_to_probs,
|
||||
probs_to_logits,
|
||||
)
|
||||
from torch.nn.functional import binary_cross_entropy_with_logits
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["ContinuousBernoulli"]
|
||||
|
||||
|
||||
class ContinuousBernoulli(ExponentialFamily):
|
||||
r"""
|
||||
Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
|
||||
or :attr:`logits` (but not both).
|
||||
|
||||
The distribution is supported in [0, 1] and parameterized by 'probs' (in
|
||||
(0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
|
||||
does not correspond to a probability and 'logits' does not correspond to
|
||||
log-odds, but the same names are used due to the similarity with the
|
||||
Bernoulli. See [1] for more details.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = ContinuousBernoulli(torch.tensor([0.3]))
|
||||
>>> m.sample()
|
||||
tensor([ 0.2538])
|
||||
|
||||
Args:
|
||||
probs (Number, Tensor): (0,1) valued parameters
|
||||
logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
|
||||
|
||||
[1] The continuous Bernoulli: fixing a pervasive error in variational
|
||||
autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
|
||||
https://arxiv.org/abs/1907.06845
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.unit_interval
|
||||
_mean_carrier_measure = 0
|
||||
has_rsample = True
|
||||
|
||||
def __init__(
|
||||
self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None
|
||||
) -> None:
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError(
|
||||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
is_scalar = isinstance(probs, _Number)
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
||||
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
||||
if validate_args is not None:
|
||||
if not self.arg_constraints["probs"].check(self.probs).all():
|
||||
raise ValueError("The parameter probs has invalid values")
|
||||
self.probs = clamp_probs(self.probs)
|
||||
else:
|
||||
is_scalar = isinstance(logits, _Number)
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
if is_scalar:
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self._param.size()
|
||||
self._lims = lims
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(ContinuousBernoulli, _instance)
|
||||
new._lims = self._lims
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
if "probs" in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
if "logits" in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._param.new(*args, **kwargs)
|
||||
|
||||
def _outside_unstable_region(self):
|
||||
return torch.max(
|
||||
torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1])
|
||||
)
|
||||
|
||||
def _cut_probs(self):
|
||||
return torch.where(
|
||||
self._outside_unstable_region(),
|
||||
self.probs,
|
||||
self._lims[0] * torch.ones_like(self.probs),
|
||||
)
|
||||
|
||||
def _cont_bern_log_norm(self):
|
||||
"""computes the log normalizing constant as a function of the 'probs' parameter"""
|
||||
cut_probs = self._cut_probs()
|
||||
cut_probs_below_half = torch.where(
|
||||
torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs)
|
||||
)
|
||||
cut_probs_above_half = torch.where(
|
||||
torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs)
|
||||
)
|
||||
log_norm = torch.log(
|
||||
torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))
|
||||
) - torch.where(
|
||||
torch.le(cut_probs, 0.5),
|
||||
torch.log1p(-2.0 * cut_probs_below_half),
|
||||
torch.log(2.0 * cut_probs_above_half - 1.0),
|
||||
)
|
||||
x = torch.pow(self.probs - 0.5, 2)
|
||||
taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
|
||||
return torch.where(self._outside_unstable_region(), log_norm, taylor)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
cut_probs = self._cut_probs()
|
||||
mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (
|
||||
torch.log1p(-cut_probs) - torch.log(cut_probs)
|
||||
)
|
||||
x = self.probs - 0.5
|
||||
taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
|
||||
return torch.where(self._outside_unstable_region(), mus, taylor)
|
||||
|
||||
@property
|
||||
def stddev(self) -> Tensor:
|
||||
return torch.sqrt(self.variance)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
cut_probs = self._cut_probs()
|
||||
vars = cut_probs * (cut_probs - 1.0) / torch.pow(
|
||||
1.0 - 2.0 * cut_probs, 2
|
||||
) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
|
||||
x = torch.pow(self.probs - 0.5, 2)
|
||||
taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x
|
||||
return torch.where(self._outside_unstable_region(), vars, taylor)
|
||||
|
||||
@lazy_property
|
||||
def logits(self) -> Tensor:
|
||||
return probs_to_logits(self.probs, is_binary=True)
|
||||
|
||||
@lazy_property
|
||||
def probs(self) -> Tensor:
|
||||
return clamp_probs(logits_to_probs(self.logits, is_binary=True))
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._param.size()
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
|
||||
with torch.no_grad():
|
||||
return self.icdf(u)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
|
||||
return self.icdf(u)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
return (
|
||||
-binary_cross_entropy_with_logits(logits, value, reduction="none")
|
||||
+ self._cont_bern_log_norm()
|
||||
)
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
cut_probs = self._cut_probs()
|
||||
cdfs = (
|
||||
torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
|
||||
+ cut_probs
|
||||
- 1.0
|
||||
) / (2.0 * cut_probs - 1.0)
|
||||
unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
|
||||
return torch.where(
|
||||
torch.le(value, 0.0),
|
||||
torch.zeros_like(value),
|
||||
torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs),
|
||||
)
|
||||
|
||||
def icdf(self, value):
|
||||
cut_probs = self._cut_probs()
|
||||
return torch.where(
|
||||
self._outside_unstable_region(),
|
||||
(
|
||||
torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
|
||||
- torch.log1p(-cut_probs)
|
||||
)
|
||||
/ (torch.log(cut_probs) - torch.log1p(-cut_probs)),
|
||||
value,
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
log_probs0 = torch.log1p(-self.probs)
|
||||
log_probs1 = torch.log(self.probs)
|
||||
return (
|
||||
self.mean * (log_probs0 - log_probs1)
|
||||
- self._cont_bern_log_norm()
|
||||
- log_probs0
|
||||
)
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (self.logits,)
|
||||
|
||||
def _log_normalizer(self, x):
|
||||
"""computes the log normalizing constant as a function of the natural parameter"""
|
||||
out_unst_reg = torch.max(
|
||||
torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5)
|
||||
)
|
||||
cut_nat_params = torch.where(
|
||||
out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x)
|
||||
)
|
||||
log_norm = torch.log(
|
||||
torch.abs(torch.special.expm1(cut_nat_params))
|
||||
) - torch.log(torch.abs(cut_nat_params))
|
||||
taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
|
||||
return torch.where(out_unst_reg, log_norm, taylor)
|
128
venv/Lib/site-packages/torch/distributions/dirichlet.py
Normal file
128
venv/Lib/site-packages/torch/distributions/dirichlet.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["Dirichlet"]
|
||||
|
||||
|
||||
# This helper is exposed for testing.
|
||||
def _Dirichlet_backward(x, concentration, grad_output):
|
||||
total = concentration.sum(-1, True).expand_as(concentration)
|
||||
grad = torch._dirichlet_grad(x, concentration, total)
|
||||
return grad * (grad_output - (x * grad_output).sum(-1, True))
|
||||
|
||||
|
||||
class _Dirichlet(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, concentration):
|
||||
x = torch._sample_dirichlet(concentration)
|
||||
ctx.save_for_backward(x, concentration)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
x, concentration = ctx.saved_tensors
|
||||
return _Dirichlet_backward(x, concentration, grad_output)
|
||||
|
||||
|
||||
class Dirichlet(ExponentialFamily):
|
||||
r"""
|
||||
Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Dirichlet(torch.tensor([0.5, 0.5]))
|
||||
>>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5]
|
||||
tensor([ 0.1046, 0.8954])
|
||||
|
||||
Args:
|
||||
concentration (Tensor): concentration parameter of the distribution
|
||||
(often referred to as alpha)
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"concentration": constraints.independent(constraints.positive, 1)
|
||||
}
|
||||
support = constraints.simplex
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, concentration, validate_args=None):
|
||||
if concentration.dim() < 1:
|
||||
raise ValueError(
|
||||
"`concentration` parameter must be at least one-dimensional."
|
||||
)
|
||||
self.concentration = concentration
|
||||
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Dirichlet, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.concentration = self.concentration.expand(batch_shape + self.event_shape)
|
||||
super(Dirichlet, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def rsample(self, sample_shape: _size = ()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
concentration = self.concentration.expand(shape)
|
||||
return _Dirichlet.apply(concentration)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return (
|
||||
torch.xlogy(self.concentration - 1.0, value).sum(-1)
|
||||
+ torch.lgamma(self.concentration.sum(-1))
|
||||
- torch.lgamma(self.concentration).sum(-1)
|
||||
)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.concentration / self.concentration.sum(-1, True)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
concentrationm1 = (self.concentration - 1).clamp(min=0.0)
|
||||
mode = concentrationm1 / concentrationm1.sum(-1, True)
|
||||
mask = (self.concentration < 1).all(dim=-1)
|
||||
mode[mask] = torch.nn.functional.one_hot(
|
||||
mode[mask].argmax(dim=-1), concentrationm1.shape[-1]
|
||||
).to(mode)
|
||||
return mode
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
con0 = self.concentration.sum(-1, True)
|
||||
return (
|
||||
self.concentration
|
||||
* (con0 - self.concentration)
|
||||
/ (con0.pow(2) * (con0 + 1))
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
k = self.concentration.size(-1)
|
||||
a0 = self.concentration.sum(-1)
|
||||
return (
|
||||
torch.lgamma(self.concentration).sum(-1)
|
||||
- torch.lgamma(a0)
|
||||
- (k - a0) * torch.digamma(a0)
|
||||
- ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
|
||||
)
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (self.concentration,)
|
||||
|
||||
def _log_normalizer(self, x):
|
||||
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
|
341
venv/Lib/site-packages/torch/distributions/distribution.py
Normal file
341
venv/Lib/site-packages/torch/distributions/distribution.py
Normal file
|
@ -0,0 +1,341 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.utils import lazy_property
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["Distribution"]
|
||||
|
||||
|
||||
class Distribution:
|
||||
r"""
|
||||
Distribution is the abstract base class for probability distributions.
|
||||
"""
|
||||
|
||||
has_rsample = False
|
||||
has_enumerate_support = False
|
||||
_validate_args = __debug__
|
||||
|
||||
@staticmethod
|
||||
def set_default_validate_args(value: bool) -> None:
|
||||
"""
|
||||
Sets whether validation is enabled or disabled.
|
||||
|
||||
The default behavior mimics Python's ``assert`` statement: validation
|
||||
is on by default, but is disabled if Python is run in optimized mode
|
||||
(via ``python -O``). Validation may be expensive, so you may want to
|
||||
disable it once a model is working.
|
||||
|
||||
Args:
|
||||
value (bool): Whether to enable validation.
|
||||
"""
|
||||
if value not in [True, False]:
|
||||
raise ValueError
|
||||
Distribution._validate_args = value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_shape: torch.Size = torch.Size(),
|
||||
event_shape: torch.Size = torch.Size(),
|
||||
validate_args: Optional[bool] = None,
|
||||
):
|
||||
self._batch_shape = batch_shape
|
||||
self._event_shape = event_shape
|
||||
if validate_args is not None:
|
||||
self._validate_args = validate_args
|
||||
if self._validate_args:
|
||||
try:
|
||||
arg_constraints = self.arg_constraints
|
||||
except NotImplementedError:
|
||||
arg_constraints = {}
|
||||
warnings.warn(
|
||||
f"{self.__class__} does not define `arg_constraints`. "
|
||||
+ "Please set `arg_constraints = {}` or initialize the distribution "
|
||||
+ "with `validate_args=False` to turn off validation."
|
||||
)
|
||||
for param, constraint in arg_constraints.items():
|
||||
if constraints.is_dependent(constraint):
|
||||
continue # skip constraints that cannot be checked
|
||||
if param not in self.__dict__ and isinstance(
|
||||
getattr(type(self), param), lazy_property
|
||||
):
|
||||
continue # skip checking lazily-constructed args
|
||||
value = getattr(self, param)
|
||||
valid = constraint.check(value)
|
||||
if not torch._is_all_true(valid):
|
||||
raise ValueError(
|
||||
f"Expected parameter {param} "
|
||||
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
||||
f"of distribution {repr(self)} "
|
||||
f"to satisfy the constraint {repr(constraint)}, "
|
||||
f"but found invalid values:\n{value}"
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def expand(self, batch_shape: _size, _instance=None):
|
||||
"""
|
||||
Returns a new distribution instance (or populates an existing instance
|
||||
provided by a derived class) with batch dimensions expanded to
|
||||
`batch_shape`. This method calls :class:`~torch.Tensor.expand` on
|
||||
the distribution's parameters. As such, this does not allocate new
|
||||
memory for the expanded distribution instance. Additionally,
|
||||
this does not repeat any args checking or parameter broadcasting in
|
||||
`__init__.py`, when an instance is first created.
|
||||
|
||||
Args:
|
||||
batch_shape (torch.Size): the desired expanded size.
|
||||
_instance: new instance provided by subclasses that
|
||||
need to override `.expand`.
|
||||
|
||||
Returns:
|
||||
New distribution instance with batch dimensions expanded to
|
||||
`batch_size`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def batch_shape(self) -> torch.Size:
|
||||
"""
|
||||
Returns the shape over which parameters are batched.
|
||||
"""
|
||||
return self._batch_shape
|
||||
|
||||
@property
|
||||
def event_shape(self) -> torch.Size:
|
||||
"""
|
||||
Returns the shape of a single sample (without batching).
|
||||
"""
|
||||
return self._event_shape
|
||||
|
||||
@property
|
||||
def arg_constraints(self) -> dict[str, constraints.Constraint]:
|
||||
"""
|
||||
Returns a dictionary from argument names to
|
||||
:class:`~torch.distributions.constraints.Constraint` objects that
|
||||
should be satisfied by each argument of this distribution. Args that
|
||||
are not tensors need not appear in this dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def support(self) -> Optional[constraints.Constraint]:
|
||||
"""
|
||||
Returns a :class:`~torch.distributions.constraints.Constraint` object
|
||||
representing this distribution's support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
"""
|
||||
Returns the mean of the distribution.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
"""
|
||||
Returns the mode of the distribution.
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__} does not implement mode")
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
"""
|
||||
Returns the variance of the distribution.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def stddev(self) -> Tensor:
|
||||
"""
|
||||
Returns the standard deviation of the distribution.
|
||||
"""
|
||||
return self.variance.sqrt()
|
||||
|
||||
def sample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
"""
|
||||
Generates a sample_shape shaped sample or sample_shape shaped batch of
|
||||
samples if the distribution parameters are batched.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self.rsample(sample_shape)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
"""
|
||||
Generates a sample_shape shaped reparameterized sample or sample_shape
|
||||
shaped batch of reparameterized samples if the distribution parameters
|
||||
are batched.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated(
|
||||
"`sample_n(n)` will be deprecated. Use `sample((n,))` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def sample_n(self, n: int) -> Tensor:
|
||||
"""
|
||||
Generates n samples or n batches of samples if the distribution
|
||||
parameters are batched.
|
||||
"""
|
||||
return self.sample(torch.Size((n,)))
|
||||
|
||||
def log_prob(self, value: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns the log of the probability density/mass function evaluated at
|
||||
`value`.
|
||||
|
||||
Args:
|
||||
value (Tensor):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cdf(self, value: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns the cumulative density/mass function evaluated at
|
||||
`value`.
|
||||
|
||||
Args:
|
||||
value (Tensor):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def icdf(self, value: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns the inverse cumulative density/mass function evaluated at
|
||||
`value`.
|
||||
|
||||
Args:
|
||||
value (Tensor):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enumerate_support(self, expand: bool = True) -> Tensor:
|
||||
"""
|
||||
Returns tensor containing all values supported by a discrete
|
||||
distribution. The result will enumerate over dimension 0, so the shape
|
||||
of the result will be `(cardinality,) + batch_shape + event_shape`
|
||||
(where `event_shape = ()` for univariate distributions).
|
||||
|
||||
Note that this enumerates over all batched tensors in lock-step
|
||||
`[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
|
||||
along dim 0, but with the remaining batch dimensions being
|
||||
singleton dimensions, `[[0], [1], ..`.
|
||||
|
||||
To iterate over the full Cartesian product use
|
||||
`itertools.product(m.enumerate_support())`.
|
||||
|
||||
Args:
|
||||
expand (bool): whether to expand the support over the
|
||||
batch dims to match the distribution's `batch_shape`.
|
||||
|
||||
Returns:
|
||||
Tensor iterating over dimension 0.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def entropy(self) -> Tensor:
|
||||
"""
|
||||
Returns entropy of distribution, batched over batch_shape.
|
||||
|
||||
Returns:
|
||||
Tensor of shape batch_shape.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def perplexity(self) -> Tensor:
|
||||
"""
|
||||
Returns perplexity of distribution, batched over batch_shape.
|
||||
|
||||
Returns:
|
||||
Tensor of shape batch_shape.
|
||||
"""
|
||||
return torch.exp(self.entropy())
|
||||
|
||||
def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size:
|
||||
"""
|
||||
Returns the size of the sample returned by the distribution, given
|
||||
a `sample_shape`. Note, that the batch and event shapes of a distribution
|
||||
instance are fixed at the time of construction. If this is empty, the
|
||||
returned shape is upcast to (1,).
|
||||
|
||||
Args:
|
||||
sample_shape (torch.Size): the size of the sample to be drawn.
|
||||
"""
|
||||
if not isinstance(sample_shape, torch.Size):
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
return torch.Size(sample_shape + self._batch_shape + self._event_shape)
|
||||
|
||||
def _validate_sample(self, value: Tensor) -> None:
|
||||
"""
|
||||
Argument validation for distribution methods such as `log_prob`,
|
||||
`cdf` and `icdf`. The rightmost dimensions of a value to be
|
||||
scored via these methods must agree with the distribution's batch
|
||||
and event shapes.
|
||||
|
||||
Args:
|
||||
value (Tensor): the tensor whose log probability is to be
|
||||
computed by the `log_prob` method.
|
||||
Raises
|
||||
ValueError: when the rightmost dimensions of `value` do not match the
|
||||
distribution's batch and event shapes.
|
||||
"""
|
||||
if not isinstance(value, torch.Tensor):
|
||||
raise ValueError("The value argument to log_prob must be a Tensor")
|
||||
|
||||
event_dim_start = len(value.size()) - len(self._event_shape)
|
||||
if value.size()[event_dim_start:] != self._event_shape:
|
||||
raise ValueError(
|
||||
f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}."
|
||||
)
|
||||
|
||||
actual_shape = value.size()
|
||||
expected_shape = self._batch_shape + self._event_shape
|
||||
for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
|
||||
if i != 1 and j != 1 and i != j:
|
||||
raise ValueError(
|
||||
f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}."
|
||||
)
|
||||
try:
|
||||
support = self.support
|
||||
except NotImplementedError:
|
||||
warnings.warn(
|
||||
f"{self.__class__} does not define `support` to enable "
|
||||
+ "sample validation. Please initialize the distribution with "
|
||||
+ "`validate_args=False` to turn off validation."
|
||||
)
|
||||
return
|
||||
assert support is not None
|
||||
valid = support.check(value)
|
||||
if not torch._is_all_true(valid):
|
||||
raise ValueError(
|
||||
"Expected value argument "
|
||||
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
||||
f"to be within the support ({repr(support)}) "
|
||||
f"of the distribution {repr(self)}, "
|
||||
f"but found invalid values:\n{value}"
|
||||
)
|
||||
|
||||
def _get_checked_instance(self, cls, _instance=None):
|
||||
if _instance is None and type(self).__init__ != cls.__init__:
|
||||
raise NotImplementedError(
|
||||
f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method "
|
||||
"must also define a custom .expand() method."
|
||||
)
|
||||
return self.__new__(type(self)) if _instance is None else _instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
|
||||
args_string = ", ".join(
|
||||
[
|
||||
f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
|
||||
for p in param_names
|
||||
]
|
||||
)
|
||||
return self.__class__.__name__ + "(" + args_string + ")"
|
65
venv/Lib/site-packages/torch/distributions/exp_family.py
Normal file
65
venv/Lib/site-packages/torch/distributions/exp_family.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions.distribution import Distribution
|
||||
|
||||
|
||||
__all__ = ["ExponentialFamily"]
|
||||
|
||||
|
||||
class ExponentialFamily(Distribution):
|
||||
r"""
|
||||
ExponentialFamily is the abstract base class for probability distributions belonging to an
|
||||
exponential family, whose probability mass/density function has the form is defined below
|
||||
|
||||
.. math::
|
||||
|
||||
p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
|
||||
|
||||
where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
|
||||
:math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
|
||||
measure.
|
||||
|
||||
Note:
|
||||
This class is an intermediary between the `Distribution` class and distributions which belong
|
||||
to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
|
||||
divergence methods. We use this class to compute the entropy and KL divergence using the AD
|
||||
framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
|
||||
Cross-entropies of Exponential Families).
|
||||
"""
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor, ...]:
|
||||
"""
|
||||
Abstract method for natural parameters. Returns a tuple of Tensors based
|
||||
on the distribution
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _log_normalizer(self, *natural_params):
|
||||
"""
|
||||
Abstract method for log normalizer function. Returns a log normalizer based on
|
||||
the distribution and input
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _mean_carrier_measure(self) -> float:
|
||||
"""
|
||||
Abstract method for expected carrier measure, which is required for computing
|
||||
entropy.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def entropy(self):
|
||||
"""
|
||||
Method to compute the entropy using Bregman divergence of the log normalizer.
|
||||
"""
|
||||
result = -self._mean_carrier_measure
|
||||
nparams = [p.detach().requires_grad_() for p in self._natural_params]
|
||||
lg_normal = self._log_normalizer(*nparams)
|
||||
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
|
||||
result += lg_normal
|
||||
for np, g in zip(nparams, gradients):
|
||||
result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
|
||||
return result
|
87
venv/Lib/site-packages/torch/distributions/exponential.py
Normal file
87
venv/Lib/site-packages/torch/distributions/exponential.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Exponential"]
|
||||
|
||||
|
||||
class Exponential(ExponentialFamily):
|
||||
r"""
|
||||
Creates a Exponential distribution parameterized by :attr:`rate`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Exponential(torch.tensor([1.0]))
|
||||
>>> m.sample() # Exponential distributed with rate=1
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
rate (float or Tensor): rate = 1 / scale of the distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"rate": constraints.positive}
|
||||
support = constraints.nonnegative
|
||||
has_rsample = True
|
||||
_mean_carrier_measure = 0
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.rate.reciprocal()
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return torch.zeros_like(self.rate)
|
||||
|
||||
@property
|
||||
def stddev(self) -> Tensor:
|
||||
return self.rate.reciprocal()
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.rate.pow(-2)
|
||||
|
||||
def __init__(self, rate, validate_args=None):
|
||||
(self.rate,) = broadcast_all(rate)
|
||||
batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Exponential, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.rate = self.rate.expand(batch_shape)
|
||||
super(Exponential, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
return self.rate.new(shape).exponential_() / self.rate
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return self.rate.log() - self.rate * value
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return 1 - torch.exp(-self.rate * value)
|
||||
|
||||
def icdf(self, value):
|
||||
return -torch.log1p(-value) / self.rate
|
||||
|
||||
def entropy(self):
|
||||
return 1.0 - torch.log(self.rate)
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (-self.rate,)
|
||||
|
||||
def _log_normalizer(self, x):
|
||||
return -torch.log(-x)
|
100
venv/Lib/site-packages/torch/distributions/fishersnedecor.py
Normal file
100
venv/Lib/site-packages/torch/distributions/fishersnedecor.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.gamma import Gamma
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["FisherSnedecor"]
|
||||
|
||||
|
||||
class FisherSnedecor(Distribution):
|
||||
r"""
|
||||
Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
|
||||
>>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2
|
||||
tensor([ 0.2453])
|
||||
|
||||
Args:
|
||||
df1 (float or Tensor): degrees of freedom parameter 1
|
||||
df2 (float or Tensor): degrees of freedom parameter 2
|
||||
"""
|
||||
|
||||
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
|
||||
support = constraints.positive
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, df1, df2, validate_args=None):
|
||||
self.df1, self.df2 = broadcast_all(df1, df2)
|
||||
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
|
||||
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
|
||||
|
||||
if isinstance(df1, _Number) and isinstance(df2, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.df1.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(FisherSnedecor, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.df1 = self.df1.expand(batch_shape)
|
||||
new.df2 = self.df2.expand(batch_shape)
|
||||
new._gamma1 = self._gamma1.expand(batch_shape)
|
||||
new._gamma2 = self._gamma2.expand(batch_shape)
|
||||
super(FisherSnedecor, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
df2 = self.df2.clone(memory_format=torch.contiguous_format)
|
||||
df2[df2 <= 2] = nan
|
||||
return df2 / (df2 - 2)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2)
|
||||
mode[self.df1 <= 2] = nan
|
||||
return mode
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
df2 = self.df2.clone(memory_format=torch.contiguous_format)
|
||||
df2[df2 <= 4] = nan
|
||||
return (
|
||||
2
|
||||
* df2.pow(2)
|
||||
* (self.df1 + df2 - 2)
|
||||
/ (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
|
||||
)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size(())) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
# X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
|
||||
# Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
|
||||
X1 = self._gamma1.rsample(sample_shape).view(shape)
|
||||
X2 = self._gamma2.rsample(sample_shape).view(shape)
|
||||
tiny = torch.finfo(X2.dtype).tiny
|
||||
X2.clamp_(min=tiny)
|
||||
Y = X1 / X2
|
||||
Y.clamp_(min=tiny)
|
||||
return Y
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
ct1 = self.df1 * 0.5
|
||||
ct2 = self.df2 * 0.5
|
||||
ct3 = self.df1 / self.df2
|
||||
t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
|
||||
t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value)
|
||||
t3 = (ct1 + ct2) * torch.log1p(ct3 * value)
|
||||
return t1 + t2 - t3
|
111
venv/Lib/site-packages/torch/distributions/gamma.py
Normal file
111
venv/Lib/site-packages/torch/distributions/gamma.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Gamma"]
|
||||
|
||||
|
||||
def _standard_gamma(concentration):
|
||||
return torch._standard_gamma(concentration)
|
||||
|
||||
|
||||
class Gamma(ExponentialFamily):
|
||||
r"""
|
||||
Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # Gamma distributed with concentration=1 and rate=1
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
concentration (float or Tensor): shape parameter of the distribution
|
||||
(often referred to as alpha)
|
||||
rate (float or Tensor): rate parameter of the distribution
|
||||
(often referred to as beta), rate = 1 / scale
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"concentration": constraints.positive,
|
||||
"rate": constraints.positive,
|
||||
}
|
||||
support = constraints.nonnegative
|
||||
has_rsample = True
|
||||
_mean_carrier_measure = 0
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.concentration / self.rate
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return ((self.concentration - 1) / self.rate).clamp(min=0)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.concentration / self.rate.pow(2)
|
||||
|
||||
def __init__(self, concentration, rate, validate_args=None):
|
||||
self.concentration, self.rate = broadcast_all(concentration, rate)
|
||||
if isinstance(concentration, _Number) and isinstance(rate, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.concentration.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Gamma, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.concentration = self.concentration.expand(batch_shape)
|
||||
new.rate = self.rate.expand(batch_shape)
|
||||
super(Gamma, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(
|
||||
shape
|
||||
)
|
||||
value.detach().clamp_(
|
||||
min=torch.finfo(value.dtype).tiny
|
||||
) # do not record in autograd graph
|
||||
return value
|
||||
|
||||
def log_prob(self, value):
|
||||
value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return (
|
||||
torch.xlogy(self.concentration, self.rate)
|
||||
+ torch.xlogy(self.concentration - 1, value)
|
||||
- self.rate * value
|
||||
- torch.lgamma(self.concentration)
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
return (
|
||||
self.concentration
|
||||
- torch.log(self.rate)
|
||||
+ torch.lgamma(self.concentration)
|
||||
+ (1.0 - self.concentration) * torch.digamma(self.concentration)
|
||||
)
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||
return (self.concentration - 1, -self.rate)
|
||||
|
||||
def _log_normalizer(self, x, y):
|
||||
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return torch.special.gammainc(self.concentration, self.rate * value)
|
131
venv/Lib/site-packages/torch/distributions/geometric.py
Normal file
131
venv/Lib/site-packages/torch/distributions/geometric.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import (
|
||||
broadcast_all,
|
||||
lazy_property,
|
||||
logits_to_probs,
|
||||
probs_to_logits,
|
||||
)
|
||||
from torch.nn.functional import binary_cross_entropy_with_logits
|
||||
from torch.types import _Number
|
||||
|
||||
|
||||
__all__ = ["Geometric"]
|
||||
|
||||
|
||||
class Geometric(Distribution):
|
||||
r"""
|
||||
Creates a Geometric distribution parameterized by :attr:`probs`,
|
||||
where :attr:`probs` is the probability of success of Bernoulli trials.
|
||||
|
||||
.. math::
|
||||
|
||||
P(X=k) = (1-p)^{k} p, k = 0, 1, ...
|
||||
|
||||
.. note::
|
||||
:func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success
|
||||
hence draws samples in :math:`\{0, 1, \ldots\}`, whereas
|
||||
:func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Geometric(torch.tensor([0.3]))
|
||||
>>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0
|
||||
tensor([ 2.])
|
||||
|
||||
Args:
|
||||
probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1]
|
||||
logits (Number, Tensor): the log-odds of sampling `1`.
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
def __init__(self, probs=None, logits=None, validate_args=None):
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError(
|
||||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
else:
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
probs_or_logits = probs if probs is not None else logits
|
||||
if isinstance(probs_or_logits, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = probs_or_logits.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
if self._validate_args and probs is not None:
|
||||
# Add an extra check beyond unit_interval
|
||||
value = self.probs
|
||||
valid = value > 0
|
||||
if not valid.all():
|
||||
invalid_value = value.data[~valid]
|
||||
raise ValueError(
|
||||
"Expected parameter probs "
|
||||
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
||||
f"of distribution {repr(self)} "
|
||||
f"to be positive but found invalid values:\n{invalid_value}"
|
||||
)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Geometric, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
if "probs" in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
if "logits" in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
super(Geometric, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return 1.0 / self.probs - 1.0
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return torch.zeros_like(self.probs)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return (1.0 / self.probs - 1.0) / self.probs
|
||||
|
||||
@lazy_property
|
||||
def logits(self) -> Tensor:
|
||||
return probs_to_logits(self.probs, is_binary=True)
|
||||
|
||||
@lazy_property
|
||||
def probs(self) -> Tensor:
|
||||
return logits_to_probs(self.logits, is_binary=True)
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
tiny = torch.finfo(self.probs.dtype).tiny
|
||||
with torch.no_grad():
|
||||
if torch._C._get_tracing_state():
|
||||
# [JIT WORKAROUND] lack of support for .uniform_()
|
||||
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
|
||||
u = u.clamp(min=tiny)
|
||||
else:
|
||||
u = self.probs.new(shape).uniform_(tiny, 1)
|
||||
return (u.log() / (-self.probs).log1p()).floor()
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
value, probs = broadcast_all(value, self.probs)
|
||||
probs = probs.clone(memory_format=torch.contiguous_format)
|
||||
probs[(probs == 1) & (value == 0)] = 0
|
||||
return value * (-probs).log1p() + self.probs.log()
|
||||
|
||||
def entropy(self):
|
||||
return (
|
||||
binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none")
|
||||
/ self.probs
|
||||
)
|
85
venv/Lib/site-packages/torch/distributions/gumbel.py
Normal file
85
venv/Lib/site-packages/torch/distributions/gumbel.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import AffineTransform, ExpTransform
|
||||
from torch.distributions.uniform import Uniform
|
||||
from torch.distributions.utils import broadcast_all, euler_constant
|
||||
from torch.types import _Number
|
||||
|
||||
|
||||
__all__ = ["Gumbel"]
|
||||
|
||||
|
||||
class Gumbel(TransformedDistribution):
|
||||
r"""
|
||||
Samples from a Gumbel Distribution.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
|
||||
>>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
|
||||
tensor([ 1.0124])
|
||||
|
||||
Args:
|
||||
loc (float or Tensor): Location parameter of the distribution
|
||||
scale (float or Tensor): Scale parameter of the distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.real
|
||||
|
||||
def __init__(self, loc, scale, validate_args=None):
|
||||
self.loc, self.scale = broadcast_all(loc, scale)
|
||||
finfo = torch.finfo(self.loc.dtype)
|
||||
if isinstance(loc, _Number) and isinstance(scale, _Number):
|
||||
base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
|
||||
else:
|
||||
base_dist = Uniform(
|
||||
torch.full_like(self.loc, finfo.tiny),
|
||||
torch.full_like(self.loc, 1 - finfo.eps),
|
||||
validate_args=validate_args,
|
||||
)
|
||||
transforms = [
|
||||
ExpTransform().inv,
|
||||
AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
|
||||
ExpTransform().inv,
|
||||
AffineTransform(loc=loc, scale=-self.scale),
|
||||
]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Gumbel, _instance)
|
||||
new.loc = self.loc.expand(batch_shape)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
# Explicitly defining the log probability function for Gumbel due to precision issues
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
y = (self.loc - value) / self.scale
|
||||
return (y - y.exp()) - self.scale.log()
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.loc + self.scale * euler_constant
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def stddev(self) -> Tensor:
|
||||
return (math.pi / math.sqrt(6)) * self.scale
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.stddev.pow(2)
|
||||
|
||||
def entropy(self):
|
||||
return self.scale.log() + (1 + euler_constant)
|
85
venv/Lib/site-packages/torch/distributions/half_cauchy.py
Normal file
85
venv/Lib/site-packages/torch/distributions/half_cauchy.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import inf, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.cauchy import Cauchy
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import AbsTransform
|
||||
|
||||
|
||||
__all__ = ["HalfCauchy"]
|
||||
|
||||
|
||||
class HalfCauchy(TransformedDistribution):
|
||||
r"""
|
||||
Creates a half-Cauchy distribution parameterized by `scale` where::
|
||||
|
||||
X ~ Cauchy(0, scale)
|
||||
Y = |X| ~ HalfCauchy(scale)
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = HalfCauchy(torch.tensor([1.0]))
|
||||
>>> m.sample() # half-cauchy distributed with scale=1
|
||||
tensor([ 2.3214])
|
||||
|
||||
Args:
|
||||
scale (float or Tensor): scale of the full Cauchy distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"scale": constraints.positive}
|
||||
support = constraints.nonnegative
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, scale, validate_args=None):
|
||||
base_dist = Cauchy(0, scale, validate_args=False)
|
||||
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(HalfCauchy, _instance)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
return self.base_dist.scale
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return torch.full(
|
||||
self._extended_shape(),
|
||||
math.inf,
|
||||
dtype=self.scale.dtype,
|
||||
device=self.scale.device,
|
||||
)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return torch.zeros_like(self.scale)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.base_dist.variance
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
value = torch.as_tensor(
|
||||
value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device
|
||||
)
|
||||
log_prob = self.base_dist.log_prob(value) + math.log(2)
|
||||
log_prob = torch.where(value >= 0, log_prob, -inf)
|
||||
return log_prob
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return 2 * self.base_dist.cdf(value) - 1
|
||||
|
||||
def icdf(self, prob):
|
||||
return self.base_dist.icdf((prob + 1) / 2)
|
||||
|
||||
def entropy(self):
|
||||
return self.base_dist.entropy() - math.log(2)
|
77
venv/Lib/site-packages/torch/distributions/half_normal.py
Normal file
77
venv/Lib/site-packages/torch/distributions/half_normal.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import inf, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import AbsTransform
|
||||
|
||||
|
||||
__all__ = ["HalfNormal"]
|
||||
|
||||
|
||||
class HalfNormal(TransformedDistribution):
|
||||
r"""
|
||||
Creates a half-normal distribution parameterized by `scale` where::
|
||||
|
||||
X ~ Normal(0, scale)
|
||||
Y = |X| ~ HalfNormal(scale)
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = HalfNormal(torch.tensor([1.0]))
|
||||
>>> m.sample() # half-normal distributed with scale=1
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
scale (float or Tensor): scale of the full Normal distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"scale": constraints.positive}
|
||||
support = constraints.nonnegative
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, scale, validate_args=None):
|
||||
base_dist = Normal(0, scale, validate_args=False)
|
||||
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(HalfNormal, _instance)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
return self.base_dist.scale
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.scale * math.sqrt(2 / math.pi)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return torch.zeros_like(self.scale)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.scale.pow(2) * (1 - 2 / math.pi)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
log_prob = self.base_dist.log_prob(value) + math.log(2)
|
||||
log_prob = torch.where(value >= 0, log_prob, -inf)
|
||||
return log_prob
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return 2 * self.base_dist.cdf(value) - 1
|
||||
|
||||
def icdf(self, prob):
|
||||
return self.base_dist.icdf((prob + 1) / 2)
|
||||
|
||||
def entropy(self):
|
||||
return self.base_dist.entropy() - math.log(2)
|
129
venv/Lib/site-packages/torch/distributions/independent.py
Normal file
129
venv/Lib/site-packages/torch/distributions/independent.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import _sum_rightmost
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["Independent"]
|
||||
|
||||
|
||||
class Independent(Distribution):
|
||||
r"""
|
||||
Reinterprets some of the batch dims of a distribution as event dims.
|
||||
|
||||
This is mainly useful for changing the shape of the result of
|
||||
:meth:`log_prob`. For example to create a diagonal Normal distribution with
|
||||
the same shape as a Multivariate Normal distribution (so they are
|
||||
interchangeable), you can::
|
||||
|
||||
>>> from torch.distributions.multivariate_normal import MultivariateNormal
|
||||
>>> from torch.distributions.normal import Normal
|
||||
>>> loc = torch.zeros(3)
|
||||
>>> scale = torch.ones(3)
|
||||
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
|
||||
>>> [mvn.batch_shape, mvn.event_shape]
|
||||
[torch.Size([]), torch.Size([3])]
|
||||
>>> normal = Normal(loc, scale)
|
||||
>>> [normal.batch_shape, normal.event_shape]
|
||||
[torch.Size([3]), torch.Size([])]
|
||||
>>> diagn = Independent(normal, 1)
|
||||
>>> [diagn.batch_shape, diagn.event_shape]
|
||||
[torch.Size([]), torch.Size([3])]
|
||||
|
||||
Args:
|
||||
base_distribution (torch.distributions.distribution.Distribution): a
|
||||
base distribution
|
||||
reinterpreted_batch_ndims (int): the number of batch dims to
|
||||
reinterpret as event dims
|
||||
"""
|
||||
|
||||
arg_constraints: dict[str, constraints.Constraint] = {}
|
||||
|
||||
def __init__(
|
||||
self, base_distribution, reinterpreted_batch_ndims, validate_args=None
|
||||
):
|
||||
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
|
||||
raise ValueError(
|
||||
"Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
|
||||
f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
|
||||
)
|
||||
shape = base_distribution.batch_shape + base_distribution.event_shape
|
||||
event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
|
||||
batch_shape = shape[: len(shape) - event_dim]
|
||||
event_shape = shape[len(shape) - event_dim :]
|
||||
self.base_dist = base_distribution
|
||||
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Independent, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.base_dist = self.base_dist.expand(
|
||||
batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
|
||||
)
|
||||
new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
|
||||
super(Independent, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@property
|
||||
def has_rsample(self) -> bool: # type: ignore[override]
|
||||
return self.base_dist.has_rsample
|
||||
|
||||
@property
|
||||
def has_enumerate_support(self) -> bool: # type: ignore[override]
|
||||
if self.reinterpreted_batch_ndims > 0:
|
||||
return False
|
||||
return self.base_dist.has_enumerate_support
|
||||
|
||||
@constraints.dependent_property
|
||||
def support(self):
|
||||
result = self.base_dist.support
|
||||
if self.reinterpreted_batch_ndims:
|
||||
result = constraints.independent(result, self.reinterpreted_batch_ndims)
|
||||
return result
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.base_dist.mean
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.base_dist.mode
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.base_dist.variance
|
||||
|
||||
def sample(self, sample_shape=torch.Size()) -> Tensor:
|
||||
return self.base_dist.sample(sample_shape)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
return self.base_dist.rsample(sample_shape)
|
||||
|
||||
def log_prob(self, value):
|
||||
log_prob = self.base_dist.log_prob(value)
|
||||
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
|
||||
|
||||
def entropy(self):
|
||||
entropy = self.base_dist.entropy()
|
||||
return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
|
||||
|
||||
def enumerate_support(self, expand=True):
|
||||
if self.reinterpreted_batch_ndims > 0:
|
||||
raise NotImplementedError(
|
||||
"Enumeration over cartesian product is not implemented"
|
||||
)
|
||||
return self.base_dist.enumerate_support(expand=expand)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
|
||||
)
|
83
venv/Lib/site-packages/torch/distributions/inverse_gamma.py
Normal file
83
venv/Lib/site-packages/torch/distributions/inverse_gamma.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.gamma import Gamma
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import PowerTransform
|
||||
|
||||
|
||||
__all__ = ["InverseGamma"]
|
||||
|
||||
|
||||
class InverseGamma(TransformedDistribution):
|
||||
r"""
|
||||
Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate`
|
||||
where::
|
||||
|
||||
X ~ Gamma(concentration, rate)
|
||||
Y = 1 / X ~ InverseGamma(concentration, rate)
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterinistic")
|
||||
>>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0]))
|
||||
>>> m.sample()
|
||||
tensor([ 1.2953])
|
||||
|
||||
Args:
|
||||
concentration (float or Tensor): shape parameter of the distribution
|
||||
(often referred to as alpha)
|
||||
rate (float or Tensor): rate = 1 / scale of the distribution
|
||||
(often referred to as beta)
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"concentration": constraints.positive,
|
||||
"rate": constraints.positive,
|
||||
}
|
||||
support = constraints.positive
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, concentration, rate, validate_args=None):
|
||||
base_dist = Gamma(concentration, rate, validate_args=validate_args)
|
||||
neg_one = -base_dist.rate.new_ones(())
|
||||
super().__init__(
|
||||
base_dist, PowerTransform(neg_one), validate_args=validate_args
|
||||
)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(InverseGamma, _instance)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def concentration(self) -> Tensor:
|
||||
return self.base_dist.concentration
|
||||
|
||||
@property
|
||||
def rate(self) -> Tensor:
|
||||
return self.base_dist.rate
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
result = self.rate / (self.concentration - 1)
|
||||
return torch.where(self.concentration > 1, result, torch.inf)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.rate / (self.concentration + 1)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
result = self.rate.square() / (
|
||||
(self.concentration - 1).square() * (self.concentration - 2)
|
||||
)
|
||||
return torch.where(self.concentration > 2, result, torch.inf)
|
||||
|
||||
def entropy(self):
|
||||
return (
|
||||
self.concentration
|
||||
+ self.rate.log()
|
||||
+ self.concentration.lgamma()
|
||||
- (1 + self.concentration) * self.concentration.digamma()
|
||||
)
|
972
venv/Lib/site-packages/torch/distributions/kl.py
Normal file
972
venv/Lib/site-packages/torch/distributions/kl.py
Normal file
|
@ -0,0 +1,972 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
import warnings
|
||||
from functools import total_ordering
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import inf, Tensor
|
||||
|
||||
from .bernoulli import Bernoulli
|
||||
from .beta import Beta
|
||||
from .binomial import Binomial
|
||||
from .categorical import Categorical
|
||||
from .cauchy import Cauchy
|
||||
from .continuous_bernoulli import ContinuousBernoulli
|
||||
from .dirichlet import Dirichlet
|
||||
from .distribution import Distribution
|
||||
from .exp_family import ExponentialFamily
|
||||
from .exponential import Exponential
|
||||
from .gamma import Gamma
|
||||
from .geometric import Geometric
|
||||
from .gumbel import Gumbel
|
||||
from .half_normal import HalfNormal
|
||||
from .independent import Independent
|
||||
from .laplace import Laplace
|
||||
from .lowrank_multivariate_normal import (
|
||||
_batch_lowrank_logdet,
|
||||
_batch_lowrank_mahalanobis,
|
||||
LowRankMultivariateNormal,
|
||||
)
|
||||
from .multivariate_normal import _batch_mahalanobis, MultivariateNormal
|
||||
from .normal import Normal
|
||||
from .one_hot_categorical import OneHotCategorical
|
||||
from .pareto import Pareto
|
||||
from .poisson import Poisson
|
||||
from .transformed_distribution import TransformedDistribution
|
||||
from .uniform import Uniform
|
||||
from .utils import _sum_rightmost, euler_constant as _euler_gamma
|
||||
|
||||
|
||||
_KL_REGISTRY: dict[
|
||||
tuple[type, type], Callable
|
||||
] = {} # Source of truth mapping a few general (type, type) pairs to functions.
|
||||
_KL_MEMOIZE: dict[
|
||||
tuple[type, type], Callable
|
||||
] = {} # Memoized version mapping many specific (type, type) pairs to functions.
|
||||
|
||||
__all__ = ["register_kl", "kl_divergence"]
|
||||
|
||||
|
||||
def register_kl(type_p, type_q):
|
||||
"""
|
||||
Decorator to register a pairwise function with :meth:`kl_divergence`.
|
||||
Usage::
|
||||
|
||||
@register_kl(Normal, Normal)
|
||||
def kl_normal_normal(p, q):
|
||||
# insert implementation here
|
||||
|
||||
Lookup returns the most specific (type,type) match ordered by subclass. If
|
||||
the match is ambiguous, a `RuntimeWarning` is raised. For example to
|
||||
resolve the ambiguous situation::
|
||||
|
||||
@register_kl(BaseP, DerivedQ)
|
||||
def kl_version1(p, q): ...
|
||||
@register_kl(DerivedP, BaseQ)
|
||||
def kl_version2(p, q): ...
|
||||
|
||||
you should register a third most-specific implementation, e.g.::
|
||||
|
||||
register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
|
||||
|
||||
Args:
|
||||
type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
|
||||
type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
|
||||
"""
|
||||
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
|
||||
raise TypeError(
|
||||
f"Expected type_p to be a Distribution subclass but got {type_p}"
|
||||
)
|
||||
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
|
||||
raise TypeError(
|
||||
f"Expected type_q to be a Distribution subclass but got {type_q}"
|
||||
)
|
||||
|
||||
def decorator(fun):
|
||||
_KL_REGISTRY[type_p, type_q] = fun
|
||||
_KL_MEMOIZE.clear() # reset since lookup order may have changed
|
||||
return fun
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@total_ordering
|
||||
class _Match:
|
||||
__slots__ = ["types"]
|
||||
|
||||
def __init__(self, *types):
|
||||
self.types = types
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.types == other.types
|
||||
|
||||
def __le__(self, other):
|
||||
for x, y in zip(self.types, other.types):
|
||||
if not issubclass(x, y):
|
||||
return False
|
||||
if x is not y:
|
||||
break
|
||||
return True
|
||||
|
||||
|
||||
def _dispatch_kl(type_p, type_q):
|
||||
"""
|
||||
Find the most specific approximate match, assuming single inheritance.
|
||||
"""
|
||||
matches = [
|
||||
(super_p, super_q)
|
||||
for super_p, super_q in _KL_REGISTRY
|
||||
if issubclass(type_p, super_p) and issubclass(type_q, super_q)
|
||||
]
|
||||
if not matches:
|
||||
return NotImplemented
|
||||
# Check that the left- and right- lexicographic orders agree.
|
||||
# mypy isn't smart enough to know that _Match implements __lt__
|
||||
# see: https://github.com/python/typing/issues/760#issuecomment-710670503
|
||||
left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var]
|
||||
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var]
|
||||
left_fun = _KL_REGISTRY[left_p, left_q]
|
||||
right_fun = _KL_REGISTRY[right_p, right_q]
|
||||
if left_fun is not right_fun:
|
||||
warnings.warn(
|
||||
f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). "
|
||||
f"Please register_kl({left_p.__name__}, {right_q.__name__})",
|
||||
RuntimeWarning,
|
||||
)
|
||||
return left_fun
|
||||
|
||||
|
||||
def _infinite_like(tensor):
|
||||
"""
|
||||
Helper function for obtaining infinite KL Divergence throughout
|
||||
"""
|
||||
return torch.full_like(tensor, inf)
|
||||
|
||||
|
||||
def _x_log_x(tensor):
|
||||
"""
|
||||
Utility function for calculating x log x
|
||||
"""
|
||||
return torch.special.xlogy(tensor, tensor) # produces correct result for x=0
|
||||
|
||||
|
||||
def _batch_trace_XXT(bmat):
|
||||
"""
|
||||
Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
|
||||
"""
|
||||
n = bmat.size(-1)
|
||||
m = bmat.size(-2)
|
||||
flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
|
||||
return flat_trace.reshape(bmat.shape[:-2])
|
||||
|
||||
|
||||
def kl_divergence(p: Distribution, q: Distribution) -> Tensor:
|
||||
r"""
|
||||
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
|
||||
|
||||
.. math::
|
||||
|
||||
KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
|
||||
|
||||
Args:
|
||||
p (Distribution): A :class:`~torch.distributions.Distribution` object.
|
||||
q (Distribution): A :class:`~torch.distributions.Distribution` object.
|
||||
|
||||
Returns:
|
||||
Tensor: A batch of KL divergences of shape `batch_shape`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the distribution types have not been registered via
|
||||
:meth:`register_kl`.
|
||||
"""
|
||||
try:
|
||||
fun = _KL_MEMOIZE[type(p), type(q)]
|
||||
except KeyError:
|
||||
fun = _dispatch_kl(type(p), type(q))
|
||||
_KL_MEMOIZE[type(p), type(q)] = fun
|
||||
if fun is NotImplemented:
|
||||
raise NotImplementedError(
|
||||
f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}"
|
||||
)
|
||||
return fun(p, q)
|
||||
|
||||
|
||||
################################################################################
|
||||
# KL Divergence Implementations
|
||||
################################################################################
|
||||
|
||||
# Same distributions
|
||||
|
||||
|
||||
@register_kl(Bernoulli, Bernoulli)
|
||||
def _kl_bernoulli_bernoulli(p, q):
|
||||
t1 = p.probs * (
|
||||
torch.nn.functional.softplus(-q.logits)
|
||||
- torch.nn.functional.softplus(-p.logits)
|
||||
)
|
||||
t1[q.probs == 0] = inf
|
||||
t1[p.probs == 0] = 0
|
||||
t2 = (1 - p.probs) * (
|
||||
torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)
|
||||
)
|
||||
t2[q.probs == 1] = inf
|
||||
t2[p.probs == 1] = 0
|
||||
return t1 + t2
|
||||
|
||||
|
||||
@register_kl(Beta, Beta)
|
||||
def _kl_beta_beta(p, q):
|
||||
sum_params_p = p.concentration1 + p.concentration0
|
||||
sum_params_q = q.concentration1 + q.concentration0
|
||||
t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
|
||||
t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
|
||||
t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
|
||||
t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
|
||||
t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
|
||||
return t1 - t2 + t3 + t4 + t5
|
||||
|
||||
|
||||
@register_kl(Binomial, Binomial)
|
||||
def _kl_binomial_binomial(p, q):
|
||||
# from https://math.stackexchange.com/questions/2214993/
|
||||
# kullback-leibler-divergence-for-binomial-distributions-p-and-q
|
||||
if (p.total_count < q.total_count).any():
|
||||
raise NotImplementedError(
|
||||
"KL between Binomials where q.total_count > p.total_count is not implemented"
|
||||
)
|
||||
kl = p.total_count * (
|
||||
p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()
|
||||
)
|
||||
inf_idxs = p.total_count > q.total_count
|
||||
kl[inf_idxs] = _infinite_like(kl[inf_idxs])
|
||||
return kl
|
||||
|
||||
|
||||
@register_kl(Categorical, Categorical)
|
||||
def _kl_categorical_categorical(p, q):
|
||||
t = p.probs * (p.logits - q.logits)
|
||||
t[(q.probs == 0).expand_as(t)] = inf
|
||||
t[(p.probs == 0).expand_as(t)] = 0
|
||||
return t.sum(-1)
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, ContinuousBernoulli)
|
||||
def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
|
||||
t1 = p.mean * (p.logits - q.logits)
|
||||
t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
|
||||
t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs)
|
||||
return t1 + t2 + t3
|
||||
|
||||
|
||||
@register_kl(Dirichlet, Dirichlet)
|
||||
def _kl_dirichlet_dirichlet(p, q):
|
||||
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
|
||||
sum_p_concentration = p.concentration.sum(-1)
|
||||
sum_q_concentration = q.concentration.sum(-1)
|
||||
t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
|
||||
t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
|
||||
t3 = p.concentration - q.concentration
|
||||
t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
|
||||
return t1 - t2 + (t3 * t4).sum(-1)
|
||||
|
||||
|
||||
@register_kl(Exponential, Exponential)
|
||||
def _kl_exponential_exponential(p, q):
|
||||
rate_ratio = q.rate / p.rate
|
||||
t1 = -rate_ratio.log()
|
||||
return t1 + rate_ratio - 1
|
||||
|
||||
|
||||
@register_kl(ExponentialFamily, ExponentialFamily)
|
||||
def _kl_expfamily_expfamily(p, q):
|
||||
if not type(p) == type(q):
|
||||
raise NotImplementedError(
|
||||
"The cross KL-divergence between different exponential families cannot \
|
||||
be computed using Bregman divergences"
|
||||
)
|
||||
p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
|
||||
q_nparams = q._natural_params
|
||||
lg_normal = p._log_normalizer(*p_nparams)
|
||||
gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
|
||||
result = q._log_normalizer(*q_nparams) - lg_normal
|
||||
for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
|
||||
term = (qnp - pnp) * g
|
||||
result -= _sum_rightmost(term, len(q.event_shape))
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Gamma, Gamma)
|
||||
def _kl_gamma_gamma(p, q):
|
||||
t1 = q.concentration * (p.rate / q.rate).log()
|
||||
t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
|
||||
t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
|
||||
t4 = (q.rate - p.rate) * (p.concentration / p.rate)
|
||||
return t1 + t2 + t3 + t4
|
||||
|
||||
|
||||
@register_kl(Gumbel, Gumbel)
|
||||
def _kl_gumbel_gumbel(p, q):
|
||||
ct1 = p.scale / q.scale
|
||||
ct2 = q.loc / q.scale
|
||||
ct3 = p.loc / q.scale
|
||||
t1 = -ct1.log() - ct2 + ct3
|
||||
t2 = ct1 * _euler_gamma
|
||||
t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
|
||||
return t1 + t2 + t3 - (1 + _euler_gamma)
|
||||
|
||||
|
||||
@register_kl(Geometric, Geometric)
|
||||
def _kl_geometric_geometric(p, q):
|
||||
return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
|
||||
|
||||
|
||||
@register_kl(HalfNormal, HalfNormal)
|
||||
def _kl_halfnormal_halfnormal(p, q):
|
||||
return _kl_normal_normal(p.base_dist, q.base_dist)
|
||||
|
||||
|
||||
@register_kl(Laplace, Laplace)
|
||||
def _kl_laplace_laplace(p, q):
|
||||
# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
|
||||
scale_ratio = p.scale / q.scale
|
||||
loc_abs_diff = (p.loc - q.loc).abs()
|
||||
t1 = -scale_ratio.log()
|
||||
t2 = loc_abs_diff / q.scale
|
||||
t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
|
||||
return t1 + t2 + t3 - 1
|
||||
|
||||
|
||||
@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
|
||||
def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
|
||||
if p.event_shape != q.event_shape:
|
||||
raise ValueError(
|
||||
"KL-divergence between two Low Rank Multivariate Normals with\
|
||||
different event shapes cannot be computed"
|
||||
)
|
||||
|
||||
term1 = _batch_lowrank_logdet(
|
||||
q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
|
||||
) - _batch_lowrank_logdet(
|
||||
p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
|
||||
)
|
||||
term3 = _batch_lowrank_mahalanobis(
|
||||
q._unbroadcasted_cov_factor,
|
||||
q._unbroadcasted_cov_diag,
|
||||
q.loc - p.loc,
|
||||
q._capacitance_tril,
|
||||
)
|
||||
# Expands term2 according to
|
||||
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
|
||||
# = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
|
||||
qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
|
||||
A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
|
||||
term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
|
||||
term22 = _batch_trace_XXT(
|
||||
p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
|
||||
)
|
||||
term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
|
||||
term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
|
||||
term2 = term21 + term22 - term23 - term24
|
||||
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
|
||||
|
||||
|
||||
@register_kl(MultivariateNormal, LowRankMultivariateNormal)
|
||||
def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
|
||||
if p.event_shape != q.event_shape:
|
||||
raise ValueError(
|
||||
"KL-divergence between two (Low Rank) Multivariate Normals with\
|
||||
different event shapes cannot be computed"
|
||||
)
|
||||
|
||||
term1 = _batch_lowrank_logdet(
|
||||
q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
|
||||
) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
|
||||
term3 = _batch_lowrank_mahalanobis(
|
||||
q._unbroadcasted_cov_factor,
|
||||
q._unbroadcasted_cov_diag,
|
||||
q.loc - p.loc,
|
||||
q._capacitance_tril,
|
||||
)
|
||||
# Expands term2 according to
|
||||
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
|
||||
# = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
|
||||
qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
|
||||
A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
|
||||
term21 = _batch_trace_XXT(
|
||||
p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
|
||||
)
|
||||
term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
|
||||
term2 = term21 - term22
|
||||
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
|
||||
|
||||
|
||||
@register_kl(LowRankMultivariateNormal, MultivariateNormal)
|
||||
def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
|
||||
if p.event_shape != q.event_shape:
|
||||
raise ValueError(
|
||||
"KL-divergence between two (Low Rank) Multivariate Normals with\
|
||||
different event shapes cannot be computed"
|
||||
)
|
||||
|
||||
term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
|
||||
-1
|
||||
) - _batch_lowrank_logdet(
|
||||
p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
|
||||
)
|
||||
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
|
||||
# Expands term2 according to
|
||||
# inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
|
||||
combined_batch_shape = torch._C._infer_size(
|
||||
q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2]
|
||||
)
|
||||
n = p.event_shape[0]
|
||||
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
|
||||
p_cov_factor = p._unbroadcasted_cov_factor.expand(
|
||||
combined_batch_shape + (n, p.cov_factor.size(-1))
|
||||
)
|
||||
p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand(
|
||||
combined_batch_shape + (n, n)
|
||||
)
|
||||
term21 = _batch_trace_XXT(
|
||||
torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False)
|
||||
)
|
||||
term22 = _batch_trace_XXT(
|
||||
torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False)
|
||||
)
|
||||
term2 = term21 + term22
|
||||
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
|
||||
|
||||
|
||||
@register_kl(MultivariateNormal, MultivariateNormal)
|
||||
def _kl_multivariatenormal_multivariatenormal(p, q):
|
||||
# From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
|
||||
if p.event_shape != q.event_shape:
|
||||
raise ValueError(
|
||||
"KL-divergence between two Multivariate Normals with\
|
||||
different event shapes cannot be computed"
|
||||
)
|
||||
|
||||
half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
|
||||
-1
|
||||
) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
|
||||
combined_batch_shape = torch._C._infer_size(
|
||||
q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2]
|
||||
)
|
||||
n = p.event_shape[0]
|
||||
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
|
||||
p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
|
||||
term2 = _batch_trace_XXT(
|
||||
torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False)
|
||||
)
|
||||
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
|
||||
return half_term1 + 0.5 * (term2 + term3 - n)
|
||||
|
||||
|
||||
@register_kl(Normal, Normal)
|
||||
def _kl_normal_normal(p, q):
|
||||
var_ratio = (p.scale / q.scale).pow(2)
|
||||
t1 = ((p.loc - q.loc) / q.scale).pow(2)
|
||||
return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
|
||||
|
||||
|
||||
@register_kl(OneHotCategorical, OneHotCategorical)
|
||||
def _kl_onehotcategorical_onehotcategorical(p, q):
|
||||
return _kl_categorical_categorical(p._categorical, q._categorical)
|
||||
|
||||
|
||||
@register_kl(Pareto, Pareto)
|
||||
def _kl_pareto_pareto(p, q):
|
||||
# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
|
||||
scale_ratio = p.scale / q.scale
|
||||
alpha_ratio = q.alpha / p.alpha
|
||||
t1 = q.alpha * scale_ratio.log()
|
||||
t2 = -alpha_ratio.log()
|
||||
result = t1 + t2 + alpha_ratio - 1
|
||||
result[p.support.lower_bound < q.support.lower_bound] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Poisson, Poisson)
|
||||
def _kl_poisson_poisson(p, q):
|
||||
return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
|
||||
|
||||
|
||||
@register_kl(TransformedDistribution, TransformedDistribution)
|
||||
def _kl_transformed_transformed(p, q):
|
||||
if p.transforms != q.transforms:
|
||||
raise NotImplementedError
|
||||
if p.event_shape != q.event_shape:
|
||||
raise NotImplementedError
|
||||
return kl_divergence(p.base_dist, q.base_dist)
|
||||
|
||||
|
||||
@register_kl(Uniform, Uniform)
|
||||
def _kl_uniform_uniform(p, q):
|
||||
result = ((q.high - q.low) / (p.high - p.low)).log()
|
||||
result[(q.low > p.low) | (q.high < p.high)] = inf
|
||||
return result
|
||||
|
||||
|
||||
# Different distributions
|
||||
@register_kl(Bernoulli, Poisson)
|
||||
def _kl_bernoulli_poisson(p, q):
|
||||
return -p.entropy() - (p.probs * q.rate.log() - q.rate)
|
||||
|
||||
|
||||
@register_kl(Beta, ContinuousBernoulli)
|
||||
def _kl_beta_continuous_bernoulli(p, q):
|
||||
return (
|
||||
-p.entropy()
|
||||
- p.mean * q.logits
|
||||
- torch.log1p(-q.probs)
|
||||
- q._cont_bern_log_norm()
|
||||
)
|
||||
|
||||
|
||||
@register_kl(Beta, Pareto)
|
||||
def _kl_beta_infinity(p, q):
|
||||
return _infinite_like(p.concentration1)
|
||||
|
||||
|
||||
@register_kl(Beta, Exponential)
|
||||
def _kl_beta_exponential(p, q):
|
||||
return (
|
||||
-p.entropy()
|
||||
- q.rate.log()
|
||||
+ q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
|
||||
)
|
||||
|
||||
|
||||
@register_kl(Beta, Gamma)
|
||||
def _kl_beta_gamma(p, q):
|
||||
t1 = -p.entropy()
|
||||
t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
|
||||
t3 = (q.concentration - 1) * (
|
||||
p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma()
|
||||
)
|
||||
t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
|
||||
return t1 + t2 - t3 + t4
|
||||
|
||||
|
||||
# TODO: Add Beta-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(Beta, Normal)
|
||||
def _kl_beta_normal(p, q):
|
||||
E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
|
||||
var_normal = q.scale.pow(2)
|
||||
t1 = -p.entropy()
|
||||
t2 = 0.5 * (var_normal * 2 * math.pi).log()
|
||||
t3 = (
|
||||
E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1)
|
||||
+ E_beta.pow(2)
|
||||
) * 0.5
|
||||
t4 = q.loc * E_beta
|
||||
t5 = q.loc.pow(2) * 0.5
|
||||
return t1 + t2 + (t3 - t4 + t5) / var_normal
|
||||
|
||||
|
||||
@register_kl(Beta, Uniform)
|
||||
def _kl_beta_uniform(p, q):
|
||||
result = -p.entropy() + (q.high - q.low).log()
|
||||
result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
|
||||
return result
|
||||
|
||||
|
||||
# Note that the KL between a ContinuousBernoulli and Beta has no closed form
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Pareto)
|
||||
def _kl_continuous_bernoulli_infinity(p, q):
|
||||
return _infinite_like(p.probs)
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Exponential)
|
||||
def _kl_continuous_bernoulli_exponential(p, q):
|
||||
return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
|
||||
|
||||
|
||||
# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
|
||||
# TODO: Add ContinuousBernoulli-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Normal)
|
||||
def _kl_continuous_bernoulli_normal(p, q):
|
||||
t1 = -p.entropy()
|
||||
t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log(
|
||||
q.scale
|
||||
)
|
||||
t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / (
|
||||
2.0 * torch.square(q.scale)
|
||||
)
|
||||
return t1 + t2 + t3
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Uniform)
|
||||
def _kl_continuous_bernoulli_uniform(p, q):
|
||||
result = -p.entropy() + (q.high - q.low).log()
|
||||
return torch.where(
|
||||
torch.max(
|
||||
torch.ge(q.low, p.support.lower_bound),
|
||||
torch.le(q.high, p.support.upper_bound),
|
||||
),
|
||||
torch.ones_like(result) * inf,
|
||||
result,
|
||||
)
|
||||
|
||||
|
||||
@register_kl(Exponential, Beta)
|
||||
@register_kl(Exponential, ContinuousBernoulli)
|
||||
@register_kl(Exponential, Pareto)
|
||||
@register_kl(Exponential, Uniform)
|
||||
def _kl_exponential_infinity(p, q):
|
||||
return _infinite_like(p.rate)
|
||||
|
||||
|
||||
@register_kl(Exponential, Gamma)
|
||||
def _kl_exponential_gamma(p, q):
|
||||
ratio = q.rate / p.rate
|
||||
t1 = -q.concentration * torch.log(ratio)
|
||||
return (
|
||||
t1
|
||||
+ ratio
|
||||
+ q.concentration.lgamma()
|
||||
+ q.concentration * _euler_gamma
|
||||
- (1 + _euler_gamma)
|
||||
)
|
||||
|
||||
|
||||
@register_kl(Exponential, Gumbel)
|
||||
def _kl_exponential_gumbel(p, q):
|
||||
scale_rate_prod = p.rate * q.scale
|
||||
loc_scale_ratio = q.loc / q.scale
|
||||
t1 = scale_rate_prod.log() - 1
|
||||
t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
|
||||
t3 = scale_rate_prod.reciprocal()
|
||||
return t1 - loc_scale_ratio + t2 + t3
|
||||
|
||||
|
||||
# TODO: Add Exponential-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(Exponential, Normal)
|
||||
def _kl_exponential_normal(p, q):
|
||||
var_normal = q.scale.pow(2)
|
||||
rate_sqr = p.rate.pow(2)
|
||||
t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
|
||||
t2 = rate_sqr.reciprocal()
|
||||
t3 = q.loc / p.rate
|
||||
t4 = q.loc.pow(2) * 0.5
|
||||
return t1 - 1 + (t2 - t3 + t4) / var_normal
|
||||
|
||||
|
||||
@register_kl(Gamma, Beta)
|
||||
@register_kl(Gamma, ContinuousBernoulli)
|
||||
@register_kl(Gamma, Pareto)
|
||||
@register_kl(Gamma, Uniform)
|
||||
def _kl_gamma_infinity(p, q):
|
||||
return _infinite_like(p.concentration)
|
||||
|
||||
|
||||
@register_kl(Gamma, Exponential)
|
||||
def _kl_gamma_exponential(p, q):
|
||||
return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
|
||||
|
||||
|
||||
@register_kl(Gamma, Gumbel)
|
||||
def _kl_gamma_gumbel(p, q):
|
||||
beta_scale_prod = p.rate * q.scale
|
||||
loc_scale_ratio = q.loc / q.scale
|
||||
t1 = (
|
||||
(p.concentration - 1) * p.concentration.digamma()
|
||||
- p.concentration.lgamma()
|
||||
- p.concentration
|
||||
)
|
||||
t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
|
||||
t3 = (
|
||||
torch.exp(loc_scale_ratio)
|
||||
* (1 + beta_scale_prod.reciprocal()).pow(-p.concentration)
|
||||
- loc_scale_ratio
|
||||
)
|
||||
return t1 + t2 + t3
|
||||
|
||||
|
||||
# TODO: Add Gamma-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(Gamma, Normal)
|
||||
def _kl_gamma_normal(p, q):
|
||||
var_normal = q.scale.pow(2)
|
||||
beta_sqr = p.rate.pow(2)
|
||||
t1 = (
|
||||
0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi)
|
||||
- p.concentration
|
||||
- p.concentration.lgamma()
|
||||
)
|
||||
t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
|
||||
t3 = q.loc * p.concentration / p.rate
|
||||
t4 = 0.5 * q.loc.pow(2)
|
||||
return (
|
||||
t1
|
||||
+ (p.concentration - 1) * p.concentration.digamma()
|
||||
+ (t2 - t3 + t4) / var_normal
|
||||
)
|
||||
|
||||
|
||||
@register_kl(Gumbel, Beta)
|
||||
@register_kl(Gumbel, ContinuousBernoulli)
|
||||
@register_kl(Gumbel, Exponential)
|
||||
@register_kl(Gumbel, Gamma)
|
||||
@register_kl(Gumbel, Pareto)
|
||||
@register_kl(Gumbel, Uniform)
|
||||
def _kl_gumbel_infinity(p, q):
|
||||
return _infinite_like(p.loc)
|
||||
|
||||
|
||||
# TODO: Add Gumbel-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(Gumbel, Normal)
|
||||
def _kl_gumbel_normal(p, q):
|
||||
param_ratio = p.scale / q.scale
|
||||
t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
|
||||
t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
|
||||
t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
|
||||
return -t1 + t2 + t3 - (_euler_gamma + 1)
|
||||
|
||||
|
||||
@register_kl(Laplace, Beta)
|
||||
@register_kl(Laplace, ContinuousBernoulli)
|
||||
@register_kl(Laplace, Exponential)
|
||||
@register_kl(Laplace, Gamma)
|
||||
@register_kl(Laplace, Pareto)
|
||||
@register_kl(Laplace, Uniform)
|
||||
def _kl_laplace_infinity(p, q):
|
||||
return _infinite_like(p.loc)
|
||||
|
||||
|
||||
@register_kl(Laplace, Normal)
|
||||
def _kl_laplace_normal(p, q):
|
||||
var_normal = q.scale.pow(2)
|
||||
scale_sqr_var_ratio = p.scale.pow(2) / var_normal
|
||||
t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
|
||||
t2 = 0.5 * p.loc.pow(2)
|
||||
t3 = p.loc * q.loc
|
||||
t4 = 0.5 * q.loc.pow(2)
|
||||
return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
|
||||
|
||||
|
||||
@register_kl(Normal, Beta)
|
||||
@register_kl(Normal, ContinuousBernoulli)
|
||||
@register_kl(Normal, Exponential)
|
||||
@register_kl(Normal, Gamma)
|
||||
@register_kl(Normal, Pareto)
|
||||
@register_kl(Normal, Uniform)
|
||||
def _kl_normal_infinity(p, q):
|
||||
return _infinite_like(p.loc)
|
||||
|
||||
|
||||
@register_kl(Normal, Gumbel)
|
||||
def _kl_normal_gumbel(p, q):
|
||||
mean_scale_ratio = p.loc / q.scale
|
||||
var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
|
||||
loc_scale_ratio = q.loc / q.scale
|
||||
t1 = var_scale_sqr_ratio.log() * 0.5
|
||||
t2 = mean_scale_ratio - loc_scale_ratio
|
||||
t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
|
||||
return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
|
||||
|
||||
|
||||
@register_kl(Normal, Laplace)
|
||||
def _kl_normal_laplace(p, q):
|
||||
loc_diff = p.loc - q.loc
|
||||
scale_ratio = p.scale / q.scale
|
||||
loc_diff_scale_ratio = loc_diff / p.scale
|
||||
t1 = torch.log(scale_ratio)
|
||||
t2 = (
|
||||
math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2))
|
||||
)
|
||||
t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio)
|
||||
return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi)))
|
||||
|
||||
|
||||
@register_kl(Pareto, Beta)
|
||||
@register_kl(Pareto, ContinuousBernoulli)
|
||||
@register_kl(Pareto, Uniform)
|
||||
def _kl_pareto_infinity(p, q):
|
||||
return _infinite_like(p.scale)
|
||||
|
||||
|
||||
@register_kl(Pareto, Exponential)
|
||||
def _kl_pareto_exponential(p, q):
|
||||
scale_rate_prod = p.scale * q.rate
|
||||
t1 = (p.alpha / scale_rate_prod).log()
|
||||
t2 = p.alpha.reciprocal()
|
||||
t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
|
||||
result = t1 - t2 + t3 - 1
|
||||
result[p.alpha <= 1] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Pareto, Gamma)
|
||||
def _kl_pareto_gamma(p, q):
|
||||
common_term = p.scale.log() + p.alpha.reciprocal()
|
||||
t1 = p.alpha.log() - common_term
|
||||
t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
|
||||
t3 = (1 - q.concentration) * common_term
|
||||
t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
|
||||
result = t1 + t2 + t3 + t4 - 1
|
||||
result[p.alpha <= 1] = inf
|
||||
return result
|
||||
|
||||
|
||||
# TODO: Add Pareto-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(Pareto, Normal)
|
||||
def _kl_pareto_normal(p, q):
|
||||
var_normal = 2 * q.scale.pow(2)
|
||||
common_term = p.scale / (p.alpha - 1)
|
||||
t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
|
||||
t2 = p.alpha.reciprocal()
|
||||
t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
|
||||
t4 = (p.alpha * common_term - q.loc).pow(2)
|
||||
result = t1 - t2 + (t3 + t4) / var_normal - 1
|
||||
result[p.alpha <= 2] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Poisson, Bernoulli)
|
||||
@register_kl(Poisson, Binomial)
|
||||
def _kl_poisson_infinity(p, q):
|
||||
return _infinite_like(p.rate)
|
||||
|
||||
|
||||
@register_kl(Uniform, Beta)
|
||||
def _kl_uniform_beta(p, q):
|
||||
common_term = p.high - p.low
|
||||
t1 = torch.log(common_term)
|
||||
t2 = (
|
||||
(q.concentration1 - 1)
|
||||
* (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
|
||||
/ common_term
|
||||
)
|
||||
t3 = (
|
||||
(q.concentration0 - 1)
|
||||
* (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term)
|
||||
/ common_term
|
||||
)
|
||||
t4 = (
|
||||
q.concentration1.lgamma()
|
||||
+ q.concentration0.lgamma()
|
||||
- (q.concentration1 + q.concentration0).lgamma()
|
||||
)
|
||||
result = t3 + t4 - t1 - t2
|
||||
result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Uniform, ContinuousBernoulli)
|
||||
def _kl_uniform_continuous_bernoulli(p, q):
|
||||
result = (
|
||||
-p.entropy()
|
||||
- p.mean * q.logits
|
||||
- torch.log1p(-q.probs)
|
||||
- q._cont_bern_log_norm()
|
||||
)
|
||||
return torch.where(
|
||||
torch.max(
|
||||
torch.ge(p.high, q.support.upper_bound),
|
||||
torch.le(p.low, q.support.lower_bound),
|
||||
),
|
||||
torch.ones_like(result) * inf,
|
||||
result,
|
||||
)
|
||||
|
||||
|
||||
@register_kl(Uniform, Exponential)
|
||||
def _kl_uniform_exponetial(p, q):
|
||||
result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
|
||||
result[p.low < q.support.lower_bound] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Uniform, Gamma)
|
||||
def _kl_uniform_gamma(p, q):
|
||||
common_term = p.high - p.low
|
||||
t1 = common_term.log()
|
||||
t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
|
||||
t3 = (
|
||||
(1 - q.concentration)
|
||||
* (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
|
||||
/ common_term
|
||||
)
|
||||
t4 = q.rate * (p.high + p.low) / 2
|
||||
result = -t1 + t2 + t3 + t4
|
||||
result[p.low < q.support.lower_bound] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Uniform, Gumbel)
|
||||
def _kl_uniform_gumbel(p, q):
|
||||
common_term = q.scale / (p.high - p.low)
|
||||
high_loc_diff = (p.high - q.loc) / q.scale
|
||||
low_loc_diff = (p.low - q.loc) / q.scale
|
||||
t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
|
||||
t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
|
||||
return t1 - t2
|
||||
|
||||
|
||||
# TODO: Uniform-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(Uniform, Normal)
|
||||
def _kl_uniform_normal(p, q):
|
||||
common_term = p.high - p.low
|
||||
t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
|
||||
t2 = (common_term).pow(2) / 12
|
||||
t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
|
||||
return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
|
||||
|
||||
|
||||
@register_kl(Uniform, Pareto)
|
||||
def _kl_uniform_pareto(p, q):
|
||||
support_uniform = p.high - p.low
|
||||
t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
|
||||
t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
|
||||
result = t2 * (q.alpha + 1) - t1
|
||||
result[p.low < q.support.lower_bound] = inf
|
||||
return result
|
||||
|
||||
|
||||
@register_kl(Independent, Independent)
|
||||
def _kl_independent_independent(p, q):
|
||||
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
|
||||
raise NotImplementedError
|
||||
result = kl_divergence(p.base_dist, q.base_dist)
|
||||
return _sum_rightmost(result, p.reinterpreted_batch_ndims)
|
||||
|
||||
|
||||
@register_kl(Cauchy, Cauchy)
|
||||
def _kl_cauchy_cauchy(p, q):
|
||||
# From https://arxiv.org/abs/1905.10965
|
||||
t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log()
|
||||
t2 = (4 * p.scale * q.scale).log()
|
||||
return t1 - t2
|
||||
|
||||
|
||||
def _add_kl_info():
|
||||
"""Appends a list of implemented KL functions to the doc for kl_divergence."""
|
||||
rows = [
|
||||
"KL divergence is currently implemented for the following distribution pairs:"
|
||||
]
|
||||
for p, q in sorted(
|
||||
_KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)
|
||||
):
|
||||
rows.append(
|
||||
f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`"
|
||||
)
|
||||
kl_info = "\n\t".join(rows)
|
||||
if kl_divergence.__doc__:
|
||||
kl_divergence.__doc__ += kl_info
|
99
venv/Lib/site-packages/torch/distributions/kumaraswamy.py
Normal file
99
venv/Lib/site-packages/torch/distributions/kumaraswamy.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import AffineTransform, PowerTransform
|
||||
from torch.distributions.uniform import Uniform
|
||||
from torch.distributions.utils import broadcast_all, euler_constant
|
||||
|
||||
|
||||
__all__ = ["Kumaraswamy"]
|
||||
|
||||
|
||||
def _moments(a, b, n):
|
||||
"""
|
||||
Computes nth moment of Kumaraswamy using using torch.lgamma
|
||||
"""
|
||||
arg1 = 1 + n / a
|
||||
log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b)
|
||||
return b * torch.exp(log_value)
|
||||
|
||||
|
||||
class Kumaraswamy(TransformedDistribution):
|
||||
r"""
|
||||
Samples from a Kumaraswamy distribution.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1
|
||||
tensor([ 0.1729])
|
||||
|
||||
Args:
|
||||
concentration1 (float or Tensor): 1st concentration parameter of the distribution
|
||||
(often referred to as alpha)
|
||||
concentration0 (float or Tensor): 2nd concentration parameter of the distribution
|
||||
(often referred to as beta)
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"concentration1": constraints.positive,
|
||||
"concentration0": constraints.positive,
|
||||
}
|
||||
support = constraints.unit_interval
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, concentration1, concentration0, validate_args=None):
|
||||
self.concentration1, self.concentration0 = broadcast_all(
|
||||
concentration1, concentration0
|
||||
)
|
||||
base_dist = Uniform(
|
||||
torch.full_like(self.concentration0, 0),
|
||||
torch.full_like(self.concentration0, 1),
|
||||
validate_args=validate_args,
|
||||
)
|
||||
transforms = [
|
||||
PowerTransform(exponent=self.concentration0.reciprocal()),
|
||||
AffineTransform(loc=1.0, scale=-1.0),
|
||||
PowerTransform(exponent=self.concentration1.reciprocal()),
|
||||
]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Kumaraswamy, _instance)
|
||||
new.concentration1 = self.concentration1.expand(batch_shape)
|
||||
new.concentration0 = self.concentration0.expand(batch_shape)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return _moments(self.concentration1, self.concentration0, 1)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
# Evaluate in log-space for numerical stability.
|
||||
log_mode = (
|
||||
self.concentration0.reciprocal() * (-self.concentration0).log1p()
|
||||
- (-self.concentration0 * self.concentration1).log1p()
|
||||
)
|
||||
log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan
|
||||
return log_mode.exp()
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return _moments(self.concentration1, self.concentration0, 2) - torch.pow(
|
||||
self.mean, 2
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
t1 = 1 - self.concentration1.reciprocal()
|
||||
t0 = 1 - self.concentration0.reciprocal()
|
||||
H0 = torch.digamma(self.concentration0 + 1) + euler_constant
|
||||
return (
|
||||
t0
|
||||
+ t1 * H0
|
||||
- torch.log(self.concentration1)
|
||||
- torch.log(self.concentration0)
|
||||
)
|
97
venv/Lib/site-packages/torch/distributions/laplace.py
Normal file
97
venv/Lib/site-packages/torch/distributions/laplace.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Laplace"]
|
||||
|
||||
|
||||
class Laplace(Distribution):
|
||||
r"""
|
||||
Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # Laplace distributed with loc=0, scale=1
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
loc (float or Tensor): mean of the distribution
|
||||
scale (float or Tensor): scale of the distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return 2 * self.scale.pow(2)
|
||||
|
||||
@property
|
||||
def stddev(self) -> Tensor:
|
||||
return (2**0.5) * self.scale
|
||||
|
||||
def __init__(self, loc, scale, validate_args=None):
|
||||
self.loc, self.scale = broadcast_all(loc, scale)
|
||||
if isinstance(loc, _Number) and isinstance(scale, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.loc.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Laplace, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.loc = self.loc.expand(batch_shape)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
super(Laplace, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
finfo = torch.finfo(self.loc.dtype)
|
||||
if torch._C._get_tracing_state():
|
||||
# [JIT WORKAROUND] lack of support for .uniform_()
|
||||
u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1
|
||||
return self.loc - self.scale * u.sign() * torch.log1p(
|
||||
-u.abs().clamp(min=finfo.tiny)
|
||||
)
|
||||
u = self.loc.new(shape).uniform_(finfo.eps - 1, 1)
|
||||
# TODO: If we ever implement tensor.nextafter, below is what we want ideally.
|
||||
# u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5)
|
||||
return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(
|
||||
-(value - self.loc).abs() / self.scale
|
||||
)
|
||||
|
||||
def icdf(self, value):
|
||||
term = value - 0.5
|
||||
return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())
|
||||
|
||||
def entropy(self):
|
||||
return 1 + torch.log(2 * self.scale)
|
145
venv/Lib/site-packages/torch/distributions/lkj_cholesky.py
Normal file
145
venv/Lib/site-packages/torch/distributions/lkj_cholesky.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).
|
||||
|
||||
Original copyright notice:
|
||||
|
||||
# Copyright: Contributors to the Pyro project.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.distributions import Beta, constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all
|
||||
|
||||
|
||||
__all__ = ["LKJCholesky"]
|
||||
|
||||
|
||||
class LKJCholesky(Distribution):
|
||||
r"""
|
||||
LKJ distribution for lower Cholesky factor of correlation matrices.
|
||||
The distribution is controlled by ``concentration`` parameter :math:`\eta`
|
||||
to make the probability of the correlation matrix :math:`M` generated from
|
||||
a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that,
|
||||
when ``concentration == 1``, we have a uniform distribution over Cholesky
|
||||
factors of correlation matrices::
|
||||
|
||||
L ~ LKJCholesky(dim, concentration)
|
||||
X = L @ L' ~ LKJCorr(dim, concentration)
|
||||
|
||||
Note that this distribution samples the
|
||||
Cholesky factor of correlation matrices and not the correlation matrices
|
||||
themselves and thereby differs slightly from the derivations in [1] for
|
||||
the `LKJCorr` distribution. For sampling, this uses the Onion method from
|
||||
[1] Section 3.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> l = LKJCholesky(3, 0.5)
|
||||
>>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix
|
||||
tensor([[ 1.0000, 0.0000, 0.0000],
|
||||
[ 0.3516, 0.9361, 0.0000],
|
||||
[-0.1899, 0.4748, 0.8593]])
|
||||
|
||||
Args:
|
||||
dimension (dim): dimension of the matrices
|
||||
concentration (float or Tensor): concentration/shape parameter of the
|
||||
distribution (often referred to as eta)
|
||||
|
||||
**References**
|
||||
|
||||
[1] `Generating random correlation matrices based on vines and extended onion method` (2009),
|
||||
Daniel Lewandowski, Dorota Kurowicka, Harry Joe.
|
||||
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
|
||||
"""
|
||||
|
||||
arg_constraints = {"concentration": constraints.positive}
|
||||
support = constraints.corr_cholesky
|
||||
|
||||
def __init__(self, dim, concentration=1.0, validate_args=None):
|
||||
if dim < 2:
|
||||
raise ValueError(
|
||||
f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}."
|
||||
)
|
||||
self.dim = dim
|
||||
(self.concentration,) = broadcast_all(concentration)
|
||||
batch_shape = self.concentration.size()
|
||||
event_shape = torch.Size((dim, dim))
|
||||
# This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].
|
||||
marginal_conc = self.concentration + 0.5 * (self.dim - 2)
|
||||
offset = torch.arange(
|
||||
self.dim - 1,
|
||||
dtype=self.concentration.dtype,
|
||||
device=self.concentration.device,
|
||||
)
|
||||
offset = torch.cat([offset.new_zeros((1,)), offset])
|
||||
beta_conc1 = offset + 0.5
|
||||
beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset
|
||||
self._beta = Beta(beta_conc1, beta_conc0)
|
||||
super().__init__(batch_shape, event_shape, validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(LKJCholesky, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.dim = self.dim
|
||||
new.concentration = self.concentration.expand(batch_shape)
|
||||
new._beta = self._beta.expand(batch_shape + (self.dim,))
|
||||
super(LKJCholesky, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
# This uses the Onion method, but there are a few differences from [1] Sec. 3.2:
|
||||
# - This vectorizes the for loop and also works for heterogeneous eta.
|
||||
# - Same algorithm generalizes to n=1.
|
||||
# - The procedure is simplified since we are sampling the cholesky factor of
|
||||
# the correlation matrix instead of the correlation matrix itself. As such,
|
||||
# we only need to generate `w`.
|
||||
y = self._beta.sample(sample_shape).unsqueeze(-1)
|
||||
u_normal = torch.randn(
|
||||
self._extended_shape(sample_shape), dtype=y.dtype, device=y.device
|
||||
).tril(-1)
|
||||
u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True)
|
||||
# Replace NaNs in first row
|
||||
u_hypersphere[..., 0, :].fill_(0.0)
|
||||
w = torch.sqrt(y) * u_hypersphere
|
||||
# Fill diagonal elements; clamp for numerical stability
|
||||
eps = torch.finfo(w.dtype).tiny
|
||||
diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt()
|
||||
w += torch.diag_embed(diag_elems)
|
||||
return w
|
||||
|
||||
def log_prob(self, value):
|
||||
# See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
|
||||
# The probability of a correlation matrix is proportional to
|
||||
# determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
|
||||
# Additionally, the Jacobian of the transformation from Cholesky factor to
|
||||
# correlation matrix is:
|
||||
# prod(L_ii ^ (D - i))
|
||||
# So the probability of a Cholesky factor is propotional to
|
||||
# prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
|
||||
# with order_i = 2 * concentration - 2 + D - i
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
|
||||
order = torch.arange(2, self.dim + 1, device=self.concentration.device)
|
||||
order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
|
||||
unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
|
||||
# Compute normalization constant (page 1999 of [1])
|
||||
dm1 = self.dim - 1
|
||||
alpha = self.concentration + 0.5 * dm1
|
||||
denominator = torch.lgamma(alpha) * dm1
|
||||
numerator = torch.mvlgamma(alpha - 0.5, dm1)
|
||||
# pi_constant in [1] is D * (D - 1) / 4 * log(pi)
|
||||
# pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
|
||||
# hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
|
||||
pi_constant = 0.5 * dm1 * math.log(math.pi)
|
||||
normalize_term = pi_constant + numerator - denominator
|
||||
return unnormalized_log_pdf - normalize_term
|
66
venv/Lib/site-packages/torch/distributions/log_normal.py
Normal file
66
venv/Lib/site-packages/torch/distributions/log_normal.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import ExpTransform
|
||||
|
||||
|
||||
__all__ = ["LogNormal"]
|
||||
|
||||
|
||||
class LogNormal(TransformedDistribution):
|
||||
r"""
|
||||
Creates a log-normal distribution parameterized by
|
||||
:attr:`loc` and :attr:`scale` where::
|
||||
|
||||
X ~ Normal(loc, scale)
|
||||
Y = exp(X) ~ LogNormal(loc, scale)
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # log-normal distributed with mean=0 and stddev=1
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
loc (float or Tensor): mean of log of distribution
|
||||
scale (float or Tensor): standard deviation of log of the distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.positive
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, loc, scale, validate_args=None):
|
||||
base_dist = Normal(loc, scale, validate_args=validate_args)
|
||||
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(LogNormal, _instance)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def loc(self) -> Tensor:
|
||||
return self.base_dist.loc
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
return self.base_dist.scale
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return (self.loc + self.scale.pow(2) / 2).exp()
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return (self.loc - self.scale.square()).exp()
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
scale_sq = self.scale.pow(2)
|
||||
return scale_sq.expm1() * (2 * self.loc + scale_sq).exp()
|
||||
|
||||
def entropy(self):
|
||||
return self.base_dist.entropy() + self.loc
|
|
@ -0,0 +1,58 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import StickBreakingTransform
|
||||
|
||||
|
||||
__all__ = ["LogisticNormal"]
|
||||
|
||||
|
||||
class LogisticNormal(TransformedDistribution):
|
||||
r"""
|
||||
Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
|
||||
that define the base `Normal` distribution transformed with the
|
||||
`StickBreakingTransform` such that::
|
||||
|
||||
X ~ LogisticNormal(loc, scale)
|
||||
Y = log(X / (1 - X.cumsum(-1)))[..., :-1] ~ Normal(loc, scale)
|
||||
|
||||
Args:
|
||||
loc (float or Tensor): mean of the base distribution
|
||||
scale (float or Tensor): standard deviation of the base distribution
|
||||
|
||||
Example::
|
||||
|
||||
>>> # logistic-normal distributed with mean=(0, 0, 0) and stddev=(1, 1, 1)
|
||||
>>> # of the base Normal distribution
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = LogisticNormal(torch.tensor([0.0] * 3), torch.tensor([1.0] * 3))
|
||||
>>> m.sample()
|
||||
tensor([ 0.7653, 0.0341, 0.0579, 0.1427])
|
||||
|
||||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.simplex
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, loc, scale, validate_args=None):
|
||||
base_dist = Normal(loc, scale, validate_args=validate_args)
|
||||
if not base_dist.batch_shape:
|
||||
base_dist = base_dist.expand([1])
|
||||
super().__init__(
|
||||
base_dist, StickBreakingTransform(), validate_args=validate_args
|
||||
)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(LogisticNormal, _instance)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def loc(self) -> Tensor:
|
||||
return self.base_dist.base_dist.loc
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
return self.base_dist.base_dist.scale
|
|
@ -0,0 +1,244 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
|
||||
from torch.distributions.utils import _standard_normal, lazy_property
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["LowRankMultivariateNormal"]
|
||||
|
||||
|
||||
def _batch_capacitance_tril(W, D):
|
||||
r"""
|
||||
Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
|
||||
and a batch of vectors :math:`D`.
|
||||
"""
|
||||
m = W.size(-1)
|
||||
Wt_Dinv = W.mT / D.unsqueeze(-2)
|
||||
K = torch.matmul(Wt_Dinv, W).contiguous()
|
||||
K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K
|
||||
return torch.linalg.cholesky(K)
|
||||
|
||||
|
||||
def _batch_lowrank_logdet(W, D, capacitance_tril):
|
||||
r"""
|
||||
Uses "matrix determinant lemma"::
|
||||
log|W @ W.T + D| = log|C| + log|D|,
|
||||
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
|
||||
the log determinant.
|
||||
"""
|
||||
return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
|
||||
-1
|
||||
)
|
||||
|
||||
|
||||
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
|
||||
r"""
|
||||
Uses "Woodbury matrix identity"::
|
||||
inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
|
||||
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
|
||||
Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
|
||||
"""
|
||||
Wt_Dinv = W.mT / D.unsqueeze(-2)
|
||||
Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
|
||||
mahalanobis_term1 = (x.pow(2) / D).sum(-1)
|
||||
mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
|
||||
return mahalanobis_term1 - mahalanobis_term2
|
||||
|
||||
|
||||
class LowRankMultivariateNormal(Distribution):
|
||||
r"""
|
||||
Creates a multivariate normal distribution with covariance matrix having a low-rank form
|
||||
parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
|
||||
|
||||
covariance_matrix = cov_factor @ cov_factor.T + cov_diag
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = LowRankMultivariateNormal(
|
||||
... torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2)
|
||||
... )
|
||||
>>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
|
||||
tensor([-0.2102, -0.5429])
|
||||
|
||||
Args:
|
||||
loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
|
||||
cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
|
||||
`batch_shape + event_shape + (rank,)`
|
||||
cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
|
||||
`batch_shape + event_shape`
|
||||
|
||||
Note:
|
||||
The computation for determinant and inverse of covariance matrix is avoided when
|
||||
`cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
|
||||
<https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
|
||||
`matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
|
||||
Thanks to these formulas, we just need to compute the determinant and inverse of
|
||||
the small size "capacitance" matrix::
|
||||
|
||||
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"loc": constraints.real_vector,
|
||||
"cov_factor": constraints.independent(constraints.real, 2),
|
||||
"cov_diag": constraints.independent(constraints.positive, 1),
|
||||
}
|
||||
support = constraints.real_vector
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
|
||||
if loc.dim() < 1:
|
||||
raise ValueError("loc must be at least one-dimensional.")
|
||||
event_shape = loc.shape[-1:]
|
||||
if cov_factor.dim() < 2:
|
||||
raise ValueError(
|
||||
"cov_factor must be at least two-dimensional, "
|
||||
"with optional leading batch dimensions"
|
||||
)
|
||||
if cov_factor.shape[-2:-1] != event_shape:
|
||||
raise ValueError(
|
||||
f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m"
|
||||
)
|
||||
if cov_diag.shape[-1:] != event_shape:
|
||||
raise ValueError(
|
||||
f"cov_diag must be a batch of vectors with shape {event_shape}"
|
||||
)
|
||||
|
||||
loc_ = loc.unsqueeze(-1)
|
||||
cov_diag_ = cov_diag.unsqueeze(-1)
|
||||
try:
|
||||
loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(
|
||||
loc_, cov_factor, cov_diag_
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise ValueError(
|
||||
f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}"
|
||||
) from e
|
||||
self.loc = loc_[..., 0]
|
||||
self.cov_diag = cov_diag_[..., 0]
|
||||
batch_shape = self.loc.shape[:-1]
|
||||
|
||||
self._unbroadcasted_cov_factor = cov_factor
|
||||
self._unbroadcasted_cov_diag = cov_diag
|
||||
self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
loc_shape = batch_shape + self.event_shape
|
||||
new.loc = self.loc.expand(loc_shape)
|
||||
new.cov_diag = self.cov_diag.expand(loc_shape)
|
||||
new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
|
||||
new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
|
||||
new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
|
||||
new._capacitance_tril = self._capacitance_tril
|
||||
super(LowRankMultivariateNormal, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@lazy_property
|
||||
def variance(self) -> Tensor: # type: ignore[override]
|
||||
return (
|
||||
self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag
|
||||
).expand(self._batch_shape + self._event_shape)
|
||||
|
||||
@lazy_property
|
||||
def scale_tril(self) -> Tensor:
|
||||
# The following identity is used to increase the numerically computation stability
|
||||
# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
|
||||
# W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
|
||||
# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
|
||||
# hence it is well-conditioned and safe to take Cholesky decomposition.
|
||||
n = self._event_shape[0]
|
||||
cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
|
||||
Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
|
||||
K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
|
||||
K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K
|
||||
scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
|
||||
return scale_tril.expand(
|
||||
self._batch_shape + self._event_shape + self._event_shape
|
||||
)
|
||||
|
||||
@lazy_property
|
||||
def covariance_matrix(self) -> Tensor:
|
||||
covariance_matrix = torch.matmul(
|
||||
self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT
|
||||
) + torch.diag_embed(self._unbroadcasted_cov_diag)
|
||||
return covariance_matrix.expand(
|
||||
self._batch_shape + self._event_shape + self._event_shape
|
||||
)
|
||||
|
||||
@lazy_property
|
||||
def precision_matrix(self) -> Tensor:
|
||||
# We use "Woodbury matrix identity" to take advantage of low rank form::
|
||||
# inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
|
||||
# where :math:`C` is the capacitance matrix.
|
||||
Wt_Dinv = (
|
||||
self._unbroadcasted_cov_factor.mT
|
||||
/ self._unbroadcasted_cov_diag.unsqueeze(-2)
|
||||
)
|
||||
A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
|
||||
precision_matrix = (
|
||||
torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
|
||||
)
|
||||
return precision_matrix.expand(
|
||||
self._batch_shape + self._event_shape + self._event_shape
|
||||
)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
W_shape = shape[:-1] + self.cov_factor.shape[-1:]
|
||||
eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||
eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||
return (
|
||||
self.loc
|
||||
+ _batch_mv(self._unbroadcasted_cov_factor, eps_W)
|
||||
+ self._unbroadcasted_cov_diag.sqrt() * eps_D
|
||||
)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
diff = value - self.loc
|
||||
M = _batch_lowrank_mahalanobis(
|
||||
self._unbroadcasted_cov_factor,
|
||||
self._unbroadcasted_cov_diag,
|
||||
diff,
|
||||
self._capacitance_tril,
|
||||
)
|
||||
log_det = _batch_lowrank_logdet(
|
||||
self._unbroadcasted_cov_factor,
|
||||
self._unbroadcasted_cov_diag,
|
||||
self._capacitance_tril,
|
||||
)
|
||||
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
|
||||
|
||||
def entropy(self):
|
||||
log_det = _batch_lowrank_logdet(
|
||||
self._unbroadcasted_cov_factor,
|
||||
self._unbroadcasted_cov_diag,
|
||||
self._capacitance_tril,
|
||||
)
|
||||
H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
|
||||
if len(self._batch_shape) == 0:
|
||||
return H
|
||||
else:
|
||||
return H.expand(self._batch_shape)
|
|
@ -0,0 +1,220 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import Categorical, constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
|
||||
|
||||
__all__ = ["MixtureSameFamily"]
|
||||
|
||||
|
||||
class MixtureSameFamily(Distribution):
|
||||
r"""
|
||||
The `MixtureSameFamily` distribution implements a (batch of) mixture
|
||||
distribution where all component are from different parameterizations of
|
||||
the same distribution type. It is parameterized by a `Categorical`
|
||||
"selecting distribution" (over `k` component) and a component
|
||||
distribution, i.e., a `Distribution` with a rightmost batch shape
|
||||
(equal to `[k]`) which indexes each (batch of) component.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # xdoctest: +SKIP("undefined vars")
|
||||
>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
|
||||
>>> # weighted normal distributions
|
||||
>>> mix = D.Categorical(torch.ones(5,))
|
||||
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
|
||||
>>> gmm = MixtureSameFamily(mix, comp)
|
||||
|
||||
>>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
|
||||
>>> # weighted bivariate normal distributions
|
||||
>>> mix = D.Categorical(torch.ones(5,))
|
||||
>>> comp = D.Independent(D.Normal(
|
||||
... torch.randn(5,2), torch.rand(5,2)), 1)
|
||||
>>> gmm = MixtureSameFamily(mix, comp)
|
||||
|
||||
>>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
|
||||
>>> # consisting of 5 random weighted bivariate normal distributions
|
||||
>>> mix = D.Categorical(torch.rand(3,5))
|
||||
>>> comp = D.Independent(D.Normal(
|
||||
... torch.randn(3,5,2), torch.rand(3,5,2)), 1)
|
||||
>>> gmm = MixtureSameFamily(mix, comp)
|
||||
|
||||
Args:
|
||||
mixture_distribution: `torch.distributions.Categorical`-like
|
||||
instance. Manages the probability of selecting component.
|
||||
The number of categories must match the rightmost batch
|
||||
dimension of the `component_distribution`. Must have either
|
||||
scalar `batch_shape` or `batch_shape` matching
|
||||
`component_distribution.batch_shape[:-1]`
|
||||
component_distribution: `torch.distributions.Distribution`-like
|
||||
instance. Right-most batch dimension indexes component.
|
||||
"""
|
||||
|
||||
arg_constraints: dict[str, constraints.Constraint] = {}
|
||||
has_rsample = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mixture_distribution: Categorical,
|
||||
component_distribution: Distribution,
|
||||
validate_args=None,
|
||||
) -> None:
|
||||
self._mixture_distribution = mixture_distribution
|
||||
self._component_distribution = component_distribution
|
||||
|
||||
if not isinstance(self._mixture_distribution, Categorical):
|
||||
raise ValueError(
|
||||
" The Mixture distribution needs to be an "
|
||||
" instance of torch.distributions.Categorical"
|
||||
)
|
||||
|
||||
if not isinstance(self._component_distribution, Distribution):
|
||||
raise ValueError(
|
||||
"The Component distribution need to be an "
|
||||
"instance of torch.distributions.Distribution"
|
||||
)
|
||||
|
||||
# Check that batch size matches
|
||||
mdbs = self._mixture_distribution.batch_shape
|
||||
cdbs = self._component_distribution.batch_shape[:-1]
|
||||
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
|
||||
if size1 != 1 and size2 != 1 and size1 != size2:
|
||||
raise ValueError(
|
||||
f"`mixture_distribution.batch_shape` ({mdbs}) is not "
|
||||
"compatible with `component_distribution."
|
||||
f"batch_shape`({cdbs})"
|
||||
)
|
||||
|
||||
# Check that the number of mixture component matches
|
||||
km = self._mixture_distribution.logits.shape[-1]
|
||||
kc = self._component_distribution.batch_shape[-1]
|
||||
if km is not None and kc is not None and km != kc:
|
||||
raise ValueError(
|
||||
f"`mixture_distribution component` ({km}) does not"
|
||||
" equal `component_distribution.batch_shape[-1]`"
|
||||
f" ({kc})"
|
||||
)
|
||||
self._num_component = km
|
||||
|
||||
event_shape = self._component_distribution.event_shape
|
||||
self._event_ndims = len(event_shape)
|
||||
super().__init__(
|
||||
batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args
|
||||
)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
batch_shape_comp = batch_shape + (self._num_component,)
|
||||
new = self._get_checked_instance(MixtureSameFamily, _instance)
|
||||
new._component_distribution = self._component_distribution.expand(
|
||||
batch_shape_comp
|
||||
)
|
||||
new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
|
||||
new._num_component = self._num_component
|
||||
new._event_ndims = self._event_ndims
|
||||
event_shape = new._component_distribution.event_shape
|
||||
super(MixtureSameFamily, new).__init__(
|
||||
batch_shape=batch_shape, event_shape=event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@constraints.dependent_property
|
||||
def support(self):
|
||||
# FIXME this may have the wrong shape when support contains batched
|
||||
# parameters
|
||||
return self._component_distribution.support
|
||||
|
||||
@property
|
||||
def mixture_distribution(self) -> Categorical:
|
||||
return self._mixture_distribution
|
||||
|
||||
@property
|
||||
def component_distribution(self) -> Distribution:
|
||||
return self._component_distribution
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
|
||||
return torch.sum(
|
||||
probs * self.component_distribution.mean, dim=-1 - self._event_ndims
|
||||
) # [B, E]
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
|
||||
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
|
||||
mean_cond_var = torch.sum(
|
||||
probs * self.component_distribution.variance, dim=-1 - self._event_ndims
|
||||
)
|
||||
var_cond_mean = torch.sum(
|
||||
probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),
|
||||
dim=-1 - self._event_ndims,
|
||||
)
|
||||
return mean_cond_var + var_cond_mean
|
||||
|
||||
def cdf(self, x):
|
||||
x = self._pad(x)
|
||||
cdf_x = self.component_distribution.cdf(x)
|
||||
mix_prob = self.mixture_distribution.probs
|
||||
|
||||
return torch.sum(cdf_x * mix_prob, dim=-1)
|
||||
|
||||
def log_prob(self, x):
|
||||
if self._validate_args:
|
||||
self._validate_sample(x)
|
||||
x = self._pad(x)
|
||||
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
|
||||
log_mix_prob = torch.log_softmax(
|
||||
self.mixture_distribution.logits, dim=-1
|
||||
) # [B, k]
|
||||
return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
with torch.no_grad():
|
||||
sample_len = len(sample_shape)
|
||||
batch_len = len(self.batch_shape)
|
||||
gather_dim = sample_len + batch_len
|
||||
es = self.event_shape
|
||||
|
||||
# mixture samples [n, B]
|
||||
mix_sample = self.mixture_distribution.sample(sample_shape)
|
||||
mix_shape = mix_sample.shape
|
||||
|
||||
# component samples [n, B, k, E]
|
||||
comp_samples = self.component_distribution.sample(sample_shape)
|
||||
|
||||
# Gather along the k dimension
|
||||
mix_sample_r = mix_sample.reshape(
|
||||
mix_shape + torch.Size([1] * (len(es) + 1))
|
||||
)
|
||||
mix_sample_r = mix_sample_r.repeat(
|
||||
torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
|
||||
)
|
||||
|
||||
samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
|
||||
return samples.squeeze(gather_dim)
|
||||
|
||||
def _pad(self, x):
|
||||
return x.unsqueeze(-1 - self._event_ndims)
|
||||
|
||||
def _pad_mixture_dimensions(self, x):
|
||||
dist_batch_ndims = len(self.batch_shape)
|
||||
cat_batch_ndims = len(self.mixture_distribution.batch_shape)
|
||||
pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
|
||||
xs = x.shape
|
||||
x = x.reshape(
|
||||
xs[:-1]
|
||||
+ torch.Size(pad_ndims * [1])
|
||||
+ xs[-1:]
|
||||
+ torch.Size(self._event_ndims * [1])
|
||||
)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
args_string = (
|
||||
f"\n {self.mixture_distribution},\n {self.component_distribution}"
|
||||
)
|
||||
return "MixtureSameFamily" + "(" + args_string + ")"
|
138
venv/Lib/site-packages/torch/distributions/multinomial.py
Normal file
138
venv/Lib/site-packages/torch/distributions/multinomial.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import inf, Tensor
|
||||
from torch.distributions import Categorical, constraints
|
||||
from torch.distributions.binomial import Binomial
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all
|
||||
|
||||
|
||||
__all__ = ["Multinomial"]
|
||||
|
||||
|
||||
class Multinomial(Distribution):
|
||||
r"""
|
||||
Creates a Multinomial distribution parameterized by :attr:`total_count` and
|
||||
either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
|
||||
:attr:`probs` indexes over categories. All other dimensions index over batches.
|
||||
|
||||
Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
|
||||
called (see example below)
|
||||
|
||||
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
|
||||
will return this normalized value.
|
||||
The `logits` argument will be interpreted as unnormalized log probabilities
|
||||
and can therefore be any real number. It will likewise be normalized so that
|
||||
the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
|
||||
will return this normalized value.
|
||||
|
||||
- :meth:`sample` requires a single shared `total_count` for all
|
||||
parameters and samples.
|
||||
- :meth:`log_prob` allows different `total_count` for each parameter and
|
||||
sample.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP("FIXME: found invalid values")
|
||||
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
|
||||
>>> x = m.sample() # equal probability of 0, 1, 2, 3
|
||||
tensor([ 21., 24., 30., 25.])
|
||||
|
||||
>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
|
||||
tensor([-4.1338])
|
||||
|
||||
Args:
|
||||
total_count (int): number of trials
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
total_count: int
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.probs * self.total_count
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.total_count * self.probs * (1 - self.probs)
|
||||
|
||||
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
|
||||
if not isinstance(total_count, int):
|
||||
raise NotImplementedError("inhomogeneous total_count is not supported")
|
||||
self.total_count = total_count
|
||||
self._categorical = Categorical(probs=probs, logits=logits)
|
||||
self._binomial = Binomial(total_count=total_count, probs=self.probs)
|
||||
batch_shape = self._categorical.batch_shape
|
||||
event_shape = self._categorical.param_shape[-1:]
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Multinomial, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.total_count = self.total_count
|
||||
new._categorical = self._categorical.expand(batch_shape)
|
||||
super(Multinomial, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._categorical._new(*args, **kwargs)
|
||||
|
||||
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
||||
def support(self):
|
||||
return constraints.multinomial(self.total_count)
|
||||
|
||||
@property
|
||||
def logits(self) -> Tensor:
|
||||
return self._categorical.logits
|
||||
|
||||
@property
|
||||
def probs(self) -> Tensor:
|
||||
return self._categorical.probs
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._categorical.param_shape
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
samples = self._categorical.sample(
|
||||
torch.Size((self.total_count,)) + sample_shape
|
||||
)
|
||||
# samples.shape is (total_count, sample_shape, batch_shape), need to change it to
|
||||
# (sample_shape, batch_shape, total_count)
|
||||
shifted_idx = list(range(samples.dim()))
|
||||
shifted_idx.append(shifted_idx.pop(0))
|
||||
samples = samples.permute(*shifted_idx)
|
||||
counts = samples.new(self._extended_shape(sample_shape)).zero_()
|
||||
counts.scatter_add_(-1, samples, torch.ones_like(samples))
|
||||
return counts.type_as(self.probs)
|
||||
|
||||
def entropy(self):
|
||||
n = torch.tensor(self.total_count)
|
||||
|
||||
cat_entropy = self._categorical.entropy()
|
||||
term1 = n * cat_entropy - torch.lgamma(n + 1)
|
||||
|
||||
support = self._binomial.enumerate_support(expand=False)[1:]
|
||||
binomial_probs = torch.exp(self._binomial.log_prob(support))
|
||||
weights = torch.lgamma(support + 1)
|
||||
term2 = (binomial_probs * weights).sum([0, -1])
|
||||
|
||||
return term1 + term2
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
logits = logits.clone(memory_format=torch.contiguous_format)
|
||||
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
|
||||
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
|
||||
logits[(value == 0) & (logits == -inf)] = 0
|
||||
log_powers = (logits * value).sum(-1)
|
||||
return log_factorial_n - log_factorial_xs + log_powers
|
|
@ -0,0 +1,267 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import _standard_normal, lazy_property
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["MultivariateNormal"]
|
||||
|
||||
|
||||
def _batch_mv(bmat, bvec):
|
||||
r"""
|
||||
Performs a batched matrix-vector product, with compatible but different batch shapes.
|
||||
|
||||
This function takes as input `bmat`, containing :math:`n \times n` matrices, and
|
||||
`bvec`, containing length :math:`n` vectors.
|
||||
|
||||
Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
|
||||
to a batch shape. They are not necessarily assumed to have the same batch shape,
|
||||
just ones which can be broadcasted.
|
||||
"""
|
||||
return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
|
||||
def _batch_mahalanobis(bL, bx):
|
||||
r"""
|
||||
Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
|
||||
for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
|
||||
|
||||
Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
|
||||
shape, but `bL` one should be able to broadcasted to `bx` one.
|
||||
"""
|
||||
n = bx.size(-1)
|
||||
bx_batch_shape = bx.shape[:-1]
|
||||
|
||||
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
|
||||
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
|
||||
bx_batch_dims = len(bx_batch_shape)
|
||||
bL_batch_dims = bL.dim() - 2
|
||||
outer_batch_dims = bx_batch_dims - bL_batch_dims
|
||||
old_batch_dims = outer_batch_dims + bL_batch_dims
|
||||
new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
|
||||
# Reshape bx with the shape (..., 1, i, j, 1, n)
|
||||
bx_new_shape = bx.shape[:outer_batch_dims]
|
||||
for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
|
||||
bx_new_shape += (sx // sL, sL)
|
||||
bx_new_shape += (n,)
|
||||
bx = bx.reshape(bx_new_shape)
|
||||
# Permute bx to make it have shape (..., 1, j, i, 1, n)
|
||||
permute_dims = (
|
||||
list(range(outer_batch_dims))
|
||||
+ list(range(outer_batch_dims, new_batch_dims, 2))
|
||||
+ list(range(outer_batch_dims + 1, new_batch_dims, 2))
|
||||
+ [new_batch_dims]
|
||||
)
|
||||
bx = bx.permute(permute_dims)
|
||||
|
||||
flat_L = bL.reshape(-1, n, n) # shape = b x n x n
|
||||
flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
|
||||
flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
|
||||
M_swap = (
|
||||
torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
|
||||
) # shape = b x c
|
||||
M = M_swap.t() # shape = c x b
|
||||
|
||||
# Now we revert the above reshape and permute operators.
|
||||
permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
|
||||
permute_inv_dims = list(range(outer_batch_dims))
|
||||
for i in range(bL_batch_dims):
|
||||
permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
|
||||
reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
|
||||
return reshaped_M.reshape(bx_batch_shape)
|
||||
|
||||
|
||||
def _precision_to_scale_tril(P):
|
||||
# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
|
||||
Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
|
||||
L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
|
||||
Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
|
||||
L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
|
||||
return L
|
||||
|
||||
|
||||
class MultivariateNormal(Distribution):
|
||||
r"""
|
||||
Creates a multivariate normal (also called Gaussian) distribution
|
||||
parameterized by a mean vector and a covariance matrix.
|
||||
|
||||
The multivariate normal distribution can be parameterized either
|
||||
in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
|
||||
or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
|
||||
or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
|
||||
diagonal entries, such that
|
||||
:math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
|
||||
can be obtained via e.g. Cholesky decomposition of the covariance.
|
||||
|
||||
Example:
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
|
||||
>>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
|
||||
tensor([-0.2102, -0.5429])
|
||||
|
||||
Args:
|
||||
loc (Tensor): mean of the distribution
|
||||
covariance_matrix (Tensor): positive-definite covariance matrix
|
||||
precision_matrix (Tensor): positive-definite precision matrix
|
||||
scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
|
||||
|
||||
Note:
|
||||
Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
|
||||
:attr:`scale_tril` can be specified.
|
||||
|
||||
Using :attr:`scale_tril` will be more efficient: all computations internally
|
||||
are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
|
||||
:attr:`precision_matrix` is passed instead, it is only used to compute
|
||||
the corresponding lower triangular matrices using a Cholesky decomposition.
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"loc": constraints.real_vector,
|
||||
"covariance_matrix": constraints.positive_definite,
|
||||
"precision_matrix": constraints.positive_definite,
|
||||
"scale_tril": constraints.lower_cholesky,
|
||||
}
|
||||
support = constraints.real_vector
|
||||
has_rsample = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loc,
|
||||
covariance_matrix=None,
|
||||
precision_matrix=None,
|
||||
scale_tril=None,
|
||||
validate_args=None,
|
||||
):
|
||||
if loc.dim() < 1:
|
||||
raise ValueError("loc must be at least one-dimensional.")
|
||||
if (covariance_matrix is not None) + (scale_tril is not None) + (
|
||||
precision_matrix is not None
|
||||
) != 1:
|
||||
raise ValueError(
|
||||
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
|
||||
)
|
||||
|
||||
if scale_tril is not None:
|
||||
if scale_tril.dim() < 2:
|
||||
raise ValueError(
|
||||
"scale_tril matrix must be at least two-dimensional, "
|
||||
"with optional leading batch dimensions"
|
||||
)
|
||||
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
|
||||
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
|
||||
elif covariance_matrix is not None:
|
||||
if covariance_matrix.dim() < 2:
|
||||
raise ValueError(
|
||||
"covariance_matrix must be at least two-dimensional, "
|
||||
"with optional leading batch dimensions"
|
||||
)
|
||||
batch_shape = torch.broadcast_shapes(
|
||||
covariance_matrix.shape[:-2], loc.shape[:-1]
|
||||
)
|
||||
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
|
||||
else:
|
||||
if precision_matrix.dim() < 2:
|
||||
raise ValueError(
|
||||
"precision_matrix must be at least two-dimensional, "
|
||||
"with optional leading batch dimensions"
|
||||
)
|
||||
batch_shape = torch.broadcast_shapes(
|
||||
precision_matrix.shape[:-2], loc.shape[:-1]
|
||||
)
|
||||
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
|
||||
self.loc = loc.expand(batch_shape + (-1,))
|
||||
|
||||
event_shape = self.loc.shape[-1:]
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
if scale_tril is not None:
|
||||
self._unbroadcasted_scale_tril = scale_tril
|
||||
elif covariance_matrix is not None:
|
||||
self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
|
||||
else: # precision_matrix is not None
|
||||
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(MultivariateNormal, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
loc_shape = batch_shape + self.event_shape
|
||||
cov_shape = batch_shape + self.event_shape + self.event_shape
|
||||
new.loc = self.loc.expand(loc_shape)
|
||||
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
|
||||
if "covariance_matrix" in self.__dict__:
|
||||
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
|
||||
if "scale_tril" in self.__dict__:
|
||||
new.scale_tril = self.scale_tril.expand(cov_shape)
|
||||
if "precision_matrix" in self.__dict__:
|
||||
new.precision_matrix = self.precision_matrix.expand(cov_shape)
|
||||
super(MultivariateNormal, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@lazy_property
|
||||
def scale_tril(self) -> Tensor:
|
||||
return self._unbroadcasted_scale_tril.expand(
|
||||
self._batch_shape + self._event_shape + self._event_shape
|
||||
)
|
||||
|
||||
@lazy_property
|
||||
def covariance_matrix(self) -> Tensor:
|
||||
return torch.matmul(
|
||||
self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
|
||||
).expand(self._batch_shape + self._event_shape + self._event_shape)
|
||||
|
||||
@lazy_property
|
||||
def precision_matrix(self) -> Tensor:
|
||||
return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
|
||||
self._batch_shape + self._event_shape + self._event_shape
|
||||
)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return (
|
||||
self._unbroadcasted_scale_tril.pow(2)
|
||||
.sum(-1)
|
||||
.expand(self._batch_shape + self._event_shape)
|
||||
)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||
return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
diff = value - self.loc
|
||||
M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
|
||||
half_log_det = (
|
||||
self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
|
||||
)
|
||||
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
|
||||
|
||||
def entropy(self):
|
||||
half_log_det = (
|
||||
self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
|
||||
)
|
||||
H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
|
||||
if len(self._batch_shape) == 0:
|
||||
return H
|
||||
else:
|
||||
return H.expand(self._batch_shape)
|
138
venv/Lib/site-packages/torch/distributions/negative_binomial.py
Normal file
138
venv/Lib/site-packages/torch/distributions/negative_binomial.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.gamma import Gamma
|
||||
from torch.distributions.utils import (
|
||||
broadcast_all,
|
||||
lazy_property,
|
||||
logits_to_probs,
|
||||
probs_to_logits,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["NegativeBinomial"]
|
||||
|
||||
|
||||
class NegativeBinomial(Distribution):
|
||||
r"""
|
||||
Creates a Negative Binomial distribution, i.e. distribution
|
||||
of the number of successful independent and identical Bernoulli trials
|
||||
before :attr:`total_count` failures are achieved. The probability
|
||||
of success of each Bernoulli trial is :attr:`probs`.
|
||||
|
||||
Args:
|
||||
total_count (float or Tensor): non-negative number of negative Bernoulli
|
||||
trials to stop, although the distribution is still valid for real
|
||||
valued count
|
||||
probs (Tensor): Event probabilities of success in the half open interval [0, 1)
|
||||
logits (Tensor): Event log-odds for probabilities of success
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"total_count": constraints.greater_than_eq(0),
|
||||
"probs": constraints.half_open_interval(0.0, 1.0),
|
||||
"logits": constraints.real,
|
||||
}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
def __init__(self, total_count, probs=None, logits=None, validate_args=None):
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError(
|
||||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
(
|
||||
self.total_count,
|
||||
self.probs,
|
||||
) = broadcast_all(total_count, probs)
|
||||
self.total_count = self.total_count.type_as(self.probs)
|
||||
else:
|
||||
(
|
||||
self.total_count,
|
||||
self.logits,
|
||||
) = broadcast_all(total_count, logits)
|
||||
self.total_count = self.total_count.type_as(self.logits)
|
||||
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
batch_shape = self._param.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(NegativeBinomial, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.total_count = self.total_count.expand(batch_shape)
|
||||
if "probs" in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
if "logits" in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._param.new(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.total_count * torch.exp(self.logits)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.mean / torch.sigmoid(-self.logits)
|
||||
|
||||
@lazy_property
|
||||
def logits(self) -> Tensor:
|
||||
return probs_to_logits(self.probs, is_binary=True)
|
||||
|
||||
@lazy_property
|
||||
def probs(self) -> Tensor:
|
||||
return logits_to_probs(self.logits, is_binary=True)
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._param.size()
|
||||
|
||||
@lazy_property
|
||||
def _gamma(self) -> Gamma:
|
||||
# Note we avoid validating because self.total_count can be zero.
|
||||
return Gamma(
|
||||
concentration=self.total_count,
|
||||
rate=torch.exp(-self.logits),
|
||||
validate_args=False,
|
||||
)
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
with torch.no_grad():
|
||||
rate = self._gamma.sample(sample_shape=sample_shape)
|
||||
return torch.poisson(rate)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
|
||||
log_unnormalized_prob = self.total_count * F.logsigmoid(
|
||||
-self.logits
|
||||
) + value * F.logsigmoid(self.logits)
|
||||
|
||||
log_normalization = (
|
||||
-torch.lgamma(self.total_count + value)
|
||||
+ torch.lgamma(1.0 + value)
|
||||
+ torch.lgamma(self.total_count)
|
||||
)
|
||||
# The case self.total_count == 0 and value == 0 has probability 1 but
|
||||
# lgamma(0) is infinite. Handle this case separately using a function
|
||||
# that does not modify tensors in place to allow Jit compilation.
|
||||
log_normalization = log_normalization.masked_fill(
|
||||
self.total_count + value == 0.0, 0.0
|
||||
)
|
||||
|
||||
return log_unnormalized_prob - log_normalization
|
115
venv/Lib/site-packages/torch/distributions/normal.py
Normal file
115
venv/Lib/site-packages/torch/distributions/normal.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import _standard_normal, broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Normal"]
|
||||
|
||||
|
||||
class Normal(ExponentialFamily):
|
||||
r"""
|
||||
Creates a normal (also called Gaussian) distribution parameterized by
|
||||
:attr:`loc` and :attr:`scale`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # normally distributed with loc=0 and scale=1
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
loc (float or Tensor): mean of the distribution (often referred to as mu)
|
||||
scale (float or Tensor): standard deviation of the distribution
|
||||
(often referred to as sigma)
|
||||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
_mean_carrier_measure = 0
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def stddev(self) -> Tensor:
|
||||
return self.scale
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.stddev.pow(2)
|
||||
|
||||
def __init__(self, loc, scale, validate_args=None):
|
||||
self.loc, self.scale = broadcast_all(loc, scale)
|
||||
if isinstance(loc, _Number) and isinstance(scale, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.loc.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Normal, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.loc = self.loc.expand(batch_shape)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
super(Normal, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
with torch.no_grad():
|
||||
return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
||||
return self.loc + eps * self.scale
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
# compute the variance
|
||||
var = self.scale**2
|
||||
log_scale = (
|
||||
math.log(self.scale)
|
||||
if isinstance(self.scale, _Number)
|
||||
else self.scale.log()
|
||||
)
|
||||
return (
|
||||
-((value - self.loc) ** 2) / (2 * var)
|
||||
- log_scale
|
||||
- math.log(math.sqrt(2 * math.pi))
|
||||
)
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return 0.5 * (
|
||||
1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2))
|
||||
)
|
||||
|
||||
def icdf(self, value):
|
||||
return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
|
||||
|
||||
def entropy(self):
|
||||
return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
|
||||
|
||||
def _log_normalizer(self, x, y):
|
||||
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
|
|
@ -0,0 +1,135 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.categorical import Categorical
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["OneHotCategorical", "OneHotCategoricalStraightThrough"]
|
||||
|
||||
|
||||
class OneHotCategorical(Distribution):
|
||||
r"""
|
||||
Creates a one-hot categorical distribution parameterized by :attr:`probs` or
|
||||
:attr:`logits`.
|
||||
|
||||
Samples are one-hot coded vectors of size ``probs.size(-1)``.
|
||||
|
||||
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
|
||||
will return this normalized value.
|
||||
The `logits` argument will be interpreted as unnormalized log probabilities
|
||||
and can therefore be any real number. It will likewise be normalized so that
|
||||
the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
|
||||
will return this normalized value.
|
||||
|
||||
See also: :func:`torch.distributions.Categorical` for specifications of
|
||||
:attr:`probs` and :attr:`logits`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
|
||||
>>> m.sample() # equal probability of 0, 1, 2, 3
|
||||
tensor([ 0., 0., 0., 1.])
|
||||
|
||||
Args:
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
support = constraints.one_hot
|
||||
has_enumerate_support = True
|
||||
|
||||
def __init__(self, probs=None, logits=None, validate_args=None):
|
||||
self._categorical = Categorical(probs, logits)
|
||||
batch_shape = self._categorical.batch_shape
|
||||
event_shape = self._categorical.param_shape[-1:]
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(OneHotCategorical, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new._categorical = self._categorical.expand(batch_shape)
|
||||
super(OneHotCategorical, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._categorical._new(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _param(self) -> Tensor:
|
||||
return self._categorical._param
|
||||
|
||||
@property
|
||||
def probs(self) -> Tensor:
|
||||
return self._categorical.probs
|
||||
|
||||
@property
|
||||
def logits(self) -> Tensor:
|
||||
return self._categorical.logits
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self._categorical.probs
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
probs = self._categorical.probs
|
||||
mode = probs.argmax(dim=-1)
|
||||
return torch.nn.functional.one_hot(mode, num_classes=probs.shape[-1]).to(probs)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self._categorical.probs * (1 - self._categorical.probs)
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._categorical.param_shape
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
probs = self._categorical.probs
|
||||
num_events = self._categorical._num_events
|
||||
indices = self._categorical.sample(sample_shape)
|
||||
return torch.nn.functional.one_hot(indices, num_events).to(probs)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
indices = value.max(-1)[1]
|
||||
return self._categorical.log_prob(indices)
|
||||
|
||||
def entropy(self):
|
||||
return self._categorical.entropy()
|
||||
|
||||
def enumerate_support(self, expand=True):
|
||||
n = self.event_shape[0]
|
||||
values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
|
||||
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
|
||||
if expand:
|
||||
values = values.expand((n,) + self.batch_shape + (n,))
|
||||
return values
|
||||
|
||||
|
||||
class OneHotCategoricalStraightThrough(OneHotCategorical):
|
||||
r"""
|
||||
Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight-
|
||||
through gradient estimator from [1].
|
||||
|
||||
[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
|
||||
(Bengio et al., 2013)
|
||||
"""
|
||||
|
||||
has_rsample = True
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
samples = self.sample(sample_shape)
|
||||
probs = self._categorical.probs # cached via @lazy_property
|
||||
return samples + (probs - probs.detach())
|
70
venv/Lib/site-packages/torch/distributions/pareto.py
Normal file
70
venv/Lib/site-packages/torch/distributions/pareto.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from typing import Optional
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exponential import Exponential
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import AffineTransform, ExpTransform
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["Pareto"]
|
||||
|
||||
|
||||
class Pareto(TransformedDistribution):
|
||||
r"""
|
||||
Samples from a Pareto Type 1 distribution.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1
|
||||
tensor([ 1.5623])
|
||||
|
||||
Args:
|
||||
scale (float or Tensor): Scale parameter of the distribution
|
||||
alpha (float or Tensor): Shape parameter of the distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
|
||||
|
||||
def __init__(
|
||||
self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None
|
||||
) -> None:
|
||||
self.scale, self.alpha = broadcast_all(scale, alpha)
|
||||
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
||||
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(
|
||||
self, batch_shape: _size, _instance: Optional["Pareto"] = None
|
||||
) -> "Pareto":
|
||||
new = self._get_checked_instance(Pareto, _instance)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
new.alpha = self.alpha.expand(batch_shape)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
# mean is inf for alpha <= 1
|
||||
a = self.alpha.clamp(min=1)
|
||||
return a * self.scale / (a - 1)
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.scale
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
# var is inf for alpha <= 2
|
||||
a = self.alpha.clamp(min=2)
|
||||
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
|
||||
|
||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||
def support(self) -> constraints.Constraint:
|
||||
return constraints.greater_than_eq(self.scale)
|
||||
|
||||
def entropy(self) -> Tensor:
|
||||
return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())
|
80
venv/Lib/site-packages/torch/distributions/poisson.py
Normal file
80
venv/Lib/site-packages/torch/distributions/poisson.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number
|
||||
|
||||
|
||||
__all__ = ["Poisson"]
|
||||
|
||||
|
||||
class Poisson(ExponentialFamily):
|
||||
r"""
|
||||
Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
|
||||
|
||||
Samples are nonnegative integers, with a pmf given by
|
||||
|
||||
.. math::
|
||||
\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
|
||||
>>> m = Poisson(torch.tensor([4]))
|
||||
>>> m.sample()
|
||||
tensor([ 3.])
|
||||
|
||||
Args:
|
||||
rate (Number, Tensor): the rate parameter
|
||||
"""
|
||||
|
||||
arg_constraints = {"rate": constraints.nonnegative}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.rate
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.rate.floor()
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.rate
|
||||
|
||||
def __init__(self, rate, validate_args=None):
|
||||
(self.rate,) = broadcast_all(rate)
|
||||
if isinstance(rate, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.rate.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Poisson, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.rate = self.rate.expand(batch_shape)
|
||||
super(Poisson, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
with torch.no_grad():
|
||||
return torch.poisson(self.rate.expand(shape))
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
rate, value = broadcast_all(self.rate, value)
|
||||
return value.xlogy(rate) - rate - (value + 1).lgamma()
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (torch.log(self.rate),)
|
||||
|
||||
def _log_normalizer(self, x):
|
||||
return torch.exp(x)
|
153
venv/Lib/site-packages/torch/distributions/relaxed_bernoulli.py
Normal file
153
venv/Lib/site-packages/torch/distributions/relaxed_bernoulli.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import SigmoidTransform
|
||||
from torch.distributions.utils import (
|
||||
broadcast_all,
|
||||
clamp_probs,
|
||||
lazy_property,
|
||||
logits_to_probs,
|
||||
probs_to_logits,
|
||||
)
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"]
|
||||
|
||||
|
||||
class LogitRelaxedBernoulli(Distribution):
|
||||
r"""
|
||||
Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs`
|
||||
or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli
|
||||
distribution.
|
||||
|
||||
Samples are logits of values in (0, 1). See [1] for more details.
|
||||
|
||||
Args:
|
||||
temperature (Tensor): relaxation temperature
|
||||
probs (Number, Tensor): the probability of sampling `1`
|
||||
logits (Number, Tensor): the log-odds of sampling `1`
|
||||
|
||||
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random
|
||||
Variables (Maddison et al., 2017)
|
||||
|
||||
[2] Categorical Reparametrization with Gumbel-Softmax
|
||||
(Jang et al., 2017)
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.real
|
||||
|
||||
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
|
||||
self.temperature = temperature
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError(
|
||||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
is_scalar = isinstance(probs, _Number)
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
else:
|
||||
is_scalar = isinstance(logits, _Number)
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
if is_scalar:
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self._param.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(LogitRelaxedBernoulli, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.temperature = self.temperature
|
||||
if "probs" in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
if "logits" in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._param.new(*args, **kwargs)
|
||||
|
||||
@lazy_property
|
||||
def logits(self) -> Tensor:
|
||||
return probs_to_logits(self.probs, is_binary=True)
|
||||
|
||||
@lazy_property
|
||||
def probs(self) -> Tensor:
|
||||
return logits_to_probs(self.logits, is_binary=True)
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._param.size()
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
probs = clamp_probs(self.probs.expand(shape))
|
||||
uniforms = clamp_probs(
|
||||
torch.rand(shape, dtype=probs.dtype, device=probs.device)
|
||||
)
|
||||
return (
|
||||
uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()
|
||||
) / self.temperature
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
diff = logits - value.mul(self.temperature)
|
||||
return self.temperature.log() + diff - 2 * diff.exp().log1p()
|
||||
|
||||
|
||||
class RelaxedBernoulli(TransformedDistribution):
|
||||
r"""
|
||||
Creates a RelaxedBernoulli distribution, parametrized by
|
||||
:attr:`temperature`, and either :attr:`probs` or :attr:`logits`
|
||||
(but not both). This is a relaxed version of the `Bernoulli` distribution,
|
||||
so the values are in (0, 1), and has reparametrizable samples.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = RelaxedBernoulli(torch.tensor([2.2]),
|
||||
... torch.tensor([0.1, 0.2, 0.3, 0.99]))
|
||||
>>> m.sample()
|
||||
tensor([ 0.2951, 0.3442, 0.8918, 0.9021])
|
||||
|
||||
Args:
|
||||
temperature (Tensor): relaxation temperature
|
||||
probs (Number, Tensor): the probability of sampling `1`
|
||||
logits (Number, Tensor): the log-odds of sampling `1`
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.unit_interval
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
|
||||
base_dist = LogitRelaxedBernoulli(temperature, probs, logits)
|
||||
super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(RelaxedBernoulli, _instance)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def temperature(self) -> Tensor:
|
||||
return self.base_dist.temperature
|
||||
|
||||
@property
|
||||
def logits(self) -> Tensor:
|
||||
return self.base_dist.logits
|
||||
|
||||
@property
|
||||
def probs(self) -> Tensor:
|
||||
return self.base_dist.probs
|
|
@ -0,0 +1,145 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.categorical import Categorical
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import ExpTransform
|
||||
from torch.distributions.utils import broadcast_all, clamp_probs
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"]
|
||||
|
||||
|
||||
class ExpRelaxedCategorical(Distribution):
|
||||
r"""
|
||||
Creates a ExpRelaxedCategorical parameterized by
|
||||
:attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
|
||||
Returns the log of a point in the simplex. Based on the interface to
|
||||
:class:`OneHotCategorical`.
|
||||
|
||||
Implementation based on [1].
|
||||
|
||||
See also: :func:`torch.distributions.OneHotCategorical`
|
||||
|
||||
Args:
|
||||
temperature (Tensor): relaxation temperature
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): unnormalized log probability for each event
|
||||
|
||||
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
|
||||
(Maddison et al., 2017)
|
||||
|
||||
[2] Categorical Reparametrization with Gumbel-Softmax
|
||||
(Jang et al., 2017)
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
support = (
|
||||
constraints.real_vector
|
||||
) # The true support is actually a submanifold of this.
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
|
||||
self._categorical = Categorical(probs, logits)
|
||||
self.temperature = temperature
|
||||
batch_shape = self._categorical.batch_shape
|
||||
event_shape = self._categorical.param_shape[-1:]
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.temperature = self.temperature
|
||||
new._categorical = self._categorical.expand(batch_shape)
|
||||
super(ExpRelaxedCategorical, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._categorical._new(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def param_shape(self) -> torch.Size:
|
||||
return self._categorical.param_shape
|
||||
|
||||
@property
|
||||
def logits(self) -> Tensor:
|
||||
return self._categorical.logits
|
||||
|
||||
@property
|
||||
def probs(self) -> Tensor:
|
||||
return self._categorical.probs
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
uniforms = clamp_probs(
|
||||
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
|
||||
)
|
||||
gumbels = -((-(uniforms.log())).log())
|
||||
scores = (self.logits + gumbels) / self.temperature
|
||||
return scores - scores.logsumexp(dim=-1, keepdim=True)
|
||||
|
||||
def log_prob(self, value):
|
||||
K = self._categorical._num_events
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
log_scale = torch.full_like(
|
||||
self.temperature, float(K)
|
||||
).lgamma() - self.temperature.log().mul(-(K - 1))
|
||||
score = logits - value.mul(self.temperature)
|
||||
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
|
||||
return score + log_scale
|
||||
|
||||
|
||||
class RelaxedOneHotCategorical(TransformedDistribution):
|
||||
r"""
|
||||
Creates a RelaxedOneHotCategorical distribution parametrized by
|
||||
:attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
|
||||
This is a relaxed version of the :class:`OneHotCategorical` distribution, so
|
||||
its samples are on simplex, and are reparametrizable.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
|
||||
... torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
>>> m.sample()
|
||||
tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
|
||||
|
||||
Args:
|
||||
temperature (Tensor): relaxation temperature
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): unnormalized log probability for each event
|
||||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
support = constraints.simplex
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
|
||||
base_dist = ExpRelaxedCategorical(
|
||||
temperature, probs, logits, validate_args=validate_args
|
||||
)
|
||||
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
|
||||
return super().expand(batch_shape, _instance=new)
|
||||
|
||||
@property
|
||||
def temperature(self) -> Tensor:
|
||||
return self.base_dist.temperature
|
||||
|
||||
@property
|
||||
def logits(self) -> Tensor:
|
||||
return self.base_dist.logits
|
||||
|
||||
@property
|
||||
def probs(self) -> Tensor:
|
||||
return self.base_dist.probs
|
120
venv/Lib/site-packages/torch/distributions/studentT.py
Normal file
120
venv/Lib/site-packages/torch/distributions/studentT.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import inf, nan, Tensor
|
||||
from torch.distributions import Chi2, constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import _standard_normal, broadcast_all
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["StudentT"]
|
||||
|
||||
|
||||
class StudentT(Distribution):
|
||||
r"""
|
||||
Creates a Student's t-distribution parameterized by degree of
|
||||
freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = StudentT(torch.tensor([2.0]))
|
||||
>>> m.sample() # Student's t-distributed with degrees of freedom=2
|
||||
tensor([ 0.1046])
|
||||
|
||||
Args:
|
||||
df (float or Tensor): degrees of freedom
|
||||
loc (float or Tensor): mean of the distribution
|
||||
scale (float or Tensor): scale of the distribution
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"df": constraints.positive,
|
||||
"loc": constraints.real,
|
||||
"scale": constraints.positive,
|
||||
}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
m = self.loc.clone(memory_format=torch.contiguous_format)
|
||||
m[self.df <= 1] = nan
|
||||
return m
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
m = self.df.clone(memory_format=torch.contiguous_format)
|
||||
m[self.df > 2] = (
|
||||
self.scale[self.df > 2].pow(2)
|
||||
* self.df[self.df > 2]
|
||||
/ (self.df[self.df > 2] - 2)
|
||||
)
|
||||
m[(self.df <= 2) & (self.df > 1)] = inf
|
||||
m[self.df <= 1] = nan
|
||||
return m
|
||||
|
||||
def __init__(self, df, loc=0.0, scale=1.0, validate_args=None):
|
||||
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
|
||||
self._chi2 = Chi2(self.df)
|
||||
batch_shape = self.df.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(StudentT, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.df = self.df.expand(batch_shape)
|
||||
new.loc = self.loc.expand(batch_shape)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
new._chi2 = self._chi2.expand(batch_shape)
|
||||
super(StudentT, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
# NOTE: This does not agree with scipy implementation as much as other distributions.
|
||||
# (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
|
||||
# parameters seems to help.
|
||||
|
||||
# X ~ Normal(0, 1)
|
||||
# Z ~ Chi2(df)
|
||||
# Y = X / sqrt(Z / df) ~ StudentT(df)
|
||||
shape = self._extended_shape(sample_shape)
|
||||
X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
|
||||
Z = self._chi2.rsample(sample_shape)
|
||||
Y = X * torch.rsqrt(Z / self.df)
|
||||
return self.loc + self.scale * Y
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
y = (value - self.loc) / self.scale
|
||||
Z = (
|
||||
self.scale.log()
|
||||
+ 0.5 * self.df.log()
|
||||
+ 0.5 * math.log(math.pi)
|
||||
+ torch.lgamma(0.5 * self.df)
|
||||
- torch.lgamma(0.5 * (self.df + 1.0))
|
||||
)
|
||||
return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
|
||||
|
||||
def entropy(self):
|
||||
lbeta = (
|
||||
torch.lgamma(0.5 * self.df)
|
||||
+ math.lgamma(0.5)
|
||||
- torch.lgamma(0.5 * (self.df + 1))
|
||||
)
|
||||
return (
|
||||
self.scale.log()
|
||||
+ 0.5
|
||||
* (self.df + 1)
|
||||
* (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
|
||||
+ 0.5 * self.df.log()
|
||||
+ lbeta
|
||||
)
|
|
@ -0,0 +1,217 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.independent import Independent
|
||||
from torch.distributions.transforms import ComposeTransform, Transform
|
||||
from torch.distributions.utils import _sum_rightmost
|
||||
from torch.types import _size
|
||||
|
||||
|
||||
__all__ = ["TransformedDistribution"]
|
||||
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
r"""
|
||||
Extension of the Distribution class, which applies a sequence of Transforms
|
||||
to a base distribution. Let f be the composition of transforms applied::
|
||||
|
||||
X ~ BaseDistribution
|
||||
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
|
||||
log p(Y) = log p(X) + log |det (dX/dY)|
|
||||
|
||||
Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
|
||||
maximum shape of its base distribution and its transforms, since transforms
|
||||
can introduce correlations among events.
|
||||
|
||||
An example for the usage of :class:`TransformedDistribution` would be::
|
||||
|
||||
# Building a Logistic Distribution
|
||||
# X ~ Uniform(0, 1)
|
||||
# f = a + b * logit(X)
|
||||
# Y ~ f(X) ~ Logistic(a, b)
|
||||
base_distribution = Uniform(0, 1)
|
||||
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
|
||||
logistic = TransformedDistribution(base_distribution, transforms)
|
||||
|
||||
For more examples, please look at the implementations of
|
||||
:class:`~torch.distributions.gumbel.Gumbel`,
|
||||
:class:`~torch.distributions.half_cauchy.HalfCauchy`,
|
||||
:class:`~torch.distributions.half_normal.HalfNormal`,
|
||||
:class:`~torch.distributions.log_normal.LogNormal`,
|
||||
:class:`~torch.distributions.pareto.Pareto`,
|
||||
:class:`~torch.distributions.weibull.Weibull`,
|
||||
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
|
||||
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
|
||||
"""
|
||||
|
||||
arg_constraints: dict[str, constraints.Constraint] = {}
|
||||
|
||||
def __init__(self, base_distribution, transforms, validate_args=None):
|
||||
if isinstance(transforms, Transform):
|
||||
self.transforms = [
|
||||
transforms,
|
||||
]
|
||||
elif isinstance(transforms, list):
|
||||
if not all(isinstance(t, Transform) for t in transforms):
|
||||
raise ValueError(
|
||||
"transforms must be a Transform or a list of Transforms"
|
||||
)
|
||||
self.transforms = transforms
|
||||
else:
|
||||
raise ValueError(
|
||||
f"transforms must be a Transform or list, but was {transforms}"
|
||||
)
|
||||
|
||||
# Reshape base_distribution according to transforms.
|
||||
base_shape = base_distribution.batch_shape + base_distribution.event_shape
|
||||
base_event_dim = len(base_distribution.event_shape)
|
||||
transform = ComposeTransform(self.transforms)
|
||||
if len(base_shape) < transform.domain.event_dim:
|
||||
raise ValueError(
|
||||
f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}."
|
||||
)
|
||||
forward_shape = transform.forward_shape(base_shape)
|
||||
expanded_base_shape = transform.inverse_shape(forward_shape)
|
||||
if base_shape != expanded_base_shape:
|
||||
base_batch_shape = expanded_base_shape[
|
||||
: len(expanded_base_shape) - base_event_dim
|
||||
]
|
||||
base_distribution = base_distribution.expand(base_batch_shape)
|
||||
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
|
||||
if reinterpreted_batch_ndims > 0:
|
||||
base_distribution = Independent(
|
||||
base_distribution, reinterpreted_batch_ndims
|
||||
)
|
||||
self.base_dist = base_distribution
|
||||
|
||||
# Compute shapes.
|
||||
transform_change_in_event_dim = (
|
||||
transform.codomain.event_dim - transform.domain.event_dim
|
||||
)
|
||||
event_dim = max(
|
||||
transform.codomain.event_dim, # the transform is coupled
|
||||
base_event_dim + transform_change_in_event_dim, # the base dist is coupled
|
||||
)
|
||||
assert len(forward_shape) >= event_dim
|
||||
cut = len(forward_shape) - event_dim
|
||||
batch_shape = forward_shape[:cut]
|
||||
event_shape = forward_shape[cut:]
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(TransformedDistribution, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
shape = batch_shape + self.event_shape
|
||||
for t in reversed(self.transforms):
|
||||
shape = t.inverse_shape(shape)
|
||||
base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
|
||||
new.base_dist = self.base_dist.expand(base_batch_shape)
|
||||
new.transforms = self.transforms
|
||||
super(TransformedDistribution, new).__init__(
|
||||
batch_shape, self.event_shape, validate_args=False
|
||||
)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
def support(self):
|
||||
if not self.transforms:
|
||||
return self.base_dist.support
|
||||
support = self.transforms[-1].codomain
|
||||
if len(self.event_shape) > support.event_dim:
|
||||
support = constraints.independent(
|
||||
support, len(self.event_shape) - support.event_dim
|
||||
)
|
||||
return support
|
||||
|
||||
@property
|
||||
def has_rsample(self) -> bool: # type: ignore[override]
|
||||
return self.base_dist.has_rsample
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
"""
|
||||
Generates a sample_shape shaped sample or sample_shape shaped batch of
|
||||
samples if the distribution parameters are batched. Samples first from
|
||||
base distribution and applies `transform()` for every transform in the
|
||||
list.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
x = self.base_dist.sample(sample_shape)
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
return x
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
"""
|
||||
Generates a sample_shape shaped reparameterized sample or sample_shape
|
||||
shaped batch of reparameterized samples if the distribution parameters
|
||||
are batched. Samples first from base distribution and applies
|
||||
`transform()` for every transform in the list.
|
||||
"""
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
return x
|
||||
|
||||
def log_prob(self, value):
|
||||
"""
|
||||
Scores the sample by inverting the transform(s) and computing the score
|
||||
using the score of the base distribution and the log abs det jacobian.
|
||||
"""
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
event_dim = len(self.event_shape)
|
||||
log_prob = 0.0
|
||||
y = value
|
||||
for transform in reversed(self.transforms):
|
||||
x = transform.inv(y)
|
||||
event_dim += transform.domain.event_dim - transform.codomain.event_dim
|
||||
log_prob = log_prob - _sum_rightmost(
|
||||
transform.log_abs_det_jacobian(x, y),
|
||||
event_dim - transform.domain.event_dim,
|
||||
)
|
||||
y = x
|
||||
|
||||
log_prob = log_prob + _sum_rightmost(
|
||||
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
|
||||
)
|
||||
return log_prob
|
||||
|
||||
def _monotonize_cdf(self, value):
|
||||
"""
|
||||
This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
|
||||
monotone increasing.
|
||||
"""
|
||||
sign = 1
|
||||
for transform in self.transforms:
|
||||
sign = sign * transform.sign
|
||||
if isinstance(sign, int) and sign == 1:
|
||||
return value
|
||||
return sign * (value - 0.5) + 0.5
|
||||
|
||||
def cdf(self, value):
|
||||
"""
|
||||
Computes the cumulative distribution function by inverting the
|
||||
transform(s) and computing the score of the base distribution.
|
||||
"""
|
||||
for transform in self.transforms[::-1]:
|
||||
value = transform.inv(value)
|
||||
if self._validate_args:
|
||||
self.base_dist._validate_sample(value)
|
||||
value = self.base_dist.cdf(value)
|
||||
value = self._monotonize_cdf(value)
|
||||
return value
|
||||
|
||||
def icdf(self, value):
|
||||
"""
|
||||
Computes the inverse cumulative distribution function using
|
||||
transform(s) and computing the score of the base distribution.
|
||||
"""
|
||||
value = self._monotonize_cdf(value)
|
||||
value = self.base_dist.icdf(value)
|
||||
for transform in self.transforms:
|
||||
value = transform(value)
|
||||
return value
|
1260
venv/Lib/site-packages/torch/distributions/transforms.py
Normal file
1260
venv/Lib/site-packages/torch/distributions/transforms.py
Normal file
File diff suppressed because it is too large
Load diff
101
venv/Lib/site-packages/torch/distributions/uniform.py
Normal file
101
venv/Lib/site-packages/torch/distributions/uniform.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Uniform"]
|
||||
|
||||
|
||||
class Uniform(Distribution):
|
||||
r"""
|
||||
Generates uniformly distributed random samples from the half-open interval
|
||||
``[low, high)``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
|
||||
>>> m.sample() # uniformly distributed in the range [0.0, 5.0)
|
||||
>>> # xdoctest: +SKIP
|
||||
tensor([ 2.3418])
|
||||
|
||||
Args:
|
||||
low (float or Tensor): lower range (inclusive).
|
||||
high (float or Tensor): upper range (exclusive).
|
||||
"""
|
||||
|
||||
# TODO allow (loc,scale) parameterization to allow independent constraints.
|
||||
arg_constraints = {
|
||||
"low": constraints.dependent(is_discrete=False, event_dim=0),
|
||||
"high": constraints.dependent(is_discrete=False, event_dim=0),
|
||||
}
|
||||
has_rsample = True
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return (self.high + self.low) / 2
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return nan * self.high
|
||||
|
||||
@property
|
||||
def stddev(self) -> Tensor:
|
||||
return (self.high - self.low) / 12**0.5
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return (self.high - self.low).pow(2) / 12
|
||||
|
||||
def __init__(self, low, high, validate_args=None):
|
||||
self.low, self.high = broadcast_all(low, high)
|
||||
|
||||
if isinstance(low, _Number) and isinstance(high, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.low.size()
|
||||
super().__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
if self._validate_args and not torch.lt(self.low, self.high).all():
|
||||
raise ValueError("Uniform is not defined when low>= high")
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Uniform, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
new.low = self.low.expand(batch_shape)
|
||||
new.high = self.high.expand(batch_shape)
|
||||
super(Uniform, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||
def support(self):
|
||||
return constraints.interval(self.low, self.high)
|
||||
|
||||
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
||||
shape = self._extended_shape(sample_shape)
|
||||
rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)
|
||||
return self.low + rand * (self.high - self.low)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
lb = self.low.le(value).type_as(self.low)
|
||||
ub = self.high.gt(value).type_as(self.low)
|
||||
return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
result = (value - self.low) / (self.high - self.low)
|
||||
return result.clamp(min=0, max=1)
|
||||
|
||||
def icdf(self, value):
|
||||
result = value * (self.high - self.low) + self.low
|
||||
return result
|
||||
|
||||
def entropy(self):
|
||||
return torch.log(self.high - self.low)
|
216
venv/Lib/site-packages/torch/distributions/utils.py
Normal file
216
venv/Lib/site-packages/torch/distributions/utils.py
Normal file
|
@ -0,0 +1,216 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from functools import update_wrapper
|
||||
from typing import Any, Callable, Generic, overload, Union
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.overrides import is_tensor_like
|
||||
from torch.types import _Number
|
||||
|
||||
|
||||
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
|
||||
|
||||
__all__ = [
|
||||
"broadcast_all",
|
||||
"logits_to_probs",
|
||||
"clamp_probs",
|
||||
"probs_to_logits",
|
||||
"lazy_property",
|
||||
"tril_matrix_to_vec",
|
||||
"vec_to_tril_matrix",
|
||||
]
|
||||
|
||||
|
||||
def broadcast_all(*values):
|
||||
r"""
|
||||
Given a list of values (possibly containing numbers), returns a list where each
|
||||
value is broadcasted based on the following rules:
|
||||
- `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
|
||||
- Number instances (scalars) are upcast to tensors having
|
||||
the same size and type as the first tensor passed to `values`. If all the
|
||||
values are scalars, then they are upcasted to scalar Tensors.
|
||||
|
||||
Args:
|
||||
values (list of `Number`, `torch.*Tensor` or objects implementing __torch_function__)
|
||||
|
||||
Raises:
|
||||
ValueError: if any of the values is not a `Number` instance,
|
||||
a `torch.*Tensor` instance, or an instance implementing __torch_function__
|
||||
"""
|
||||
if not all(is_tensor_like(v) or isinstance(v, _Number) for v in values):
|
||||
raise ValueError(
|
||||
"Input arguments must all be instances of Number, "
|
||||
"torch.Tensor or objects implementing __torch_function__."
|
||||
)
|
||||
if not all(is_tensor_like(v) for v in values):
|
||||
options: dict[str, Any] = dict(dtype=torch.get_default_dtype())
|
||||
for value in values:
|
||||
if isinstance(value, torch.Tensor):
|
||||
options = dict(dtype=value.dtype, device=value.device)
|
||||
break
|
||||
new_values = [
|
||||
v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
|
||||
]
|
||||
return torch.broadcast_tensors(*new_values)
|
||||
return torch.broadcast_tensors(*values)
|
||||
|
||||
|
||||
def _standard_normal(shape, dtype, device):
|
||||
if torch._C._get_tracing_state():
|
||||
# [JIT WORKAROUND] lack of support for .normal_()
|
||||
return torch.normal(
|
||||
torch.zeros(shape, dtype=dtype, device=device),
|
||||
torch.ones(shape, dtype=dtype, device=device),
|
||||
)
|
||||
return torch.empty(shape, dtype=dtype, device=device).normal_()
|
||||
|
||||
|
||||
def _sum_rightmost(value, dim):
|
||||
r"""
|
||||
Sum out ``dim`` many rightmost dimensions of a given tensor.
|
||||
|
||||
Args:
|
||||
value (Tensor): A tensor of ``.dim()`` at least ``dim``.
|
||||
dim (int): The number of rightmost dims to sum out.
|
||||
"""
|
||||
if dim == 0:
|
||||
return value
|
||||
required_shape = value.shape[:-dim] + (-1,)
|
||||
return value.reshape(required_shape).sum(-1)
|
||||
|
||||
|
||||
def logits_to_probs(logits, is_binary=False):
|
||||
r"""
|
||||
Converts a tensor of logits into probabilities. Note that for the
|
||||
binary case, each value denotes log odds, whereas for the
|
||||
multi-dimensional case, the values along the last dimension denote
|
||||
the log probabilities (possibly unnormalized) of the events.
|
||||
"""
|
||||
if is_binary:
|
||||
return torch.sigmoid(logits)
|
||||
return F.softmax(logits, dim=-1)
|
||||
|
||||
|
||||
def clamp_probs(probs):
|
||||
"""Clamps the probabilities to be in the open interval `(0, 1)`.
|
||||
|
||||
The probabilities would be clamped between `eps` and `1 - eps`,
|
||||
and `eps` would be the smallest representable positive number for the input data type.
|
||||
|
||||
Args:
|
||||
probs (Tensor): A tensor of probabilities.
|
||||
|
||||
Returns:
|
||||
Tensor: The clamped probabilities.
|
||||
|
||||
Examples:
|
||||
>>> probs = torch.tensor([0.0, 0.5, 1.0])
|
||||
>>> clamp_probs(probs)
|
||||
tensor([1.1921e-07, 5.0000e-01, 1.0000e+00])
|
||||
|
||||
>>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)
|
||||
>>> clamp_probs(probs)
|
||||
tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64)
|
||||
|
||||
"""
|
||||
eps = torch.finfo(probs.dtype).eps
|
||||
return probs.clamp(min=eps, max=1 - eps)
|
||||
|
||||
|
||||
def probs_to_logits(probs, is_binary=False):
|
||||
r"""
|
||||
Converts a tensor of probabilities into logits. For the binary case,
|
||||
this denotes the probability of occurrence of the event indexed by `1`.
|
||||
For the multi-dimensional case, the values along the last dimension
|
||||
denote the probabilities of occurrence of each of the events.
|
||||
"""
|
||||
ps_clamped = clamp_probs(probs)
|
||||
if is_binary:
|
||||
return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
|
||||
return torch.log(ps_clamped)
|
||||
|
||||
|
||||
T = TypeVar("T", contravariant=True)
|
||||
R = TypeVar("R", covariant=True)
|
||||
|
||||
|
||||
class lazy_property(Generic[T, R]):
|
||||
r"""
|
||||
Used as a decorator for lazy loading of class attributes. This uses a
|
||||
non-data descriptor that calls the wrapped method to compute the property on
|
||||
first call; thereafter replacing the wrapped method into an instance
|
||||
attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped: Callable[[T], R]) -> None:
|
||||
self.wrapped: Callable[[T], R] = wrapped
|
||||
update_wrapper(self, wrapped) # type:ignore[arg-type]
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self, instance: None, obj_type: Any = None
|
||||
) -> "_lazy_property_and_property[T, R]": ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: T, obj_type: Any = None) -> R: ...
|
||||
|
||||
def __get__(
|
||||
self, instance: Union[T, None], obj_type: Any = None
|
||||
) -> "R | _lazy_property_and_property[T, R]":
|
||||
if instance is None:
|
||||
return _lazy_property_and_property(self.wrapped)
|
||||
with torch.enable_grad():
|
||||
value = self.wrapped(instance)
|
||||
setattr(instance, self.wrapped.__name__, value)
|
||||
return value
|
||||
|
||||
|
||||
class _lazy_property_and_property(lazy_property[T, R], property):
|
||||
"""We want lazy properties to look like multiple things.
|
||||
|
||||
* property when Sphinx autodoc looks
|
||||
* lazy_property when Distribution validate_args looks
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped: Callable[[T], R]) -> None:
|
||||
property.__init__(self, wrapped)
|
||||
|
||||
|
||||
def tril_matrix_to_vec(mat: Tensor, diag: int = 0) -> Tensor:
|
||||
r"""
|
||||
Convert a `D x D` matrix or a batch of matrices into a (batched) vector
|
||||
which comprises of lower triangular elements from the matrix in row order.
|
||||
"""
|
||||
n = mat.shape[-1]
|
||||
if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
|
||||
raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n - 1}].")
|
||||
arange = torch.arange(n, device=mat.device)
|
||||
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
|
||||
vec = mat[..., tril_mask]
|
||||
return vec
|
||||
|
||||
|
||||
def vec_to_tril_matrix(vec: Tensor, diag: int = 0) -> Tensor:
|
||||
r"""
|
||||
Convert a vector or a batch of vectors into a batched `D x D`
|
||||
lower triangular matrix containing elements from the vector in row order.
|
||||
"""
|
||||
# +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
|
||||
n = (
|
||||
-(1 + 2 * diag)
|
||||
+ ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5
|
||||
) / 2
|
||||
eps = torch.finfo(vec.dtype).eps
|
||||
if not torch._C._get_tracing_state() and (round(n) - n > eps):
|
||||
raise ValueError(
|
||||
f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "
|
||||
+ "the lower triangular part of a square D x D matrix."
|
||||
)
|
||||
n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
|
||||
mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
|
||||
arange = torch.arange(n, device=vec.device)
|
||||
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
|
||||
mat[..., tril_mask] = vec
|
||||
return mat
|
212
venv/Lib/site-packages/torch/distributions/von_mises.py
Normal file
212
venv/Lib/site-packages/torch/distributions/von_mises.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all, lazy_property
|
||||
|
||||
|
||||
__all__ = ["VonMises"]
|
||||
|
||||
|
||||
def _eval_poly(y, coef):
|
||||
coef = list(coef)
|
||||
result = coef.pop()
|
||||
while coef:
|
||||
result = coef.pop() + y * result
|
||||
return result
|
||||
|
||||
|
||||
_I0_COEF_SMALL = [
|
||||
1.0,
|
||||
3.5156229,
|
||||
3.0899424,
|
||||
1.2067492,
|
||||
0.2659732,
|
||||
0.360768e-1,
|
||||
0.45813e-2,
|
||||
]
|
||||
_I0_COEF_LARGE = [
|
||||
0.39894228,
|
||||
0.1328592e-1,
|
||||
0.225319e-2,
|
||||
-0.157565e-2,
|
||||
0.916281e-2,
|
||||
-0.2057706e-1,
|
||||
0.2635537e-1,
|
||||
-0.1647633e-1,
|
||||
0.392377e-2,
|
||||
]
|
||||
_I1_COEF_SMALL = [
|
||||
0.5,
|
||||
0.87890594,
|
||||
0.51498869,
|
||||
0.15084934,
|
||||
0.2658733e-1,
|
||||
0.301532e-2,
|
||||
0.32411e-3,
|
||||
]
|
||||
_I1_COEF_LARGE = [
|
||||
0.39894228,
|
||||
-0.3988024e-1,
|
||||
-0.362018e-2,
|
||||
0.163801e-2,
|
||||
-0.1031555e-1,
|
||||
0.2282967e-1,
|
||||
-0.2895312e-1,
|
||||
0.1787654e-1,
|
||||
-0.420059e-2,
|
||||
]
|
||||
|
||||
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
|
||||
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
|
||||
|
||||
|
||||
def _log_modified_bessel_fn(x, order=0):
|
||||
"""
|
||||
Returns ``log(I_order(x))`` for ``x > 0``,
|
||||
where `order` is either 0 or 1.
|
||||
"""
|
||||
assert order == 0 or order == 1
|
||||
|
||||
# compute small solution
|
||||
y = x / 3.75
|
||||
y = y * y
|
||||
small = _eval_poly(y, _COEF_SMALL[order])
|
||||
if order == 1:
|
||||
small = x.abs() * small
|
||||
small = small.log()
|
||||
|
||||
# compute large solution
|
||||
y = 3.75 / x
|
||||
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
|
||||
|
||||
result = torch.where(x < 3.75, small, large)
|
||||
return result
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def _rejection_sample(loc, concentration, proposal_r, x):
|
||||
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
||||
while not done.all():
|
||||
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
||||
u1, u2, u3 = u.unbind()
|
||||
z = torch.cos(math.pi * u1)
|
||||
f = (1 + proposal_r * z) / (proposal_r + z)
|
||||
c = concentration * (proposal_r - f)
|
||||
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
||||
if accept.any():
|
||||
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
||||
done = done | accept
|
||||
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
||||
|
||||
|
||||
class VonMises(Distribution):
|
||||
"""
|
||||
A circular von Mises distribution.
|
||||
|
||||
This implementation uses polar coordinates. The ``loc`` and ``value`` args
|
||||
can be any real number (to facilitate unconstrained optimization), but are
|
||||
interpreted as angles modulo 2 pi.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # von Mises distributed with loc=1 and concentration=1
|
||||
tensor([1.9777])
|
||||
|
||||
:param torch.Tensor loc: an angle in radians.
|
||||
:param torch.Tensor concentration: concentration parameter
|
||||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = False
|
||||
|
||||
def __init__(self, loc, concentration, validate_args=None):
|
||||
self.loc, self.concentration = broadcast_all(loc, concentration)
|
||||
batch_shape = self.loc.shape
|
||||
event_shape = torch.Size()
|
||||
super().__init__(batch_shape, event_shape, validate_args)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
log_prob = self.concentration * torch.cos(value - self.loc)
|
||||
log_prob = (
|
||||
log_prob
|
||||
- math.log(2 * math.pi)
|
||||
- _log_modified_bessel_fn(self.concentration, order=0)
|
||||
)
|
||||
return log_prob
|
||||
|
||||
@lazy_property
|
||||
def _loc(self) -> Tensor:
|
||||
return self.loc.to(torch.double)
|
||||
|
||||
@lazy_property
|
||||
def _concentration(self) -> Tensor:
|
||||
return self.concentration.to(torch.double)
|
||||
|
||||
@lazy_property
|
||||
def _proposal_r(self) -> Tensor:
|
||||
kappa = self._concentration
|
||||
tau = 1 + (1 + 4 * kappa**2).sqrt()
|
||||
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
|
||||
_proposal_r = (1 + rho**2) / (2 * rho)
|
||||
# second order Taylor expansion around 0 for small kappa
|
||||
_proposal_r_taylor = 1 / kappa + kappa
|
||||
return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
"""
|
||||
The sampling algorithm for the von Mises distribution is based on the
|
||||
following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
|
||||
von Mises distribution." Applied Statistics (1979): 152-157.
|
||||
|
||||
Sampling is always done in double precision internally to avoid a hang
|
||||
in _rejection_sample() for small values of the concentration, which
|
||||
starts to happen for single precision around 1e-4 (see issue #88443).
|
||||
"""
|
||||
shape = self._extended_shape(sample_shape)
|
||||
x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
|
||||
return _rejection_sample(
|
||||
self._loc, self._concentration, self._proposal_r, x
|
||||
).to(self.loc.dtype)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
try:
|
||||
return super().expand(batch_shape)
|
||||
except NotImplementedError:
|
||||
validate_args = self.__dict__.get("_validate_args")
|
||||
loc = self.loc.expand(batch_shape)
|
||||
concentration = self.concentration.expand(batch_shape)
|
||||
return type(self)(loc, concentration, validate_args=validate_args)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
"""
|
||||
The provided mean is the circular one.
|
||||
"""
|
||||
return self.loc
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return self.loc
|
||||
|
||||
@lazy_property
|
||||
def variance(self) -> Tensor: # type: ignore[override]
|
||||
"""
|
||||
The provided variance is the circular one.
|
||||
"""
|
||||
return (
|
||||
1
|
||||
- (
|
||||
_log_modified_bessel_fn(self.concentration, order=1)
|
||||
- _log_modified_bessel_fn(self.concentration, order=0)
|
||||
).exp()
|
||||
)
|
87
venv/Lib/site-packages/torch/distributions/weibull.py
Normal file
87
venv/Lib/site-packages/torch/distributions/weibull.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exponential import Exponential
|
||||
from torch.distributions.gumbel import euler_constant
|
||||
from torch.distributions.transformed_distribution import TransformedDistribution
|
||||
from torch.distributions.transforms import AffineTransform, PowerTransform
|
||||
from torch.distributions.utils import broadcast_all
|
||||
|
||||
|
||||
__all__ = ["Weibull"]
|
||||
|
||||
|
||||
class Weibull(TransformedDistribution):
|
||||
r"""
|
||||
Samples from a two-parameter Weibull distribution.
|
||||
|
||||
Example:
|
||||
|
||||
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||||
>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1
|
||||
tensor([ 0.4784])
|
||||
|
||||
Args:
|
||||
scale (float or Tensor): Scale parameter of distribution (lambda).
|
||||
concentration (float or Tensor): Concentration parameter of distribution (k/shape).
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"scale": constraints.positive,
|
||||
"concentration": constraints.positive,
|
||||
}
|
||||
support = constraints.positive
|
||||
|
||||
def __init__(self, scale, concentration, validate_args=None):
|
||||
self.scale, self.concentration = broadcast_all(scale, concentration)
|
||||
self.concentration_reciprocal = self.concentration.reciprocal()
|
||||
base_dist = Exponential(
|
||||
torch.ones_like(self.scale), validate_args=validate_args
|
||||
)
|
||||
transforms = [
|
||||
PowerTransform(exponent=self.concentration_reciprocal),
|
||||
AffineTransform(loc=0, scale=self.scale),
|
||||
]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Weibull, _instance)
|
||||
new.scale = self.scale.expand(batch_shape)
|
||||
new.concentration = self.concentration.expand(batch_shape)
|
||||
new.concentration_reciprocal = new.concentration.reciprocal()
|
||||
base_dist = self.base_dist.expand(batch_shape)
|
||||
transforms = [
|
||||
PowerTransform(exponent=new.concentration_reciprocal),
|
||||
AffineTransform(loc=0, scale=new.scale),
|
||||
]
|
||||
super(Weibull, new).__init__(base_dist, transforms, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.scale * torch.exp(torch.lgamma(1 + self.concentration_reciprocal))
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
return (
|
||||
self.scale
|
||||
* ((self.concentration - 1) / self.concentration)
|
||||
** self.concentration.reciprocal()
|
||||
)
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
return self.scale.pow(2) * (
|
||||
torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal))
|
||||
- torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal))
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
return (
|
||||
euler_constant * (1 - self.concentration_reciprocal)
|
||||
+ torch.log(self.scale * self.concentration_reciprocal)
|
||||
+ 1
|
||||
)
|
340
venv/Lib/site-packages/torch/distributions/wishart.py
Normal file
340
venv/Lib/site-packages/torch/distributions/wishart.py
Normal file
|
@ -0,0 +1,340 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.multivariate_normal import _precision_to_scale_tril
|
||||
from torch.distributions.utils import lazy_property
|
||||
from torch.types import _Number, _size, Number
|
||||
|
||||
|
||||
__all__ = ["Wishart"]
|
||||
|
||||
_log_2 = math.log(2)
|
||||
|
||||
|
||||
def _mvdigamma(x: Tensor, p: int) -> Tensor:
|
||||
assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
|
||||
return torch.digamma(
|
||||
x.unsqueeze(-1)
|
||||
- torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
|
||||
).sum(-1)
|
||||
|
||||
|
||||
def _clamp_above_eps(x: Tensor) -> Tensor:
|
||||
# We assume positive input for this function
|
||||
return x.clamp(min=torch.finfo(x.dtype).eps)
|
||||
|
||||
|
||||
class Wishart(ExponentialFamily):
|
||||
r"""
|
||||
Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
|
||||
or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
|
||||
>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
|
||||
>>> m.sample() # Wishart distributed with mean=`df * I` and
|
||||
>>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
|
||||
|
||||
Args:
|
||||
df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
|
||||
covariance_matrix (Tensor): positive-definite covariance matrix
|
||||
precision_matrix (Tensor): positive-definite precision matrix
|
||||
scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
|
||||
Note:
|
||||
Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
|
||||
:attr:`scale_tril` can be specified.
|
||||
Using :attr:`scale_tril` will be more efficient: all computations internally
|
||||
are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
|
||||
:attr:`precision_matrix` is passed instead, it is only used to compute
|
||||
the corresponding lower triangular matrices using a Cholesky decomposition.
|
||||
'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
|
||||
|
||||
**References**
|
||||
|
||||
[1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
|
||||
[2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
|
||||
[3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
|
||||
[4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
|
||||
[5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
|
||||
"""
|
||||
|
||||
arg_constraints = {
|
||||
"covariance_matrix": constraints.positive_definite,
|
||||
"precision_matrix": constraints.positive_definite,
|
||||
"scale_tril": constraints.lower_cholesky,
|
||||
"df": constraints.greater_than(0),
|
||||
}
|
||||
support = constraints.positive_definite
|
||||
has_rsample = True
|
||||
_mean_carrier_measure = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: Union[Tensor, Number],
|
||||
covariance_matrix: Optional[Tensor] = None,
|
||||
precision_matrix: Optional[Tensor] = None,
|
||||
scale_tril: Optional[Tensor] = None,
|
||||
validate_args=None,
|
||||
):
|
||||
assert (covariance_matrix is not None) + (scale_tril is not None) + (
|
||||
precision_matrix is not None
|
||||
) == 1, (
|
||||
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
|
||||
)
|
||||
|
||||
param = next(
|
||||
p
|
||||
for p in (covariance_matrix, precision_matrix, scale_tril)
|
||||
if p is not None
|
||||
)
|
||||
|
||||
if param.dim() < 2:
|
||||
raise ValueError(
|
||||
"scale_tril must be at least two-dimensional, with optional leading batch dimensions"
|
||||
)
|
||||
|
||||
if isinstance(df, _Number):
|
||||
batch_shape = torch.Size(param.shape[:-2])
|
||||
self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
|
||||
else:
|
||||
batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
|
||||
self.df = df.expand(batch_shape)
|
||||
event_shape = param.shape[-2:]
|
||||
|
||||
if self.df.le(event_shape[-1] - 1).any():
|
||||
raise ValueError(
|
||||
f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1] - 1}."
|
||||
)
|
||||
|
||||
if scale_tril is not None:
|
||||
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
||||
elif covariance_matrix is not None:
|
||||
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
||||
elif precision_matrix is not None:
|
||||
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
||||
|
||||
self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1)
|
||||
if self.df.lt(event_shape[-1]).any():
|
||||
warnings.warn(
|
||||
"Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."
|
||||
)
|
||||
|
||||
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
|
||||
|
||||
if scale_tril is not None:
|
||||
self._unbroadcasted_scale_tril = scale_tril
|
||||
elif covariance_matrix is not None:
|
||||
self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
|
||||
else: # precision_matrix is not None
|
||||
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
|
||||
|
||||
# Chi2 distribution is needed for Bartlett decomposition sampling
|
||||
self._dist_chi2 = torch.distributions.chi2.Chi2(
|
||||
df=(
|
||||
self.df.unsqueeze(-1)
|
||||
- torch.arange(
|
||||
self._event_shape[-1],
|
||||
dtype=self._unbroadcasted_scale_tril.dtype,
|
||||
device=self._unbroadcasted_scale_tril.device,
|
||||
).expand(batch_shape + (-1,))
|
||||
)
|
||||
)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Wishart, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
cov_shape = batch_shape + self.event_shape
|
||||
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
|
||||
new.df = self.df.expand(batch_shape)
|
||||
|
||||
new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
|
||||
|
||||
if "covariance_matrix" in self.__dict__:
|
||||
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
|
||||
if "scale_tril" in self.__dict__:
|
||||
new.scale_tril = self.scale_tril.expand(cov_shape)
|
||||
if "precision_matrix" in self.__dict__:
|
||||
new.precision_matrix = self.precision_matrix.expand(cov_shape)
|
||||
|
||||
# Chi2 distribution is needed for Bartlett decomposition sampling
|
||||
new._dist_chi2 = torch.distributions.chi2.Chi2(
|
||||
df=(
|
||||
new.df.unsqueeze(-1)
|
||||
- torch.arange(
|
||||
self.event_shape[-1],
|
||||
dtype=new._unbroadcasted_scale_tril.dtype,
|
||||
device=new._unbroadcasted_scale_tril.device,
|
||||
).expand(batch_shape + (-1,))
|
||||
)
|
||||
)
|
||||
|
||||
super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@lazy_property
|
||||
def scale_tril(self) -> Tensor:
|
||||
return self._unbroadcasted_scale_tril.expand(
|
||||
self._batch_shape + self._event_shape
|
||||
)
|
||||
|
||||
@lazy_property
|
||||
def covariance_matrix(self) -> Tensor:
|
||||
return (
|
||||
self._unbroadcasted_scale_tril
|
||||
@ self._unbroadcasted_scale_tril.transpose(-2, -1)
|
||||
).expand(self._batch_shape + self._event_shape)
|
||||
|
||||
@lazy_property
|
||||
def precision_matrix(self) -> Tensor:
|
||||
identity = torch.eye(
|
||||
self._event_shape[-1],
|
||||
device=self._unbroadcasted_scale_tril.device,
|
||||
dtype=self._unbroadcasted_scale_tril.dtype,
|
||||
)
|
||||
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
|
||||
self._batch_shape + self._event_shape
|
||||
)
|
||||
|
||||
@property
|
||||
def mean(self) -> Tensor:
|
||||
return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
|
||||
|
||||
@property
|
||||
def mode(self) -> Tensor:
|
||||
factor = self.df - self.covariance_matrix.shape[-1] - 1
|
||||
factor[factor <= 0] = nan
|
||||
return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
|
||||
|
||||
@property
|
||||
def variance(self) -> Tensor:
|
||||
V = self.covariance_matrix # has shape (batch_shape x event_shape)
|
||||
diag_V = V.diagonal(dim1=-2, dim2=-1)
|
||||
return self.df.view(self._batch_shape + (1, 1)) * (
|
||||
V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)
|
||||
)
|
||||
|
||||
def _bartlett_sampling(self, sample_shape=torch.Size()):
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
|
||||
# Implemented Sampling using Bartlett decomposition
|
||||
noise = _clamp_above_eps(
|
||||
self._dist_chi2.rsample(sample_shape).sqrt()
|
||||
).diag_embed(dim1=-2, dim2=-1)
|
||||
|
||||
i, j = torch.tril_indices(p, p, offset=-1)
|
||||
noise[..., i, j] = torch.randn(
|
||||
torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
|
||||
dtype=noise.dtype,
|
||||
device=noise.device,
|
||||
)
|
||||
chol = self._unbroadcasted_scale_tril @ noise
|
||||
return chol @ chol.transpose(-2, -1)
|
||||
|
||||
def rsample(
|
||||
self, sample_shape: _size = torch.Size(), max_try_correction=None
|
||||
) -> Tensor:
|
||||
r"""
|
||||
.. warning::
|
||||
In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
|
||||
Several tries to correct singular samples are performed by default, but it may end up returning
|
||||
singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
|
||||
In those cases, the user should validate the samples and either fix the value of `df`
|
||||
or adjust `max_try_correction` value for argument in `.rsample` accordingly.
|
||||
"""
|
||||
|
||||
if max_try_correction is None:
|
||||
max_try_correction = 3 if torch._C._get_tracing_state() else 10
|
||||
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
sample = self._bartlett_sampling(sample_shape)
|
||||
|
||||
# Below part is to improve numerical stability temporally and should be removed in the future
|
||||
is_singular = self.support.check(sample)
|
||||
if self._batch_shape:
|
||||
is_singular = is_singular.amax(self._batch_dims)
|
||||
|
||||
if torch._C._get_tracing_state():
|
||||
# Less optimized version for JIT
|
||||
for _ in range(max_try_correction):
|
||||
sample_new = self._bartlett_sampling(sample_shape)
|
||||
sample = torch.where(is_singular, sample_new, sample)
|
||||
|
||||
is_singular = ~self.support.check(sample)
|
||||
if self._batch_shape:
|
||||
is_singular = is_singular.amax(self._batch_dims)
|
||||
|
||||
else:
|
||||
# More optimized version with data-dependent control flow.
|
||||
if is_singular.any():
|
||||
warnings.warn("Singular sample detected.")
|
||||
|
||||
for _ in range(max_try_correction):
|
||||
sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
|
||||
sample[is_singular] = sample_new
|
||||
|
||||
is_singular_new = ~self.support.check(sample_new)
|
||||
if self._batch_shape:
|
||||
is_singular_new = is_singular_new.amax(self._batch_dims)
|
||||
is_singular[is_singular.clone()] = is_singular_new
|
||||
|
||||
if not is_singular.any():
|
||||
break
|
||||
|
||||
return sample
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
nu = self.df # has shape (batch_shape)
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
return (
|
||||
-nu
|
||||
* (
|
||||
p * _log_2 / 2
|
||||
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
|
||||
.log()
|
||||
.sum(-1)
|
||||
)
|
||||
- torch.mvlgamma(nu / 2, p=p)
|
||||
+ (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
|
||||
- torch.cholesky_solve(value, self._unbroadcasted_scale_tril)
|
||||
.diagonal(dim1=-2, dim2=-1)
|
||||
.sum(dim=-1)
|
||||
/ 2
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
nu = self.df # has shape (batch_shape)
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
return (
|
||||
(p + 1)
|
||||
* (
|
||||
p * _log_2 / 2
|
||||
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
|
||||
.log()
|
||||
.sum(-1)
|
||||
)
|
||||
+ torch.mvlgamma(nu / 2, p=p)
|
||||
- (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
|
||||
+ nu * p / 2
|
||||
)
|
||||
|
||||
@property
|
||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||
nu = self.df # has shape (batch_shape)
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
return -self.precision_matrix / 2, (nu - p - 1) / 2
|
||||
|
||||
def _log_normalizer(self, x, y):
|
||||
p = self._event_shape[-1]
|
||||
return (y + (p + 1) / 2) * (
|
||||
-torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p
|
||||
) + torch.mvlgamma(y + (p + 1) / 2, p=p)
|
Loading…
Add table
Add a link
Reference in a new issue