skwdro.base.samplers.torch module#

API

This module exposes a base class BaseSampler that you can subclass to build your own \(\pi_0\) reference transport plan. It is defined through its right-conditional \(\nu_\xi(\zeta)=\pi_0(\zeta|\xi)\) (by slight abuse of notation), as its first marginal is fixed to be \(\hat{\mathbb{P}}^N\) the dataset. Formally it is defined through the disintegration lemma and its marginal property is required in order to achieve some technical feasability conditions, cf. [1] for clear explanations of these purposes.

Those classes have a BaseSampler.reset_mean() method to dynamicaly change the mean(s) of the generating distributions.

class skwdro.base.samplers.torch.BaseSampler(seed: int | None = None)[source]

Bases: ABC

abstractmethod log_prob(zeta: Tensor, zeta_labels: Tensor | None) Tensor[source]
abstractmethod log_prob_recentered(xi: Tensor, xi_labels: Tensor | None, zeta: Tensor, zeta_labels: Tensor | None) Tensor[source]
abstract property produces_labels: bool
abstractmethod reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

abstractmethod sample(n_samples: int) Tuple[Tensor, Tensor | None][source]

Override this method to make a custom sampling mechanism from scratch. It should output a pair of tensors for xi and xi_labels.

Parameters:
n_samples: int

number of samples to draw

Returns:
zeta: torch.Tensor

input samples drawn

zeta_labels: torch.Tensor|None

input targets drawn, if any

seed: int | None
class skwdro.base.samplers.torch.ClassificationNormalBernouilliSampler(p: float, xi: Tensor, xi_labels: Tensor, *, sigma: float | Tensor | None = None, tril: Tensor | None = None, prec: Tensor | None = None, cov: Tensor | None = None, seed: int | None)[source]

Bases: LabeledSampler, IsOptionalCovarianceSampler

data_s: MultivariateNormal
labels_s: TransformedDistribution
reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

sample_labels(n_sample: int) Tensor[source]

Overrides w/ sample to prevent rsample from crashing since bernouilli isn’t reparametrizeable.

class skwdro.base.samplers.torch.ClassificationNormalIdSampler(xi: Tensor, xi_labels: Tensor, *, sigma: float | Tensor | None = None, tril: Tensor | None = None, prec: Tensor | None = None, cov: Tensor | None = None, seed: int | None)[source]

Bases: LabeledSampler, IsOptionalCovarianceSampler

data_s: MultivariateNormal
labels_s: Dirac
reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

sample_labels(n_sample: int) Tensor[source]

Just get as many labels as data points (n_sample).

class skwdro.base.samplers.torch.ClassificationNormalNormalSampler(xi: Tensor, xi_labels: Tensor, *, sigma: float | Tensor | None = None, tril: Tensor | None = None, prec: Tensor | None = None, cov: Tensor | None = None, l_sigma: float | Tensor | None = None, l_tril: Tensor | None = None, l_prec: Tensor | None = None, l_cov: Tensor | None = None, seed: int | None = None)[source]

Bases: LabeledSampler, IsOptionalCovarianceSampler

data_s: MultivariateNormal
labels_s: MultivariateNormal
reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

class skwdro.base.samplers.torch.IsOptionalCovarianceSampler[source]

Bases: ABC

init_covar(d: int, sigma: float | Tensor | None = None, tril: Tensor | None = None, prec: Tensor | None = None, cov: Tensor | None = None) Dict[str, Tensor][source]

Sets up the covariance matrix in the correct format to give as a kwarg to torch distributions. Order of importance for non-None values: * sigma: defines Id/sigma**2 as cov matrix, given as L^T@L * tril: defines L s.t. C=L^TL * cov: defines the full C matrix * prec: defines the precision matrix C^-1, only useful for fast CDF computation and bad otherwise

class skwdro.base.samplers.torch.LabeledCostSampler(cost: TorchCost, xi: Tensor, xi_labels: Tensor, sigma, seed: int | None = None)[source]

Bases: LabeledSampler

reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

class skwdro.base.samplers.torch.LabeledSampler(data_sampler: Distribution, labels_sampler: Distribution, seed: int | None)[source]

Bases: BaseSampler, ABC

log_prob(zeta: Tensor, zeta_labels: Tensor | None) Tensor[source]
log_prob_recentered(xi: Tensor, xi_labels: Tensor | None, zeta: Tensor, zeta_labels: Tensor | None) Tensor[source]
property produces_labels
sample(n_samples: int)[source]

Override this method to make a custom sampling mechanism from scratch. It should output a pair of tensors for xi and xi_labels.

Parameters:
n_samples: int

number of samples to draw

Returns:
zeta: torch.Tensor

input samples drawn

zeta_labels: torch.Tensor|None

input targets drawn, if any

sample_data(n_sample: int)[source]
sample_labels(n_sample: int)[source]
class skwdro.base.samplers.torch.NewsVendorNormalSampler(xi, *, sigma: float | Tensor | None = None, tril: Tensor | None = None, prec: Tensor | None = None, cov: Tensor | None = None, seed: int | None = None)[source]

Bases: NoLabelsSampler, IsOptionalCovarianceSampler

data_s: MultivariateNormal
reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

class skwdro.base.samplers.torch.NoLabelsCostSampler(cost: TorchCost, xi: Tensor, sigma, seed: int | None = None)[source]

Bases: NoLabelsSampler

reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

class skwdro.base.samplers.torch.NoLabelsSampler(data_sampler: Distribution, seed: int | None)[source]

Bases: BaseSampler, ABC

log_prob(zeta: Tensor, zeta_labels: Tensor | None) Tensor[source]
log_prob_recentered(xi: Tensor, xi_labels: Tensor | None, zeta: Tensor, zeta_labels: Tensor | None) Tensor[source]
property produces_labels
sample(n_samples: int)[source]

Override this method to make a custom sampling mechanism from scratch. It should output a pair of tensors for xi and xi_labels.

Parameters:
n_samples: int

number of samples to draw

Returns:
zeta: torch.Tensor

input samples drawn

zeta_labels: torch.Tensor|None

input targets drawn, if any

class skwdro.base.samplers.torch.PortfolioLaplaceSampler(xi, *, sigma: float | Tensor | None = None, tril: Tensor | None = None, prec: Tensor | None = None, cov: Tensor | None = None, seed: int | None = None)[source]

Bases: NoLabelsSampler, IsOptionalCovarianceSampler

data_s: MultivariateNormal
reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

class skwdro.base.samplers.torch.PortfolioNormalSampler(xi, *, sigma: float | Tensor | None = None, tril: Tensor | None = None, prec: Tensor | None = None, cov: Tensor | None = None, seed: int | None = None)[source]

Bases: NoLabelsSampler, IsOptionalCovarianceSampler

data_s: MultivariateNormal
reset_mean(xi: Tensor, xi_labels: Tensor | None)[source]

Reset the sampler instance parametrization. Must be overriden when a subclass is made to describe it.

Parameters:
xi: torch.Tensor

part of the parametrization of the sampler that concerns the input variables

xi_labels: torch.Tensor|None

part of the parametrization of the sampler that concerns the labels variables

References#