Source code for skwdro.base.costs_torch.base_cost

from abc import ABC, abstractmethod
from typing import Optional, Tuple, overload
import torch as pt
import torch.nn as nn
import skwdro.distributions as dst

ENGINES_NAMES = {
    "pt": "PyTorch tensors",
    "jx": "Jax arrays"
}


[docs] class TorchCost(nn.Module, ABC): """ Base class for transport functions """ def __init__( self, name: str = "", engine: str = "" ): super(TorchCost, self).__init__() self.name = name self.engine = engine # Default power needs to be overwritten self.power = 1.0
[docs] def forward( self, xi: pt.Tensor, zeta: pt.Tensor, xi_labels: Optional[pt.Tensor] = None, zeta_labels: Optional[pt.Tensor] = None ) -> pt.Tensor: """ This function is called by default when using the __call__ dunder of pytorch modules: it sends directly to the :py:meth:`value` method. """ return self.value(xi, zeta, xi_labels, zeta_labels)
[docs] @abstractmethod def value( self, xi: pt.Tensor, zeta: pt.Tensor, xi_labels: Optional[pt.Tensor] = None, zeta_labels: Optional[pt.Tensor] = None ) -> pt.Tensor: del xi, zeta, xi_labels, zeta_labels raise NotImplementedError("Please Implement this method")
[docs] def sampler( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], epsilon: Optional[pt.Tensor] ) -> Tuple[dst.Distribution, Optional[dst.Distribution]]: return ( self._sampler_data(xi, epsilon), self._sampler_labels(xi_labels, epsilon) )
@abstractmethod def _sampler_data( self, xi: pt.Tensor, epsilon: Optional[pt.Tensor] ) -> dst.Distribution: del xi, epsilon raise NotImplementedError() @overload @abstractmethod def _sampler_labels( self, xi_labels: pt.Tensor, epsilon: Optional[pt.Tensor] ) -> dst.Distribution: del xi_labels, epsilon raise NotImplementedError() @overload @abstractmethod def _sampler_labels( self, xi_labels: None, epsilon: Optional[pt.Tensor] ) -> None: del xi_labels, epsilon raise NotImplementedError() @abstractmethod def _sampler_labels( self, xi_labels: Optional[pt.Tensor], epsilon: Optional[pt.Tensor] ) -> Optional[dst.Distribution]: del xi_labels, epsilon raise NotImplementedError() def __str__(self) -> str: return ' '.join([ "Cost named", self.name, "using as data:", ENGINES_NAMES[self.engine] ])
[docs] @abstractmethod def solve_max_series_exp( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], rhs: pt.Tensor, rhs_labels: Optional[pt.Tensor] ) -> Tuple[pt.Tensor, Optional[pt.Tensor]]: r""" Override this method to provide an explicit solution to the expansion of the inner supremum one would wish to solve if they were solving the usual WDRO approach: .. math:: \zeta^\texttt{imp_samp}:=\text{arg}\min_{\zeta} \left\langle\nabla_\xi L_\theta(\xi)\mid\zeta-\xi\right\rangle - \lambda c(\xi, \zeta). .. important:: This is an unconstrained first-order approximation of the supremum, which can be ill-posed or untractable, but is usually cheap enough for efficient importance sampling. One may attempt to implement higher-order approximations and add constraints if cheap enough solutions are available, for reasonably small models, if desired. """ del xi, rhs, xi_labels, rhs_labels raise NotImplementedError()