Source code for skwdro.wrap_problem

from typing import Tuple, Optional

import torch as pt
import torch.nn as nn

from skwdro.base.costs_torch import Cost
from skwdro.base.cost_decoder import ParsedCost, cost_from_parse_torch, parse_code_torch
from skwdro.base.losses_torch import WrappedPrimalLoss
from skwdro.base.samplers.torch.base_samplers import BaseSampler
from skwdro.base.samplers.torch.cost_samplers import LabeledCostSampler, NoLabelsCostSampler
from skwdro.solvers._dual_interfaces import _DualLoss
from skwdro.solvers.oracle_torch import DualPostSampledLoss, DualPreSampledLoss

SIGMA_FACTOR: float = .5
EPSILON_SIGMA_FACTOR: float = 1e-2
DEFAULT_COST_SPEC: Tuple[float, float] = (2, 2)


[docs] def expert_hyperparams( rho: pt.Tensor, p: float, epsilon: Optional[float], epsilon_sigma_factor: float, sigma: Optional[float], sigma_factor: float, ) -> Tuple[pt.Tensor, pt.Tensor]: r""" Tuning of the hyperparameters for the dual loss. Parameters ---------- rho: Tensor, shape (n_samples,) Wasserstein radius p: float power of norm epsilon: float Epsilon if hard coded, ``None`` to let the algo find it. epsilon_sigma_factor: float Estimated ratio :math:`\frac{\epsilon}{\sigma}` sigma: float Sigma if hard coded, ``None`` to let the algo find it. sigma_factor: float Estimated ratio :math:`\frac{\sigma}{\rho}` """ expert_sigma: pt.Tensor expert_epsilon: pt.Tensor # Sigma init if sigma is None: if rho > 0.: expert_sigma = rho * sigma_factor else: expert_sigma = pt.tensor(sigma_factor) else: expert_sigma = pt.tensor(sigma) # Epsilon init if epsilon is None: epsilon_factor = epsilon_sigma_factor * sigma_factor**p expert_epsilon = pt.max( epsilon_factor * rho.pow(p - 1), # epsilon ^ (p/q) pt.tensor(1e-7) ) else: expert_epsilon = pt.tensor(epsilon) return expert_sigma, expert_epsilon
[docs] def power_from_parsed_spec(parsed_spec: Optional[ParsedCost]) -> float: if parsed_spec is None: return 2. else: return parsed_spec.power
[docs] def dualize_primal_loss( loss_: nn.Module, transform_: Optional[nn.Module], rho: pt.Tensor, xi_batchinit: pt.Tensor, xi_labels_batchinit: Optional[pt.Tensor], post_sample: bool = True, cost_spec: Optional[str] = None, n_samples: int = 10, seed: int = 42, *, epsilon: Optional[float] = None, sigma: Optional[float] = None, l2reg: Optional[float] = None, adapt: Optional[str] = "prodigy", imp_samp: bool = True ) -> _DualLoss: r""" Provide the wrapped version of the primal loss. Parameters ---------- loss_: nn.Module the primal loss transform_: nn.Module the transformation to apply to the data before feeding it to the loss rho: Tensor, shape (n_samples,) Wasserstein radius xi_batchinit: Tensor, shape (n_samples, n_features) Data points to initialize the samplers and :math:`\lambda_0` xi_labels_batchinit: Optional[Tensor], shape (n_samples, n_features) Labels to initialize the samplers and :math:`\lambda_0` post_sample: bool whether to use a post-sampled dual loss cost_spec: str|None the cost specification in the format ``(k, p)`` for a sample k-norm and p-power. ``None`` to use the default ``(2, 2)``. n_samples: int number of :math:`\zeta` samples to draw before the gradient descent begins (can be changed if needed between inferences) seed: int the seed for the samplers epsilon: float|None Epsilon if hard coded, ``None`` to let the algo find it. sigma: float|None Sigma if hard coded, ``None`` to let the algo find it. l2reg: float|None L2 regularization if needed adapt: str|None the adaptative step to use between `"prodigy"` and `"mechanic"`. imp_samp: bool whether to use importance sampling (will work only for ``(2, 2)`` costs). """ sampler: BaseSampler cost: Cost has_labels = xi_labels_batchinit is not None if has_labels: assert isinstance(xi_labels_batchinit, pt.Tensor), ' '.join([ "Please provide a starting", "(mini/full)batch of labels", "to initialize the samplers" ]) parsed_cost = parse_code_torch(cost_spec, has_labels) expert_sigma, expert_epsilon = expert_hyperparams( rho, power_from_parsed_spec(parsed_cost), epsilon, EPSILON_SIGMA_FACTOR, sigma, SIGMA_FACTOR ) expert_sigma = expert_sigma.to(xi_batchinit) expert_epsilon = expert_epsilon.to(xi_batchinit) cost = cost_from_parse_torch(parsed_cost) if has_labels: assert xi_labels_batchinit is not None sampler = LabeledCostSampler( cost, xi_batchinit, xi_labels_batchinit, expert_sigma, seed ) else: sampler = NoLabelsCostSampler(cost, xi_batchinit, expert_sigma, seed) loss = WrappedPrimalLoss( loss_, transform_, sampler, has_labels, l2reg=l2reg ) # kwargs = { # "rho_0": rho, # "n_samples": n_samples, # "epsilon_0": expert_epsilon, # "adapt": adapt, # "imp_samp": imp_samp and parsed_cost.can_imp_samp(), # } loss_constructor = ( DualPostSampledLoss if post_sample else DualPreSampledLoss ) return loss_constructor( loss, cost, n_iter=((200, 2800) if post_sample else (100, 10)), rho_0=rho, n_samples=n_samples, epsilon_0=expert_epsilon, adapt=adapt, imp_samp=(imp_samp and parsed_cost.can_imp_samp()) )
# if post_sample: # return DualPostSampledLoss( # loss, # cost, # n_iter=(200, 2800), # **kwargs # ) # else: # return DualPreSampledLoss( # loss, # cost, # n_iter=(100, 10), # **kwargs # )