from typing import List, Optional, Dict, Tuple, Union
import torch as pt
import torch.distributions as dst
import torch.distributions.constraints as cstr
Shapeoid = Union[pt.Size, List[int], Tuple[int, ...]]
[docs]
class Dirac(dst.ExponentialFamily):
@property
def arg_constraints(self) -> Dict[str, cstr.Constraint]:
return {"loc": cstr.real_vector}
@property
def support(self) -> Optional[cstr.Constraint]:
return cstr.real_vector # type: ignore
has_rsample = True
def __init__(
self,
loc: pt.Tensor,
n_batch_dims: int = 0,
validate_args: Optional[bool] = None):
locshape = loc.size()
batch_shape = locshape[:n_batch_dims]
event_shape = locshape[n_batch_dims:]
self.loc: pt.Tensor = loc
super().__init__(batch_shape, event_shape, validate_args)
[docs]
def expand(
self,
batch_shape: Shapeoid,
_instance=None):
new: Dirac = self._get_checked_instance(Dirac, _instance)
batch_shape = cast_to_size(batch_shape)
loc_shape = batch_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
assert isinstance(new, Dirac)
super(Dirac, new).__init__( # type: ignore
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
@property
def mean(self) -> pt.Tensor:
return self.loc
@property
def mode(self) -> pt.Tensor:
return self.loc
@property
def variance(self) -> pt.Tensor:
return pt.zeros_like(self.loc)
[docs]
def rsample(self, sample_shape: Shapeoid = pt.Size()) -> pt.Tensor:
sample_shape = cast_to_size(sample_shape)
return self.loc.expand(self._extended_shape(sample_shape))
[docs]
def log_prob(self, value: pt.Tensor) -> pt.Tensor:
return (
pt.tensor(0.) if (value - self.loc).abs().sum() == 0.
else pt.tensor(-pt.inf)
)
[docs]
def enumerate_support(self, expand: bool = True) -> pt.Tensor:
if expand:
return self.rsample(pt.Size((1,)))
else:
raise NotImplementedError
[docs]
def entropy(self) -> pt.Tensor:
return pt.tensor(-pt.inf)
[docs]
def perplexity(self) -> pt.Tensor:
return pt.tensor(0.)
[docs]
def cast_to_size(
shape: Shapeoid) -> pt.Size:
if (isinstance(shape, pt.Size)):
return shape
else:
return pt.Size(shape)