from typing import Tuple, Optional, Union, Callable
import torch as pt
import torch.nn as nn
from skwdro.base.costs_torch import Cost
from skwdro.base.cost_decoder import ParsedCost, cost_from_parse_torch, parse_code_torch
from skwdro.base.losses_torch import WrappedPrimalLoss
from skwdro.base.samplers.torch.base_samplers import BaseSampler
from skwdro.base.samplers.torch.cost_samplers import LabeledCostSampler, NoLabelsCostSampler
from skwdro.solvers._dual_interfaces import _DualLoss
from skwdro.solvers.oracle_torch import DualPostSampledLoss, DualPreSampledLoss
from skwdro.solvers.utils import Steps
SIGMA_FACTOR: float = .5
EPSILON_SIGMA_FACTOR: float = 1e-2
DEFAULT_COST_SPEC: Tuple[float, float] = (2, 2)
[docs]
def expert_hyperparams(
rho: pt.Tensor,
p: float,
epsilon: Optional[float],
epsilon_sigma_factor: float,
sigma: Optional[float],
sigma_factor: float,
) -> Tuple[pt.Tensor, pt.Tensor]:
r"""
Tuning of the hyperparameters for the dual loss.
Parameters
----------
rho: Tensor, shape (n_samples,)
Wasserstein radius
p: float
power of norm
epsilon: float
Epsilon if hard coded, ``None`` to let the algo find it.
epsilon_sigma_factor: float
Estimated ratio :math:`\frac{\epsilon}{\sigma}`
sigma: float
Sigma if hard coded, ``None`` to let the algo find it.
sigma_factor: float
Estimated ratio :math:`\frac{\sigma}{\rho}`
"""
expert_sigma: pt.Tensor
expert_epsilon: pt.Tensor
# Sigma init
if sigma is None:
if rho > 0.:
expert_sigma = rho * sigma_factor
else:
expert_sigma = pt.tensor(sigma_factor)
else:
expert_sigma = pt.tensor(sigma)
# Epsilon init
if epsilon is None:
epsilon_factor = epsilon_sigma_factor * sigma_factor**p
expert_epsilon = pt.max(
epsilon_factor * rho.pow(p - 1), # epsilon ^ (p/q)
pt.tensor(1e-7)
)
else:
expert_epsilon = pt.tensor(epsilon)
return expert_sigma, expert_epsilon
[docs]
def power_from_parsed_spec(parsed_spec: Optional[ParsedCost]) -> float:
if parsed_spec is None:
return 2.
else:
return parsed_spec.power
[docs]
def decide_on_impsamp(
user_query: bool,
cost: ParsedCost,
) -> bool:
return user_query and cost.can_imp_samp()
[docs]
def dualize_primal_loss(
loss_: Union[
nn.Module,
Callable[..., pt.Tensor]
],
transform_: Optional[nn.Module],
rho: pt.Tensor,
xi_batchinit: pt.Tensor,
xi_labels_batchinit: Optional[pt.Tensor],
post_sample: bool = True,
cost_spec: Optional[str] = None,
n_samples: int = 10,
seed: int = 42,
*,
reduction: Optional[str] = None,
learning_rate: Optional[float] = None,
epsilon: Optional[float] = None,
sigma: Optional[float] = None,
l2reg: Optional[float] = None,
adapt: Optional[str] = "prodigy",
n_iter: Optional[Steps] = None,
imp_samp: bool = True,
loss_reduces_spatial_dims: bool = False
) -> _DualLoss:
r"""
Provide the wrapped version of the primal loss.
Parameters
----------
loss_: nn.Module|Callable
the primal loss :math:`L_\theta`. Can be given either as a
:py:class:`torch.nn.Module` or as a (functional) callable.
transform_: nn.Module|None
the transformation to apply to the (non-label) data before feeding it to
the loss. Identity if set to ``None`` (default).
rho: Tensor, scalar tensor
Wasserstein radius
xi_batchinit: Tensor, shape (n_samples, n_features)
Data points to initialize the samplers and :math:`\lambda_0`
xi_labels_batchinit: Optional[Tensor], shape (n_samples, n_features)
Labels to initialize the samplers and :math:`\lambda_0`
post_sample: bool
whether to use a post-sampled dual loss
cost_spec: str|None
the cost specification in the format ``(k, p)`` for a sample k-norm
and p-power. ``None`` to use the default ``(2, 2)``.
n_samples: int
number of :math:`\zeta` samples to draw before the gradient
descent begins (can be changed if needed between inferences)
seed: int
the seed for the samplers
reduction: str | None
specifies the reduction to apply to the outer expectation of the
SkWDRO formula applied: ``'none'`` | ``'mean'`` | ``'sum'``.
- ``'none'``: no reduction will be applied,
- ``'mean'``: the sum of the output will be divided by the number of
elements in the output,
- ``'sum'``: the output will be summed.
Default: ``None`` which translates to ``'mean'``
learning_rate: float
the step size for the default descent algorithm linked to the loss
function
epsilon: float|None
Epsilon if hard coded, ``None`` to let the algo find it.
sigma: float|None
Sigma if hard coded, ``None`` to let the algo find it.
l2reg: float|None
L2 regularization if needed
adapt: str|None
the adaptative step to use between `"prodigy"` and `"mechanic"`.
n_iter: int|tuple[int, int]|None
can set the default number of iterations if used through the default
solving routines. Mostly an internal parameter. If int, it is the
number of internal robust optimization steps, if a 2-uple of ints, it is
the number of erm steps preceding the robust solve then the number of
robust steps, if None it will be filled by default.
imp_samp: bool
whether to use importance sampling
(will work only for ``(2, 2)`` costs).
loss_reduces_spatial_dims: bool
flag that can be set to ``True`` if the primal :py:attr:`loss` reduces
the last dimension of the losses batch with its reduction set to
``'none'``, e.g. for :py:class:`torch.CrossEntropyLoss` which will take
one dimension as channel axis, defaults to ``False``
"""
sampler: BaseSampler
cost: Cost
has_labels = xi_labels_batchinit is not None
if has_labels:
assert isinstance(xi_labels_batchinit, pt.Tensor), ' '.join([
"Please provide a starting",
"(mini/full)batch of labels",
"to initialize the samplers"
])
parsed_cost = parse_code_torch(cost_spec, has_labels)
expert_sigma, expert_epsilon = expert_hyperparams(
rho,
power_from_parsed_spec(parsed_cost),
epsilon,
EPSILON_SIGMA_FACTOR,
sigma,
SIGMA_FACTOR
)
expert_sigma = expert_sigma.to(xi_batchinit)
expert_epsilon = expert_epsilon.to(xi_batchinit)
cost = cost_from_parse_torch(parsed_cost)
if has_labels:
assert xi_labels_batchinit is not None
sampler = LabeledCostSampler(
cost,
xi_batchinit,
xi_labels_batchinit,
expert_sigma,
seed
)
else:
sampler = NoLabelsCostSampler(cost, xi_batchinit, expert_sigma, seed)
loss = WrappedPrimalLoss(
loss_, transform_, sampler, has_labels, l2reg=l2reg,
reduce_spatial_dims=not loss_reduces_spatial_dims
)
loss_constructor = (
DualPostSampledLoss if post_sample
else DualPreSampledLoss
)
return loss_constructor(
loss,
cost,
n_iter=(
(200, 2800) if post_sample else (100, 10)
) if n_iter is None else n_iter,
rho_0=rho,
n_samples=n_samples,
epsilon_0=expert_epsilon,
reduction=reduction,
imp_samp=decide_on_impsamp(imp_samp, parsed_cost),
learning_rate=learning_rate,
adapt=adapt,
)