Source code for skwdro.base.losses_torch.wrapper

from typing import Callable, Optional, Union
from itertools import chain
import torch as pt
import torch.nn as nn

from skwdro.base.losses_torch import Loss
from skwdro.base.samplers.torch.base_samplers import BaseSampler


class WrappingError(ValueError):
    pass


[docs] class WrappedPrimalLoss(Loss): has_labels: bool loss_oop_interface: bool = True reduce_spatial_dims: bool = True def __init__( self, loss: Union[ nn.Module, Callable[..., pt.Tensor] ], transform: Optional[nn.Module], sampler: BaseSampler, has_labels: bool, reduce_spatial_dims: bool = True, *, l2reg: Optional[float] = None ) -> None: r""" Provide the wrapped version of the primal loss. Parameters ---------- loss: nn.Module|Callable the primal loss :math:`L_\theta`. Can be given either as a :py:class:`torch.nn.Module` or as a (functional) callable. transform: nn.Module|None the transformation to apply to the (non-label) data before feeding it to the loss. Identity if set to ``None`` (default). sampler: :py:class:`skwdro.base.samplers.torch.base_samplers.BaseSampler` the sampling object that defines the way :math:`\zeta` samples are drawn has_labels: bool set to ``True`` if the loss accepts two inputs: a prediction and some kind of target. Otherwise, set to ``False``. .. warning:: It is *your* job to check that the :py:attr:`loss`, :py:attr:`_sampler`, and :py:attr:`has_labels` parameters are compatible with one another. reduce_spatial_dims: bool set to ``False`` if the loss reduces by default the last dimension of a batch of non-reduced losses. This may be useful for example if using :py:class:`~torch.nn.CrossEntropyLoss` which will consider your last axis as a channel axis (which might not be the case outside of computer vision applications). l2reg: float|None L2 regularization if needed Attributes ---------- """ super(WrappedPrimalLoss, self).__init__(sampler, has_labels, l2reg=l2reg) self.loss = loss if isinstance(loss, pt.nn.Module): assert loss.reduction == 'none', " ".join([ 'If you are using the OOP interface of PyTorch to define the', 'main loss functional, please set its reduction method to', '\"none\"' ]) self.loss_oop_interface = True else: assert callable(loss), " ".join([ 'If you are not using the OOP interface of PyTorch to define', 'the main loss functional, please use the functional interface', 'so that loss is at least callable, with a signature accepting', 'either: my_loss(input: Tensor, target: Tensor, reduction: str)', 'or my_loss(input: Tensor, reduction: str).' ]) self.loss_oop_interface = False self.reduce_spatial_dims = reduce_spatial_dims self.transform = transform if transform is not None else nn.Identity()
[docs] @classmethod def default_sampler( cls, xi, xi_labels, epsilon, seed: Optional[int] ) -> BaseSampler: del xi, xi_labels, epsilon, seed raise WrappingError( "No default sampler can be attributed by default by a wrapped loss.")
@property def theta(self): if self.loss_oop_interface: assert isinstance(self.loss, nn.Module) return pt.concat(list( map( pt.flatten, chain(self.loss.parameters(), self.transform.parameters()) ) )) else: assert callable(self.loss) return pt.concat(list( map( pt.flatten, self.transform.parameters() ) )) @property def intercept(self): return pt.tensor(0.) def _flat_value_w_labels(self, xi, xi_labels): if self.loss_oop_interface: return self.regularize(self.loss( self.transform(xi), xi_labels )) else: return self.regularize(self.loss( self.transform(xi), xi_labels, reduction='none' )) def _flat_value_wo_labels(self, xi): if self.loss_oop_interface: return self.regularize(self.loss( self.transform(xi) )) else: return self.regularize(self.loss( self.transform(xi), reduction='none' )) def _reduce_flat_spatial_dims_loss(self, losses: pt.Tensor) -> pt.Tensor: if self.reduce_spatial_dims: return losses.mean(dim=-1, keepdim=True) else: return losses.unsqueeze(-1)
[docs] def value( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] = None ) -> pt.Tensor: if self.has_labels: assert xi_labels is not None if xi.dim() > 2 and xi_labels.dim() > 2: # Forwarding zetas *b, _ = xi.size() flat_loss = self._flat_value_w_labels( xi.flatten(start_dim=0, end_dim=-2), xi_labels.flatten(start_dim=0, end_dim=-2) ) return self._reduce_flat_spatial_dims_loss(flat_loss).view(*b, 1) elif xi.dim() > 2 and xi_labels.dim() == 2: # Forwarding zetas *b, _ = xi.size() flat_loss = self._flat_value_w_labels( xi.flatten(start_dim=0, end_dim=-2), xi_labels ) return self._reduce_flat_spatial_dims_loss(flat_loss).view(*b, 1) elif xi.dim() == 2 and xi_labels.dim() <= 2: # Forwarding xis flat_loss = self._flat_value_w_labels( xi, xi_labels ).squeeze() return self._reduce_flat_spatial_dims_loss(flat_loss) elif xi.dim() == xi_labels.dim() == 1: # Forwarding xis (no batch dim) b_xi = xi.unsqueeze(0) # need to consider as zetas, will squeeze later b_xi_labels = xi_labels.unsqueeze(0) # need to consider as zeta_labels, will squeeze later flat_loss = self._flat_value_w_labels(b_xi, b_xi_labels) return self._reduce_flat_spatial_dims_loss(flat_loss).squeeze() else: raise NotImplementedError() else: assert xi_labels is None if xi.dim() > 2: # Forwarding zetas *b, _ = xi.size() flat_loss = self._flat_value_wo_labels( xi.flatten(start_dim=0, end_dim=-2) ) return self._reduce_flat_spatial_dims_loss(flat_loss).view(*b, 1) elif xi.dim() == 2: # Forwarding xis return self._reduce_flat_spatial_dims_loss( self._flat_value_wo_labels(xi) ).squeeze() elif xi.dim() == 1: # Forwarding xis (no batch dim) b_xi = xi.unsqueeze(0) # need to consider as zetas, will squeeze later return self._reduce_flat_spatial_dims_loss( self._flat_value_wo_labels(b_xi) ).squeeze() else: raise NotImplementedError()