Source code for skwdro.solvers.optim_cond

from skwdro.solvers.utils import maybe_flatten_grad_else_raise, NoneGradError
from skwdro.solvers._dual_interfaces import _DualLoss as BaseDualLoss
import torch as pt
from typing import Callable, Optional, Union, Tuple


LazyTensor = Callable[[], pt.Tensor]
ValidAndError = Tuple[bool, float]

L_AND_T = {"both", "theta_and_lambda", "t&l", "lambda_and_theta", "l&t"}
L_OR_T = {"one", "theta_or_lambda", "tUl", "lambda_or_theta", "lUt"}
JUST_T = {"theta", "t"}
JUST_L = {"lambda", "l"}


[docs] def combine_intersect(a: ValidAndError, b: ValidAndError) -> ValidAndError: return a[0] and b[0], a[1] + b[1]
[docs] def combine_union(a: ValidAndError, b: ValidAndError) -> ValidAndError: return a[0] or b[0], a[1] + b[1]
[docs] def wrap(b: bool) -> ValidAndError: return b, 0.
[docs] class OptCondTorch: r""" Callable object representing some optimality conditions May track two different expression of the error: * the relative error: :math:`\|u_n\| < tol \|u_0\|` * the absolute error: :math:`\|u_n\| < tol` Those equations are evaluated for three possible metrics :math:`u_n`: * the progress in the gradient of the dual loss with respect to the parameter of interest :math:`\nabla_{\theta ,\lambda} J_{\theta_n}(\zeta_n)` * the progress of the parameters themselves :math:`(\theta_n-\theta_{n-1} , \lambda_n-\lambda_{n-1})` To evaluate the above metrics, one may chose to monitor the convergence in: * only :math:`\theta` * only :math:`\lambda` * both * or either Parameters ---------- order: int|str norm type to use tol_theta: float if positive, the tolerance (relative or absolute) to allow for the parameters, if <=0 ignores it tol_lambda: float if positive, the tolerance (relative or absolute) to allow for the dual parameter, if <=0 ignores it monitoring: str see the global variables :py:data:`L_OR_T` (for either convergence to allow stop), :py:data:`L_AND_T` (for joint convergence to allow stop), :py:data:`JUST_L` (for only :math:`\lambda`), :py:data:`JUST_T` (for only :math:`\theta`) to have the allowed options mode: str either ``"rel"`` for relative progress or ``"abs"`` for absolute progress. Not checked if the metric is the gradient value metric: either ``"grad"`` for gradient improvement/change over time, or ``"param"`` for parameter-space improvement/change over time """ def __init__( self, order: Union[int, str], tol_theta: float = 1e-8, tol_lambda: float = 1e-8, *, monitoring: str = "theta", mode: str = "rel", metric: str = "grad", verbose: bool = False ): """ """ if isinstance(order, str): assert order == 'inf' self.order: float = float(order) assert self.order > 0., ' '.join([ "Please provide a UINT", "order for the parameters", "grad norm, or the 'inf' string" ]) self.tol_theta = tol_theta self.tol_lambda = tol_lambda self.monitoring = monitoring self.mode = mode self.metric = metric self.max_iter: int = 0 self.verbose = verbose self.l_mem: Optional[pt.Tensor] = None self.t_mem: Optional[pt.Tensor] = None self.l_0: Optional[pt.Tensor] = None self.t_0: Optional[pt.Tensor] = None self.l_grad_0: Optional[pt.Tensor] = None self.t_grad_0: Optional[pt.Tensor] = None self.delta_l_1: Optional[pt.Tensor] = None self.delta_t_1: Optional[pt.Tensor] = None def __call__(self, dual_loss: BaseDualLoss, it_number: int) -> bool: r""" This object can be called on a dual loss object to check its current convergence with respect to its allowed max number of iterations. Parameters ---------- dual_loss: BaseDualLoss the loss to monitor it_number: int the looping iteration Returns ------- cond: bool green light to stop algorithm """ self.max_iter = ( dual_loss.n_iter if isinstance(dual_loss.n_iter, int) else dual_loss.n_iter[1] ) def flattheta() -> pt.Tensor: return self.get_flat_param(dual_loss.primal_loss) def flatgrad() -> pt.Tensor: return self.get_flat_grad(dual_loss.primal_loss) def lam() -> pt.Tensor: return dual_loss.lam def lamgrad() -> pt.Tensor: return maybe_flatten_grad_else_raise(dual_loss._lam) ci = self.check_iter(it_number) cp, err = self.check_all_params(lam, lamgrad, flattheta, flatgrad) if self.verbose and it_number % 100 == 0: print(f"[{it_number = }] {ci = } {cp = } {err = }") return ci or cp
[docs] def check_all_params( self, lam: LazyTensor, lamgrad: LazyTensor, flattheta: LazyTensor, flatgrad: LazyTensor) -> ValidAndError: r""" Checks the dual and primal parameters for convergence by using functional monads on the tensors, see :py:func:`~OptCondTorch.check_t` and :py:func:`~OptCondTorch.check_l`. Parameter --------- lam: LazyTensor the dual multiplier lam_grad: LazyTensor its scalar gradient flat_theta: LazyTensor the flattened concatenation of all the optimizeable parameters of the primal model flat_theta_grad: LazyTensor the flattened concatenation of the gradients of those parameters Returns ------- cond: bool green light to stop algorithm """ if self.monitoring in L_AND_T: return combine_intersect( self.check_l(lam, lamgrad), self.check_t(flattheta, flatgrad) ) elif self.monitoring in L_OR_T: return combine_union( self.check_l(lam, lamgrad), self.check_t(flattheta, flatgrad) ) elif self.monitoring in JUST_L: return self.check_l(lam, lamgrad) elif self.monitoring in JUST_T: return self.check_t(flattheta, flatgrad) else: raise ValueError("Please provide a valid value for the monitoring")
[docs] def check_t( self, flat_theta: LazyTensor, flat_theta_grad: LazyTensor ) -> ValidAndError: r""" Check the convergence of the theta parameter, either in gradient or in parameter value. The parameters are ``LazyTensor``s which means that they must be called as functions to be evaluated. Parameter --------- flat_theta: LazyTensor the flattened concatenation of all the optimizeable parameters of the primal model flat_theta_grad: LazyTensor the flattened concatenation of the gradients of those parameters Returns ------- cond: bool green light to stop algorithm """ cond = False if self.tol_theta <= 0.: return cond, 0. else: if self.metric == "grad": mem = self.t_grad_0 # nabla_theta first iteration if mem is None: # Compute nabla_theta because we are at nabla theta # right now. # Wait for next iteration to verify convergence. self.t_grad_0 = pt.linalg.norm( flat_theta_grad(), self.order) return wrap(False) else: # Compute nabla_theta at current iteration and compare it # to first iterate new = pt.linalg.norm(flat_theta_grad(), self.order) return self.check_metric(new, mem, self.tol_theta) elif self.metric == "param": mem0 = self.t_0 # first param vector theta_0 mem1 = self.delta_t_1 # |theta_1 - theta_0| if mem0 is None: # Define theta_0 self.t_0 = flat_theta() return wrap(False) elif mem1 is None: assert mem0 is not None # Define theta_1 (=theta_k) self.t_mem = flat_theta() # |theta_1 - theta_0| self.delta_t_1 = pt.linalg.norm( self.t_mem - mem0, self.order) return wrap(False) else: # theta_k mem = self.t_mem assert mem is not None # theta_{k+1} new = flat_theta() # make memory advance one step of k self.t_mem = new # |theta_{k+1} - theta_k| delta = pt.linalg.norm(new - mem, self.order) assert self.delta_t_1 is not None # Check current diff wrt first iterate diff return self.check_metric( delta, self.delta_t_1, self.tol_theta ) else: return wrap(False)
[docs] def check_l(self, lam: LazyTensor, lam_grad: LazyTensor) -> ValidAndError: r""" Check the convergence of the theta parameter, either in gradient or in parameter value. The parameters are ``LazyTensor``s which means that they must be called as functions to be evaluated Parameter --------- lam: LazyTensor the dual multiplier lam_grad: LazyTensor its scalar gradient Returns ------- cond: bool green light to stop algorithm """ if self.tol_lambda <= 0.: return wrap(False) else: if self.metric == "grad": mem = self.l_grad_0 if mem is None: # nabla_lambda at first iterate self.l_grad_0 = pt.abs(lam_grad()) return wrap(False) else: # New nabla_lambda new = pt.abs(lam_grad()) return self.check_metric(new, mem, self.tol_theta) elif self.metric == "param": mem0 = self.l_0 mem1 = self.delta_l_1 if mem0 is None: # first lambda self.l_0 = lam() return wrap(False) elif mem1 is None: assert mem0 is not None # lambda_1 (=lambda_k) self.l_mem = lam() # first diff in lambda self.delta_l_1 = pt.abs(self.l_mem - mem0) return wrap(False) else: # lambda_k mem = self.l_mem assert mem is not None # lambda_{k+1} new = lam() # Memory advances one step self.l_mem = new # |lambda_{k+1} - lambda_k| delta = pt.abs(new - mem) assert self.delta_l_1 is not None return self.check_metric( delta, self.delta_l_1, self.tol_lambda ) else: return wrap(False)
[docs] def check_metric( self, new_obs: pt.Tensor, memory: pt.Tensor, tol: float ) -> ValidAndError: r""" Helper function to get the tolerance check in both the relative and absolute error cases. Parameters ---------- new_obs: pt.Tensor current step metric memory: pt.Tensor same metric at last step -- initialized at None, so a check must be performed before call to this function tol: float the positive tolerance rate allowed (same for absolute and relative tolerance) Returns ------- cond: bool green light to stop algorithm """ assert tol > 0. if self.mode == "rel": return ( new_obs.sum().item() < tol * memory.sum().item(), new_obs.sum().item() ) elif self.mode == "abs": return new_obs.sum().item() < tol, new_obs.sum().item() else: raise ValueError(' '.join([ "Please set the optcond mode to either 'rel'", "for relative tolerance or 'abs'" ]))
[docs] def check_iter(self, it_number: int) -> bool: r""" Checks if the maximum number of iterations has been crossed Returns ------- cond: bool green light to stop algorithm """ return False if self.max_iter <= 0 else it_number >= self.max_iter
[docs] @classmethod def get_flat_param(cls, module: pt.nn.Module) -> pt.Tensor: """ Helper function to get a flat vector containing all the primal parameters. """ return pt.concat([*map( lambda t: t.flatten(), module.parameters() )])
[docs] @classmethod def get_flat_grad(cls, module: pt.nn.Module) -> pt.Tensor: """ Helper function to get a flat vector containing all the gradients of the primal model. """ try: return pt.concat([*map( maybe_flatten_grad_else_raise, module.parameters() )]) except NoneGradError as e: raise ValueError(' '.join([ "The module provided as the primal loss", "yields None grads for some of its", "parameters, preventing the solver from computing", "optimality conditions.", f"\nShape of original tensor: {e.args[0]}", "Please investigate.", ]))