Source code for skwdro.base.samplers.torch.classif_sampler

from typing import Optional, Union
import torch as pt

import skwdro.distributions as dst
from .base_samplers import IsOptionalCovarianceSampler, LabeledSampler


[docs] class ClassificationNormalNormalSampler(LabeledSampler, IsOptionalCovarianceSampler): data_s: dst.MultivariateNormal labels_s: dst.MultivariateNormal def __init__( self, xi: pt.Tensor, xi_labels: pt.Tensor, *, sigma: Optional[Union[float, pt.Tensor]] = None, tril: Optional[pt.Tensor] = None, prec: Optional[pt.Tensor] = None, cov: Optional[pt.Tensor] = None, l_sigma: Optional[Union[float, pt.Tensor]] = None, l_tril: Optional[pt.Tensor] = None, l_prec: Optional[pt.Tensor] = None, l_cov: Optional[pt.Tensor] = None, seed: Optional[int] = None, ): """ Example of an available sampler for classification problems. - inputs are sampled from a gaussian distribution - targets from another one. Specify the parameters of the distributions as keywords arguments. Parameters ---------- xi: pt.Tensor mean for inputs xi_labels: pt.Tensor mean for targets sigma: float|Tensor scalar standard deviation shared through dimensions, for inputs. l_sigma: float|Tensor scalar standard deviation shared through dimensions, for targets. See :py:class:`skwdro.base.samplers.torch.base_samplers.IsOptionalCovarianceSampler` for other arguments. """ assert len(xi.size()) >= 2 assert len(xi_labels.size()) >= 2 covar = self.init_covar(xi.size(-1), sigma, tril, prec, cov) labels_covar = self.init_covar( xi_labels.size(-1), l_sigma, l_tril, l_prec, l_cov ) super(ClassificationNormalNormalSampler, self).__init__( dst.MultivariateNormal( loc=xi, **covar # type: ignore ), dst.MultivariateNormal( loc=xi_labels, **labels_covar # type: ignore ), seed )
[docs] def reset_mean( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] ): assert xi_labels is not None ClassificationNormalNormalSampler.__init__( self, xi, xi_labels, seed=self.seed, tril=self.data_s._unbroadcasted_scale_tril, l_tril=self.labels_s._unbroadcasted_scale_tril )
[docs] class ClassificationNormalIdSampler(LabeledSampler, IsOptionalCovarianceSampler): data_s: dst.MultivariateNormal labels_s: dst.Dirac def __init__( self, xi: pt.Tensor, xi_labels: pt.Tensor, *, sigma: Optional[Union[float, pt.Tensor]] = None, tril: Optional[pt.Tensor] = None, prec: Optional[pt.Tensor] = None, cov: Optional[pt.Tensor] = None, seed: Optional[int], ): """ Example of an available sampler for classification problems. - inputs are sampled from a gaussian distribution - targets are not noisy. Specify the parameters of the distributions as keywords arguments. Parameters ---------- xi: pt.Tensor mean for inputs xi_labels: pt.Tensor mean for targets sigma: float|Tensor scalar standard deviation shared through dimensions, for inputs. See :py:class:`skwdro.base.samplers.torch.base_samplers.IsOptionalCovarianceSampler` for other arguments. """ covar = self.init_covar(xi.size(-1), sigma, tril, prec, cov) super(ClassificationNormalIdSampler, self).__init__( dst.MultivariateNormal( loc=xi, **covar # type: ignore ), dst.Dirac( loc=xi_labels, n_batch_dims=1, validate_args=True ), seed )
[docs] def sample_labels(self, n_sample: int) -> pt.Tensor: """ Just get as many labels as data points (n_sample). """ return self.labels_s.mean.unsqueeze(0).expand(n_sample, -1, -1)
[docs] def reset_mean( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] ): assert xi_labels is not None ClassificationNormalIdSampler.__init__( self, xi, xi_labels, seed=self.seed, tril=self.data_s._unbroadcasted_scale_tril, )
[docs] class ClassificationNormalBernouilliSampler(LabeledSampler, IsOptionalCovarianceSampler): data_s: dst.MultivariateNormal labels_s: dst.TransformedDistribution def __init__( self, p: float, xi: pt.Tensor, xi_labels: pt.Tensor, *, sigma: Optional[Union[float, pt.Tensor]] = None, tril: Optional[pt.Tensor] = None, prec: Optional[pt.Tensor] = None, cov: Optional[pt.Tensor] = None, seed: Optional[int], ): r""" Example of an available sampler for classification problems. - inputs are sampled from a gaussian distribution - targets are sampled randomly in :math:`\{0, 1\}`. Specify the parameters of the distributions as keywords arguments. Parameters ---------- p: float|Tensor Probability of switch for bernouilli. xi: pt.Tensor mean for inputs xi_labels: pt.Tensor mean for targets sigma: float|Tensor scalar standard deviation shared through dimensions, for inputs. See :py:class:`skwdro.base.samplers.torch.base_samplers.IsOptionalCovarianceSampler` for other arguments. """ assert 0. <= p <= 1. covar = self.init_covar(xi.size(-1), sigma, tril, prec, cov) self.p = p super(ClassificationNormalBernouilliSampler, self).__init__( dst.MultivariateNormal( loc=xi, **covar # type: ignore ), dst.TransformedDistribution( dst.Bernoulli( p ), dst.transforms.AffineTransform( loc=-xi_labels, scale=2 * xi_labels ) ), seed )
[docs] def sample_labels(self, n_sample: int) -> pt.Tensor: """ Overrides w/ ``sample`` to prevent ``rsample`` from crashing since bernouilli isn't reparametrizeable. """ zeta_labels = self.labels_s.sample(pt.Size((n_sample,))) assert isinstance(zeta_labels, pt.Tensor) return zeta_labels
[docs] def reset_mean( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] ): assert xi_labels is not None ClassificationNormalBernouilliSampler.__init__( self, self.p, xi, xi_labels, seed=self.seed, tril=self.data_s._unbroadcasted_scale_tril, )