Source code for skwdro.wrap_problem

from typing import Tuple, Optional, Union, Callable

import torch as pt
import torch.nn as nn

from skwdro.base.costs_torch import Cost
from skwdro.base.cost_decoder import ParsedCost, cost_from_parse_torch, parse_code_torch
from skwdro.base.losses_torch import WrappedPrimalLoss
from skwdro.base.samplers.torch.base_samplers import BaseSampler
from skwdro.base.samplers.torch.cost_samplers import LabeledCostSampler, NoLabelsCostSampler
from skwdro.solvers._dual_interfaces import _DualLoss
from skwdro.solvers.oracle_torch import DualPostSampledLoss, DualPreSampledLoss
from skwdro.solvers.utils import Steps

SIGMA_FACTOR: float = .5
EPSILON_SIGMA_FACTOR: float = 1e-2
DEFAULT_COST_SPEC: Tuple[float, float] = (2, 2)


[docs] def expert_hyperparams( rho: pt.Tensor, p: float, epsilon: Optional[float], epsilon_sigma_factor: float, sigma: Optional[float], sigma_factor: float, ) -> Tuple[pt.Tensor, pt.Tensor]: r""" Tuning of the hyperparameters for the dual loss. Parameters ---------- rho: Tensor, shape (n_samples,) Wasserstein radius p: float power of norm epsilon: float Epsilon if hard coded, ``None`` to let the algo find it. epsilon_sigma_factor: float Estimated ratio :math:`\frac{\epsilon}{\sigma}` sigma: float Sigma if hard coded, ``None`` to let the algo find it. sigma_factor: float Estimated ratio :math:`\frac{\sigma}{\rho}` """ expert_sigma: pt.Tensor expert_epsilon: pt.Tensor # Sigma init if sigma is None: if rho > 0.: expert_sigma = rho * sigma_factor else: expert_sigma = pt.tensor(sigma_factor) else: expert_sigma = pt.tensor(sigma) # Epsilon init if epsilon is None: epsilon_factor = epsilon_sigma_factor * sigma_factor**p expert_epsilon = pt.max( epsilon_factor * rho.pow(p - 1), # epsilon ^ (p/q) pt.tensor(1e-7) ) else: expert_epsilon = pt.tensor(epsilon) return expert_sigma, expert_epsilon
[docs] def power_from_parsed_spec(parsed_spec: Optional[ParsedCost]) -> float: if parsed_spec is None: return 2. else: return parsed_spec.power
[docs] def decide_on_impsamp( user_query: bool, cost: ParsedCost, ) -> bool: return user_query and cost.can_imp_samp()
[docs] def dualize_primal_loss( loss_: Union[ nn.Module, Callable[..., pt.Tensor] ], transform_: Optional[nn.Module], rho: pt.Tensor, xi_batchinit: pt.Tensor, xi_labels_batchinit: Optional[pt.Tensor], post_sample: bool = True, cost_spec: Optional[str] = None, n_samples: int = 10, seed: int = 42, *, reduction: Optional[str] = None, learning_rate: Optional[float] = None, epsilon: Optional[float] = None, sigma: Optional[float] = None, l2reg: Optional[float] = None, adapt: Optional[str] = "prodigy", n_iter: Optional[Steps] = None, imp_samp: bool = True, loss_reduces_spatial_dims: bool = False ) -> _DualLoss: r""" Provide the wrapped version of the primal loss. Parameters ---------- loss_: nn.Module|Callable the primal loss :math:`L_\theta`. Can be given either as a :py:class:`torch.nn.Module` or as a (functional) callable. transform_: nn.Module|None the transformation to apply to the (non-label) data before feeding it to the loss. Identity if set to ``None`` (default). rho: Tensor, scalar tensor Wasserstein radius xi_batchinit: Tensor, shape (n_samples, n_features) Data points to initialize the samplers and :math:`\lambda_0` xi_labels_batchinit: Optional[Tensor], shape (n_samples, n_features) Labels to initialize the samplers and :math:`\lambda_0` post_sample: bool whether to use a post-sampled dual loss cost_spec: str|None the cost specification in the format ``(k, p)`` for a sample k-norm and p-power. ``None`` to use the default ``(2, 2)``. n_samples: int number of :math:`\zeta` samples to draw before the gradient descent begins (can be changed if needed between inferences) seed: int the seed for the samplers reduction: str | None specifies the reduction to apply to the outer expectation of the SkWDRO formula applied: ``'none'`` | ``'mean'`` | ``'sum'``. - ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of elements in the output, - ``'sum'``: the output will be summed. Default: ``None`` which translates to ``'mean'`` learning_rate: float the step size for the default descent algorithm linked to the loss function epsilon: float|None Epsilon if hard coded, ``None`` to let the algo find it. sigma: float|None Sigma if hard coded, ``None`` to let the algo find it. l2reg: float|None L2 regularization if needed adapt: str|None the adaptative step to use between `"prodigy"` and `"mechanic"`. n_iter: int|tuple[int, int]|None can set the default number of iterations if used through the default solving routines. Mostly an internal parameter. If int, it is the number of internal robust optimization steps, if a 2-uple of ints, it is the number of erm steps preceding the robust solve then the number of robust steps, if None it will be filled by default. imp_samp: bool whether to use importance sampling (will work only for ``(2, 2)`` costs). loss_reduces_spatial_dims: bool flag that can be set to ``True`` if the primal :py:attr:`loss` reduces the last dimension of the losses batch with its reduction set to ``'none'``, e.g. for :py:class:`torch.CrossEntropyLoss` which will take one dimension as channel axis, defaults to ``False`` """ sampler: BaseSampler cost: Cost has_labels = xi_labels_batchinit is not None if has_labels: assert isinstance(xi_labels_batchinit, pt.Tensor), ' '.join([ "Please provide a starting", "(mini/full)batch of labels", "to initialize the samplers" ]) parsed_cost = parse_code_torch(cost_spec, has_labels) expert_sigma, expert_epsilon = expert_hyperparams( rho, power_from_parsed_spec(parsed_cost), epsilon, EPSILON_SIGMA_FACTOR, sigma, SIGMA_FACTOR ) expert_sigma = expert_sigma.to(xi_batchinit) expert_epsilon = expert_epsilon.to(xi_batchinit) cost = cost_from_parse_torch(parsed_cost) if has_labels: assert xi_labels_batchinit is not None sampler = LabeledCostSampler( cost, xi_batchinit, xi_labels_batchinit, expert_sigma, seed ) else: sampler = NoLabelsCostSampler(cost, xi_batchinit, expert_sigma, seed) loss = WrappedPrimalLoss( loss_, transform_, sampler, has_labels, l2reg=l2reg, reduce_spatial_dims=not loss_reduces_spatial_dims ) loss_constructor = ( DualPostSampledLoss if post_sample else DualPreSampledLoss ) return loss_constructor( loss, cost, n_iter=( (200, 2800) if post_sample else (100, 10) ) if n_iter is None else n_iter, rho_0=rho, n_samples=n_samples, epsilon_0=expert_epsilon, reduction=reduction, imp_samp=decide_on_impsamp(imp_samp, parsed_cost), learning_rate=learning_rate, adapt=adapt, )