Source code for skwdro.base.samplers.torch.base_samplers

from typing import Dict, Optional, Tuple, Union
from abc import ABC, abstractmethod

import random
import torch as pt

import skwdro.distributions as dst


[docs] class BaseSampler(ABC): seed: Optional[int] def __init__(self, seed: Optional[int] = None): """ Base class for all samplers available in the library. One must subclass this in order to make their samplers comply with the interfaces of this library. .. note:: This class is iterable. Attributes ---------- seed: int|None rando seed for np and torch rngs. """ self.seed = seed # Set seed if seed is not None: pt.manual_seed(seed) random.seed(seed)
[docs] @abstractmethod def sample( self, n_samples: int ) -> Tuple[pt.Tensor, Optional[pt.Tensor]]: """ Override this method to make a custom sampling mechanism from scratch. It should output a pair of tensors for ``xi`` and ``xi_labels``. Parameters ---------- n_samples: int number of samples to draw Returns ------- zeta: torch.Tensor input samples drawn zeta_labels: torch.Tensor|None input targets drawn, if any """ raise NotImplementedError()
def __iter__(self): return self def __next__(self): return self.sample(1) @property @abstractmethod def produces_labels(self) -> bool: raise NotImplementedError()
[docs] @abstractmethod def reset_mean( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] ): """ Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it. Parameters ---------- xi: torch.Tensor part of the parametrization of the sampler that concerns the input variables xi_labels: torch.Tensor|None part of the parametrization of the sampler that concerns the labels variables """ raise NotImplementedError()
[docs] @abstractmethod def log_prob( self, zeta: pt.Tensor, zeta_labels: Optional[pt.Tensor] ) -> pt.Tensor: raise NotImplementedError()
[docs] @abstractmethod def log_prob_recentered( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], zeta: pt.Tensor, zeta_labels: Optional[pt.Tensor] ) -> pt.Tensor: raise NotImplementedError()
def _format_logprobs( self, dist: dst.Distribution, data: pt.Tensor ) -> pt.Tensor: lp = dist.log_prob(data) if lp.dim() == 3: return lp.sum(dim=-1, keepdim=True) elif lp.dim() == 2: return lp.unsqueeze(-1) else: return lp.unsqueeze(-1) raise NotImplementedError()
[docs] class NoLabelsSampler(BaseSampler, ABC): def __init__(self, data_sampler: dst.Distribution, seed: Optional[int]): """ Base class for all samplers that do not need targets (outputing ``None``) Attributes ---------- data_s: torch.distributions.Distribution torch distribution to sample the input data from """ super(NoLabelsSampler, self).__init__(seed) self.data_s = data_sampler
[docs] def sample(self, n_samples: int): return self.data_s.rsample(pt.Size((n_samples,))), None
@property def produces_labels(self): return False
[docs] def log_prob( self, zeta: pt.Tensor, zeta_labels: Optional[pt.Tensor] ) -> pt.Tensor: assert zeta_labels is None return self._format_logprobs(self.data_s, zeta)
# return self.data_s.log_prob(zeta).sum(-1, keepdim=True)
[docs] def log_prob_recentered( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], zeta: pt.Tensor, zeta_labels: Optional[pt.Tensor] ) -> pt.Tensor: assert xi_labels is None and zeta_labels is None return self._format_logprobs(self.data_s, zeta - xi + self.data_s.mean)
# return self.data_s.log_prob(zeta - xi + self.data_s.mean).sum(-1, keepdim=True)
[docs] class LabeledSampler(BaseSampler, ABC): def __init__( self, data_sampler: dst.Distribution, labels_sampler: dst.Distribution, seed: Optional[int] ) -> None: """ Base class for all samplers that do not need targets (outputing ``None``) Attributes ---------- data_s: torch.distributions.Distribution torch distribution to sample the input data from labels_s: torch.distributions.Distribution torch distribution to sample the targets from """ super(LabeledSampler, self).__init__(seed) self.data_s = data_sampler self.labels_s = labels_sampler
[docs] def sample(self, n_samples: int): zeta = self.sample_data(n_samples) zeta_labels = self.sample_labels(n_samples) return zeta, zeta_labels
[docs] def sample_data(self, n_sample: int): return self.data_s.rsample(pt.Size((n_sample,)))
[docs] def sample_labels(self, n_sample: int): return self.labels_s.rsample(pt.Size((n_sample,)))
@property def produces_labels(self): return True
[docs] def log_prob( self, zeta: pt.Tensor, zeta_labels: Optional[pt.Tensor] ) -> pt.Tensor: assert zeta_labels is not None lp_zeta = self._format_logprobs(self.data_s, zeta) lp_zeta_labels = self._format_logprobs(self.labels_s, zeta_labels) lp = lp_zeta + lp_zeta_labels return lp
[docs] def log_prob_recentered( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], zeta: pt.Tensor, zeta_labels: Optional[pt.Tensor] ) -> pt.Tensor: assert zeta_labels is not None and xi_labels is not None # TODO: FIX lp_zeta = self._format_logprobs(self.data_s, zeta) lp_zeta_labels = self._format_logprobs(self.labels_s, zeta_labels) lp = lp_zeta + lp_zeta_labels return lp
# Helper class ########################
[docs] class IsOptionalCovarianceSampler(ABC):
[docs] def init_covar( self, d: int, sigma: Optional[Union[float, pt.Tensor]] = None, tril: Optional[pt.Tensor] = None, prec: Optional[pt.Tensor] = None, cov: Optional[pt.Tensor] = None ) -> Dict[str, pt.Tensor]: """ Sets up the covariance matrix in the correct format to give as a kwarg to torch distributions. Order of importance for non-None values: * sigma: defines Id/sigma**2 as cov matrix, given as L^T@L * tril: defines L s.t. C=L^TL * cov: defines the full C matrix * prec: defines the precision matrix C^-1, only useful for fast CDF computation and bad otherwise """ if sigma is not None: return {"scale_tril": pt.eye(d) * sigma} elif tril is not None: return {"scale_tril": tril} elif cov is not None: return {"covariance_matrix": cov} elif prec is not None: return {"precision_matrix": prec} else: raise ValueError(' '.join([ "Please provide", "a valid covariance", "matrix for the" "constructor of", str(self.__class__.__name__) ]))