Source code for skwdro.solvers.entropic_dual_torch

import math
from warnings import warn
from typing import List, Optional
import torch as pt
from skwdro.base.samplers.torch.base_samplers import BaseSampler
from skwdro.base.problems import Distribution
from skwdro.solvers.optim_cond import OptCondTorch
from skwdro.solvers.utils import detach_tensor, interpret_steps_struct
from skwdro.solvers.result import wrap_solver_result
from skwdro.solvers.oracle_torch import _DualLoss, CompositeOptimizer


[docs] def extract_data(dist: Distribution): """ Get torch tensors out of empirical distribution. Parameters ---------- dist: Distribution Empirical distribution of data and optionally labels Returns ------- xi: pt.Tensor data tensor xi_labels: Optional[pt.Tensor] label tensor if the distribution yields them, else ``None`` .. admonition:: Shapes Shapes of input/output tensors. - xi: (m, d) - xi_labels: None or (m, d') """ if dist.with_labels: xi = pt.Tensor(dist.samples_x) xi_labels = pt.Tensor(dist.samples_y) return xi, xi_labels else: xi = pt.Tensor(dist.samples) return xi, None
[docs] @wrap_solver_result def solve_dual_wdro(loss: _DualLoss, p_hat: Distribution, opt: OptCondTorch): r""" Solve the dual problem with the loss-dependant grandient descent algorithm. Parameters ---------- loss: _DualLoss Dual loss p_hat: Distribution Empirical distribution opt: OptCond Optimality conditions Returns ------- theta: np.ndarray Concatenated array of the parameters of the model, except the intercept if there is one intercept: Optional[np.ndarray] If the model has specificaly an intercept as one of its parameters, it is stacked in this output tensor lambd: Union[np.ndarray, float] Dual variable :math:`\lambda` of the problem .. admonition:: Shapes of input/output tensors. - theta: (n_params,) - intercept: (n_intercepts,) or None - lambd: (1,) """ # Cast our raw data into tensors xi, xi_labels = extract_data(p_hat) # If user provides a numpy loss, fail. assert loss is not None assert isinstance(loss, _DualLoss) # Initialize sampler. assert isinstance(loss.sampler, BaseSampler) # If user wants to specify a custom optimizer, they need to register # an instance of a subclass of torch optimizers in the # relevant attribute. optimizer: pt.optim.Optimizer = loss.optimizer # _DualLoss.presample determines the way the optimization is performed optim_ = optim_presample if loss.presample else optim_postsample opt_cond: OptCondTorch = opt losses = optim_(optimizer, xi, xi_labels, loss, opt_cond) theta = detach_tensor(loss.theta) intercept = loss.intercept if intercept is not None: intercept = detach_tensor(intercept) lambd = detach_tensor(loss.lam) if loss.rho > 0. else [0.] robust_loss = losses[-1] return theta, intercept, lambd, robust_loss
[docs] def optim_presample( optimizer: pt.optim.Optimizer, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], loss: _DualLoss, opt_cond: OptCondTorch ) -> List[float]: r""" Optimize the dual loss by sampling the :math:`zeta` values once at the begining of the optimization, then performing a deterministic gradient descent (e.g. BFGS style algorithm). Parameters ---------- optimizer : pt.optim.Optimizer loss-dependant optimizer, can be customized if needed xi : pt.Tensor data tensor xi_labels : Optional[pt.Tensor] target tensor loss : _DualLoss dual loss instance Returns ------- losses: list[float] .. admonition:: Shapes of input/output tensors. - xi: (m, d) - xi_labels: (m, d') """ if loss.primal_loss.sampler is None: # Means were never set in the first place, so initialize the sampler # Warning: no seed is set, so set the sampler yourself if needed loss.init_sampler(xi, xi_labels, loss.epsilon.detach().item(), None) zeta, zeta_labels = loss.generate_zetas() def closure(back=True) -> float: """ Loss evaluation function, performing the forward pass for the autograd engine. """ optimizer.zero_grad() if back: # Forward pass objective = loss(xi, xi_labels, zeta, zeta_labels) # Backward pass objective.backward() else: loss.eval() objective = loss(xi, xi_labels, zeta, zeta_labels) loss.train() assert isinstance(objective, pt.Tensor) return objective.item() losses = [] pretrain_iters, train_iters = interpret_steps_struct(loss.n_iter) # Pretrain ERM loss.erm_mode = True for _ in range(pretrain_iters): optimizer.step(closure) # Init lambda loss.get_initial_guess_at_dual(xi, xi_labels) loss.erm_mode = False if hasattr(optimizer, "reset_lbd_state"): optimizer.reset_lbd_state() # type: ignore # Train WDRO for iteration in loss.iterations: # Do not resample, only step according to BFGS-style algo optimizer.step(closure) if opt_cond.metric == 'grad': _is = loss.imp_samp losses.append(closure()) loss.imp_samp = _is # Put it back on if it used to be. del _is else: with pt.no_grad(): _is = loss.imp_samp loss.imp_samp = False # Shut down IS if it is on. losses.append(closure(False)) loss.imp_samp = _is # Put it back on if it used to be. del _is if math.isnan(losses[-1]): warn( "Last iteration loss yielded NaN, the default algorithm diverged" ) break if opt_cond(loss, iteration): break return losses
[docs] def optim_postsample( optimizer: pt.optim.Optimizer, xi: pt.Tensor, xi_labels: Optional[pt.Tensor], loss: _DualLoss, opt_cond: OptCondTorch ) -> List[float]: r""" Optimize the dual loss by resampling the :math:`\zeta` values at each gradient descent step. Note that the descent is performed fullbatch on :math:`\xi`. Parameters ---------- n_iter : int number of gradient descent iterations to perform optimizer : pt.optim.Optimizer loss-dependant optimizer, can be customized if needed xi : pt.Tensor data tensor xi_labels : Optional[pt.Tensor] target tensor loss : _DualLoss dual loss instance Returns ------- List[float] .. admonition:: Shapes of input/output tensors. - xi: (m, d) - xi_labels: (m, d') """ losses = [] pretrain_iters, train_iters = interpret_steps_struct(loss.n_iter) # Pretrain ERM loss.erm_mode = False for _ in range(pretrain_iters): optimizer.zero_grad() # Resamples zetas at forward pass objective = loss(xi, xi_labels) assert isinstance(objective, pt.Tensor) objective.backward() # Perform the stochastic step optimizer.step() losses.append(objective.item()) # Init lambda loss.get_initial_guess_at_dual(xi, xi_labels) if isinstance(optimizer, CompositeOptimizer) and loss.erm_mode: optimizer.reset_lbd_state() # Train WDRO loss.erm_mode = False for iteration in loss.iterations: optimizer.zero_grad() # Resamples zetas at forward pass objective = loss(xi, xi_labels) assert isinstance(objective, pt.Tensor) objective.backward() # Perform the stochastic step optimizer.step() losses.append(objective.item()) if math.isnan(losses[-1]): warn( "Last iteration loss yielded NaN, the default algorithm diverged" ) break if opt_cond(loss, iteration): break return losses