from typing import Callable, Optional, Union
from itertools import chain
import torch as pt
import torch.nn as nn
from skwdro.base.losses_torch import Loss
from skwdro.base.samplers.torch.base_samplers import BaseSampler
class WrappingError(ValueError):
pass
[docs]
class WrappedPrimalLoss(Loss):
has_labels: bool
loss_oop_interface: bool = True
reduce_spatial_dims: bool = True
def __init__(
self,
loss: Union[
nn.Module,
Callable[..., pt.Tensor]
],
transform: Optional[nn.Module],
sampler: BaseSampler,
has_labels: bool,
reduce_spatial_dims: bool = True,
*,
l2reg: Optional[float] = None
) -> None:
r"""
Provide the wrapped version of the primal loss.
Parameters
----------
loss: nn.Module|Callable
the primal loss :math:`L_\theta`. Can be given either as a
:py:class:`torch.nn.Module` or as a (functional) callable.
transform: nn.Module|None
the transformation to apply to the (non-label) data before feeding it
to the loss. Identity if set to ``None`` (default).
sampler: :py:class:`skwdro.base.samplers.torch.base_samplers.BaseSampler`
the sampling object that defines the way :math:`\zeta` samples are
drawn
has_labels: bool
set to ``True`` if the loss accepts two inputs: a prediction and some
kind of target. Otherwise, set to ``False``.
.. warning:: It is *your* job to check that the :py:attr:`loss`,
:py:attr:`_sampler`, and :py:attr:`has_labels` parameters are
compatible with one another.
reduce_spatial_dims: bool
set to ``False`` if the loss reduces by default the last dimension
of a batch of non-reduced losses. This may be useful for example if
using :py:class:`~torch.nn.CrossEntropyLoss` which will consider your
last axis as a channel axis (which might not be the case outside of
computer vision applications).
l2reg: float|None
L2 regularization if needed
Attributes
----------
"""
super(WrappedPrimalLoss, self).__init__(sampler, has_labels, l2reg=l2reg)
self.loss = loss
if isinstance(loss, pt.nn.Module):
assert loss.reduction == 'none', " ".join([
'If you are using the OOP interface of PyTorch to define the',
'main loss functional, please set its reduction method to',
'\"none\"'
])
self.loss_oop_interface = True
else:
assert callable(loss), " ".join([
'If you are not using the OOP interface of PyTorch to define',
'the main loss functional, please use the functional interface',
'so that loss is at least callable, with a signature accepting',
'either: my_loss(input: Tensor, target: Tensor, reduction: str)',
'or my_loss(input: Tensor, reduction: str).'
])
self.loss_oop_interface = False
self.reduce_spatial_dims = reduce_spatial_dims
self.transform = transform if transform is not None else nn.Identity()
[docs]
@classmethod
def default_sampler(
cls,
xi, xi_labels,
epsilon, seed: Optional[int]
) -> BaseSampler:
del xi, xi_labels, epsilon, seed
raise WrappingError(
"No default sampler can be attributed by default by a wrapped loss.")
@property
def theta(self):
if self.loss_oop_interface:
assert isinstance(self.loss, nn.Module)
return pt.concat(list(
map(
pt.flatten,
chain(self.loss.parameters(), self.transform.parameters())
)
))
else:
assert callable(self.loss)
return pt.concat(list(
map(
pt.flatten,
self.transform.parameters()
)
))
@property
def intercept(self):
return pt.tensor(0.)
def _flat_value_w_labels(self, xi, xi_labels):
if self.loss_oop_interface:
return self.regularize(self.loss(
self.transform(xi),
xi_labels
))
else:
return self.regularize(self.loss(
self.transform(xi),
xi_labels,
reduction='none'
))
def _flat_value_wo_labels(self, xi):
if self.loss_oop_interface:
return self.regularize(self.loss(
self.transform(xi)
))
else:
return self.regularize(self.loss(
self.transform(xi),
reduction='none'
))
def _reduce_flat_spatial_dims_loss(self, losses: pt.Tensor) -> pt.Tensor:
if self.reduce_spatial_dims:
return losses.mean(dim=-1, keepdim=True)
else:
return losses.unsqueeze(-1)
[docs]
def value(
self,
xi: pt.Tensor, xi_labels: Optional[pt.Tensor] = None
) -> pt.Tensor:
if self.has_labels:
assert xi_labels is not None
if xi.dim() > 2 and xi_labels.dim() > 2:
# Forwarding zetas
*b, _ = xi.size()
flat_loss = self._flat_value_w_labels(
xi.flatten(start_dim=0, end_dim=-2),
xi_labels.flatten(start_dim=0, end_dim=-2)
)
return self._reduce_flat_spatial_dims_loss(flat_loss).view(*b, 1)
elif xi.dim() > 2 and xi_labels.dim() == 2:
# Forwarding zetas
*b, _ = xi.size()
flat_loss = self._flat_value_w_labels(
xi.flatten(start_dim=0, end_dim=-2),
xi_labels
)
return self._reduce_flat_spatial_dims_loss(flat_loss).view(*b, 1)
elif xi.dim() == 2 and xi_labels.dim() <= 2:
# Forwarding xis
flat_loss = self._flat_value_w_labels(
xi, xi_labels
).squeeze()
return self._reduce_flat_spatial_dims_loss(flat_loss)
elif xi.dim() == xi_labels.dim() == 1:
# Forwarding xis (no batch dim)
b_xi = xi.unsqueeze(0) # need to consider as zetas, will squeeze later
b_xi_labels = xi_labels.unsqueeze(0) # need to consider as zeta_labels, will squeeze later
flat_loss = self._flat_value_w_labels(b_xi, b_xi_labels)
return self._reduce_flat_spatial_dims_loss(flat_loss).squeeze()
else:
raise NotImplementedError()
else:
assert xi_labels is None
if xi.dim() > 2:
# Forwarding zetas
*b, _ = xi.size()
flat_loss = self._flat_value_wo_labels(
xi.flatten(start_dim=0, end_dim=-2)
)
return self._reduce_flat_spatial_dims_loss(flat_loss).view(*b, 1)
elif xi.dim() == 2:
# Forwarding xis
return self._reduce_flat_spatial_dims_loss(
self._flat_value_wo_labels(xi)
).squeeze()
elif xi.dim() == 1:
# Forwarding xis (no batch dim)
b_xi = xi.unsqueeze(0) # need to consider as zetas, will squeeze later
return self._reduce_flat_spatial_dims_loss(
self._flat_value_wo_labels(b_xi)
).squeeze()
else:
raise NotImplementedError()