Source code for skwdro.base.samplers.torch.cost_samplers

from typing import Optional
import torch as pt

from skwdro.base.samplers.torch.base_samplers import LabeledSampler, NoLabelsSampler
from skwdro.base.costs_torch import Cost


[docs] class NoLabelsCostSampler(NoLabelsSampler): def __init__( self, cost: Cost, xi: pt.Tensor, sigma, seed: Optional[int] = None, ): """ Parent class of all samplers that only sample inputs, with a specification drawn from a cost functional (:py:class:`skwdro.base.costs_torch.Cost`). Parameters ---------- cost: Cost cost functional specifying the samp;ling behaviour through its :py:method:`skwdro.base.costs_torch.Cost.sampler` method. xi: pt.Tensor mean for inputs sigma: float|Tensor scalar standard deviation shared through dimensions, for inputs. See :py:class:`skwdro.base.samplers.torch.base_samplers.IsOptionalCovarianceSampler` for other arguments. """ super(NoLabelsCostSampler, self).__init__( cost._sampler_data(xi, sigma), seed ) self.generating_cost = cost self.sigma = sigma
[docs] def reset_mean( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] ): del xi_labels NoLabelsCostSampler.__init__(self, self.generating_cost, xi, self.sigma, self.seed)
[docs] class LabeledCostSampler(LabeledSampler): def __init__( self, cost: Cost, xi: pt.Tensor, xi_labels: pt.Tensor, sigma, seed: Optional[int] = None ): """ Parent class of all samplers that sample both inputs and labels, with a specification drawn from a cost functional (:py:class:`skwdro.base.costs_torch.Cost`). Parameters ---------- cost: Cost cost functional specifying the samp;ling behaviour through its :py:method:`skwdro.base.costs_torch.Cost.sampler` method. xi: pt.Tensor mean for inputs xi_labels: pt.Tensor mean for targets sigma: float|Tensor scalar standard deviation shared through dimensions, for inputs. See :py:class:`skwdro.base.samplers.torch.base_samplers.IsOptionalCovarianceSampler` for other arguments. """ sd, sl = ( cost._sampler_data(xi, sigma), cost._sampler_labels(xi_labels, sigma) ) if sl is None: raise ValueError("Please choose a cost that can sample labels") super(LabeledCostSampler, self).__init__(sd, sl, seed) self.generating_cost = cost self.sigma = sigma
[docs] def reset_mean( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] ): assert xi_labels is not None LabeledCostSampler.__init__( self, self.generating_cost, xi, xi_labels, self.sigma, self.seed )