Source code for skwdro.base.costs_torch.normlabelcost

from typing import Optional, Tuple, overload

import torch as pt

import skwdro.distributions as dst
from .normcost import NormCost


[docs] class NormLabelCost(NormCost): r""" p-norm of the ground metric to change data + label Norm used to add cost to switching labels: .. math:: d_\kappa\left(\left[\begin{array}{c}\bm{X}\\y\end{array}\right], \left[\begin{array}{c}\bm{X'}\\y'\end{array}\right]\right) := \|\bm{X}-\bm{X'}\|+\kappa |y-y'| """ def __init__( self, p: float = 2., power: float = 1., kappa: float = 1e4, name: Optional[str] = None ): r""" Constructor """ super().__init__( power=power, p=p, name="Kappa-norm" if name is None else name ) self.kappa = kappa assert kappa >= 0, ' '.join([ f"Input kappa={kappa}<0", "is illicit since it 'encourages'", "flipping labels in the database,", "and thus makes no sense wrt the database", "in terms of 'trust' to the labels." ]) @classmethod def _label_penalty(cls, y: pt.Tensor, y_prime: pt.Tensor, p: float): return pt.norm(y - y_prime, p=p, dim=-1, keepdim=True) @classmethod def _data_penalty(cls, x: pt.Tensor, x_prime: pt.Tensor, p: float): diff = x - x_prime return pt.norm(diff, p=p, dim=-1, keepdim=True) @overload def value( self, xi: pt.Tensor, zeta: pt.Tensor, xi_labels: pt.Tensor, zeta_labels: pt.Tensor ) -> pt.Tensor: r""" Cost to displace :math:`\xi:=\left[\begin{array}{c}\bm{X}\\y\end{array}\right]` to :math:`\zeta:=\left[\begin{array}{c}\bm{X'}\\y'\end{array}\right]` in :math:`mathbb{R}^n`. Parameters ---------- xi : Tensor, shape (n_samples, n_features) Data point to be displaced (without the label) zeta : Tensor, shape (n_samples, n_features) Data point towards which ``x`` is displaced xi_labels : Tensor, shape (n_samples, n_features_y) Label or target for the problem/loss zeta_labels : Tensor, shape (n_samples, n_features_y) Label or target in the dataset """ pass @overload def value( self, xi: pt.Tensor, zeta: pt.Tensor, xi_labels: None = None, zeta_labels: None = None ) -> pt.Tensor: raise AssertionError() @overload def value( self, xi: pt.Tensor, zeta: pt.Tensor, xi_labels: Optional[pt.Tensor] = None, zeta_labels: Optional[pt.Tensor] = None ) -> pt.Tensor: raise AssertionError()
[docs] def value( self, xi: pt.Tensor, zeta: pt.Tensor, xi_labels: Optional[pt.Tensor] = None, zeta_labels: Optional[pt.Tensor] = None ) -> pt.Tensor: assert xi_labels is not None and zeta_labels is not None _c: pt.Tensor if float(self.kappa) == float("inf"): # Writing convention: if kappa=+oo we put all cost on switching # labels so the cost is reported on y. # To provide a tractable computation, we yield the y-penalty alone. _c = self._label_penalty( xi_labels, zeta_labels, self.p )**self.power elif self.kappa == 0.: # Writing convention: if kappa is null we put all cost on moving # the data itself, so the worst-case distribution is free to switch # the labels. # Warning : this usecase should not make sense anyway. _c = self._data_penalty(xi, zeta, self.p)**self.power else: distance = self._data_penalty(xi, zeta, self.p) \ + self.kappa * \ self._label_penalty(xi_labels, zeta_labels, self.p) distance /= 1. + self.kappa _c = distance**self.power del distance return _c
@overload def _sampler_labels( self, xi_labels: pt.Tensor, epsilon: Optional[pt.Tensor] ) -> dst.Distribution: pass @overload def _sampler_labels( self, xi_labels: None, epsilon: Optional[pt.Tensor] ) -> None: raise ValueError() def _sampler_labels( self, xi_labels, epsilon ) -> Optional[pt.distributions.Distribution]: # d = xi_labels.size(-1) if epsilon is None: epsilon = pt.tensor(1e-3) elif not isinstance(epsilon, pt.Tensor): epsilon = pt.tensor(epsilon) if self.kappa == float('inf'): return dst.Dirac(xi_labels) if self.power == 1: if self.p == 1: return dst.Laplace( loc=xi_labels, scale=epsilon.to(xi_labels) / self.kappa ) elif self.p == 2: return dst.Normal( loc=xi_labels, scale=epsilon.to(xi_labels) / self.kappa ) elif self.p == pt.inf: Warning("For sup norm, we use a gaussian sampler by default.") return dst.Normal( loc=xi_labels, scale=epsilon.to(xi_labels) / self.kappa ) else: raise NotImplementedError() elif self.power == 2: if self.p == 2: return dst.Normal( loc=xi_labels, scale=epsilon.to(xi_labels) ) # return dst.MultivariateNormal( # loc=xi_labels, # scale_tril=epsilon * pt.eye(d) / self.kappa # ) else: raise NotImplementedError() else: raise NotImplementedError() @overload def solve_max_series_exp( self, xi: pt.Tensor, xi_labels: pt.Tensor, rhs: pt.Tensor, rhs_labels: pt.Tensor ) -> Tuple[pt.Tensor, pt.Tensor]: pass @overload def solve_max_series_exp( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], rhs: pt.Tensor, rhs_labels: Optional[pt.Tensor] ) -> Tuple[pt.Tensor, Optional[pt.Tensor]]: pass
[docs] def solve_max_series_exp( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], rhs: pt.Tensor, rhs_labels: Optional[pt.Tensor] ) -> Tuple[pt.Tensor, Optional[pt.Tensor]]: if xi_labels is not None and rhs_labels is not None: if self.p == 2 == self.power: return xi + .5 * rhs, xi_labels + .5 * rhs_labels / self.kappa else: raise NotImplementedError() else: return super().solve_max_series_exp(xi, xi_labels, rhs, rhs_labels)