from typing import Callable, Dict, Tuple, Optional, overload
from itertools import chain
import torch as pt
from prodigyopt import Prodigy
from mechanic_pytorch import mechanize
from skwdro.base.costs_torch import Cost
from skwdro.base.losses_torch import Loss
from skwdro.solvers._dual_interfaces import _DualLoss
from skwdro.solvers.utils import Steps, interpret_steps_struct
IMP_SAMP = True
[docs]
class CompositeOptimizer(pt.optim.Optimizer):
def __init__(self, params, lbd, n_iter, optimizer):
self.lbd = lbd
def make_optim(params):
if optimizer == 'mechanic':
return mechanize(
pt.optim.Adam
)(params, lr=1.0, weight_decay=0.)
elif optimizer == 'prodigy':
return Prodigy(
params,
lr=1.0,
weight_decay=0,
safeguard_warmup=True,
use_bias_correction=True
)
else:
raise NotImplementedError(
"No composite optimizer by that name"
)
self.opts = {
'params': make_optim(params),
'lbd': make_optim([lbd])
}
if optimizer == 'prodigy':
pretrain_iters, train_iters = interpret_steps_struct(n_iter)
T = {'params': pretrain_iters + train_iters, 'lbd': train_iters}
self.schedulers = {
k: pt.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=T[k]
) for (k, opt) in self.opts.items()
}
else:
self.schedulers = {}
self.init_state_lbd = self.opts['lbd'].state_dict()
super(CompositeOptimizer, self).__init__(chain(params, [lbd]), {})
def __getstate__(self) -> Dict[str, object]:
s = {key: val.__getstate__() for key, val in self.opts.items()}
s['init_state_lbd'] = self.init_state_lbd
s["defaults"] = {}
return s
@overload
def step(self, closure: None = None) -> None:
...
@overload
def step(self, closure: Callable) -> float:
raise NotImplementedError(
"Please provide a null callable to the step f°"
)
[docs]
def step(self, closure: Optional[Callable] = None) -> Optional[float]:
del closure
for opt in self.opts.values():
opt.step()
for scheduler in self.schedulers.values():
scheduler.step()
with pt.no_grad():
self.lbd.clamp_(0., None)
return None
[docs]
def zero_grad(self, *args, **kwargs):
del args
del kwargs
for opt in self.opts.values():
opt.zero_grad()
[docs]
def state_dict(self):
return {k: opt.state_dict() for (k, opt) in self.opts.items()}
[docs]
def load_state_dict(self, state_dict):
for (k, opt) in self.opts.items():
opt.load_state_dict(state_dict[k])
[docs]
def reset_lbd_state(self):
self.opts['lbd'].load_state_dict(self.init_state_lbd)
[docs]
class DualPostSampledLoss(_DualLoss):
r"""
Dual loss implementing a sampling of the :math:`\zeta` vectors at
each forward pass.
Parameters
----------
loss : Loss
the loss of interest :math:`L_\theta`
cost : Cost
ground-distance function
n_samples : int
number of :math:`\zeta` samples to draw at each forward pass
"""
def __init__(
self,
loss: Loss,
cost: Cost,
n_samples: int,
epsilon_0: pt.Tensor,
rho_0: pt.Tensor,
n_iter: Steps = 10000,
gradient_hypertuning: bool = False,
*,
imp_samp: bool = IMP_SAMP,
adapt: Optional[str] = "prodigy",
) -> None:
super(DualPostSampledLoss, self).__init__(
loss,
cost,
n_samples,
epsilon_0,
rho_0,
n_iter,
gradient_hypertuning,
imp_samp=imp_samp
)
if adapt:
assert adapt in ("mechanic", "prodigy")
self._opti = CompositeOptimizer(
self.primal_loss.parameters(), self.lam, n_iter, adapt)
else:
self._opti = pt.optim.AdamW(
self.parameters(),
lr=5e-2,
betas=(.99, .999),
weight_decay=0.,
amsgrad=True,
foreach=True
)
[docs]
def reset_sampler_mean(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor] = None
):
""" Prepare the sampler for a new batch of :math:`xi` data.
Parameters
----------
xi : pt.Tensor
new data batch
xi_labels : Optional[pt.Tensor]
new labels batch
"""
self.primal_loss.sampler.reset_mean(xi, xi_labels)
@overload
def forward(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor] = None,
zeta: None = None,
zeta_labels: None = None,
reset_sampler: bool = False
) -> pt.Tensor:
pass
@overload
def forward(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor],
zeta: pt.Tensor,
zeta_labels: Optional[pt.Tensor] = None,
reset_sampler: bool = False
) -> pt.Tensor:
raise ValueError(
"This class does not support forwarding pre-sampled zetas"
)
[docs]
def forward(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor] = None,
zeta: Optional[pt.Tensor] = None,
zeta_labels: Optional[pt.Tensor] = None,
reset_sampler: bool = False
) -> Optional[pt.Tensor]:
"""
Forward pass for the dual loss, with the sampling of the
adversarial samples
Parameters
----------
xi : pt.Tensor
data batch
xi_labels : Optional[pt.Tensor]
labels batch
reset_sampler : bool
defaults to ``False``, if set resets the batch saved in the sampler
Returns
-------
dl : pt.Tensor
Shapes
------
xi : (m, d)
xi_labels : (m, d')
dl : (1,)
"""
del zeta, zeta_labels
if reset_sampler:
self.reset_sampler_mean(xi, xi_labels)
if self.rho < 0.:
raise ValueError(' '.join([
"Rho < 0 detected: ->",
str(self.rho.item()),
", please provide a positive rho value"
]))
elif self.rho == 0.:
first_term = self.rho * self.lam
_pl: pt.Tensor = self.primal_loss(
xi.unsqueeze(0), # (1, m, d)
# (1, m, d') or None
xi_labels.unsqueeze(0) if xi_labels is not None else None
).mean() # (1,)
return first_term + _pl
else:
zeta_, zeta_labels_ = self.generate_zetas(self.n_samples)
return self.compute_dual(xi, xi_labels, zeta_, zeta_labels_)
def __str__(self):
return "Dual loss (sample IN for loop)\n" + 10 * "-" + "\n".join(
map(str, self.parameters())
)
@property
def presample(self):
return False
[docs]
class DualPreSampledLoss(_DualLoss):
r""" Dual loss implementing a forward pass without resampling the
:math:`\zeta` vectors.
Parameters
----------
loss : Loss
the loss of interest :math:`L_\theta`
cost : Cost
ground-distance function
n_samples : int
number of :math:`\zeta` samples to draw before the gradient
descent begins (can be changed if needed between inferences).
"""
zeta: Optional[pt.Tensor]
zeta_labels: Optional[pt.Tensor]
def __init__(
self,
loss: Loss,
cost: Cost,
n_samples: int,
epsilon_0: pt.Tensor,
rho_0: pt.Tensor,
n_iter: Steps = 50,
gradient_hypertuning: bool = False,
*,
imp_samp: bool = IMP_SAMP,
adapt: Optional[str] = "prodigy",
) -> None:
del adapt
super(DualPreSampledLoss, self).__init__(
loss,
cost,
n_samples,
epsilon_0,
rho_0,
n_iter,
gradient_hypertuning,
imp_samp=imp_samp
)
self._opti = pt.optim.LBFGS(
self.parameters(),
lr=1.,
max_iter=1,
max_eval=10,
tolerance_grad=1e-4,
tolerance_change=1e-6,
history_size=30
)
self.zeta = None
self.zeta_labels = None
@overload
def forward(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor] = None,
zeta: None = None,
zeta_labels: None = None,
reset_sampler: bool = False
) -> pt.Tensor:
raise NotImplementedError(
"This class must forward pre-sampled zeta values"
)
@overload
def forward(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor],
zeta: pt.Tensor,
zeta_labels: Optional[pt.Tensor] = None,
reset_sampler: bool = False
):
del xi, xi_labels, zeta, zeta_labels, reset_sampler
[docs]
def forward(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor] = None,
zeta: Optional[pt.Tensor] = None,
zeta_labels: Optional[pt.Tensor] = None,
reset_sampler: bool = False
) -> pt.Tensor:
r""" Forward pass for the dual loss, wrt the already sampled
:math:`\zeta` values
Parameters
----------
xi : pt.Tensor
data batch
xi_labels : Optional[pt.Tensor]
labels batch
zeta : Optional[pt.Tensor]
data batch
zeta_labels : Optional[pt.Tensor]
labels batch
Returns
-------
dl : pt.Tensor
Shapes
------
xi : (m, d)
xi_labels : (m, d')
dl : (1,)
"""
del reset_sampler
if zeta is None:
if self.zeta is None:
# No previously registered samples, fail
raise ValueError(' '.join([
"Please provide a zeta value for the forward pass of",
"DualPreSampledLoss, else switch to",
"an instance of DualPostSampledLoss."
]))
else:
# Reuse the same samples as last forward pass
return self.compute_dual(
xi,
xi_labels,
self.zeta,
self.zeta_labels
)
else:
self.zeta = zeta
self.zeta_labels = zeta_labels
return self.compute_dual(xi, xi_labels, zeta, zeta_labels)
def __str__(self):
return "Dual loss (sample BEFORE for loop)\n" + 10 * "-" + "\n".join(
map(str, self.parameters())
)
@property
def presample(self):
return True
@property
def current_samples(
self
) -> Tuple[
Optional[pt.Tensor],
Optional[pt.Tensor]
]:
return self.zeta, self.zeta_labels
"""
DualLoss is an alias for the "post sampled loss"
(resample at every forward pass)
"""
DualLoss = DualPostSampledLoss