PyTorch interface#

Practice

In a general machine-learning setting, any practitionners turn to deep-learning techniques that require the use of very specific tools that cater to the “big-data” setting with massively parallel operations. This begs for other computational architectures (e.g. GPUs, TPUs, etc), and adapted codebases. Popular among the deep-learning comunity are three main python libraries: PyTorch, Keras/Tensorflow, and Jax. They offer state-of-the-art performances and a lot of utilities to build and manipulate both models as well as the training data.

In this tutorial, we will understand how to use the PyTorch interfaces in skwdro in order to robustify a simple model. We aim at transforming any Pytorch-parametrized loss function \(L_\theta\) into its robust counterpart:

(1)#\[L_\theta^\texttt{robust}(\xi) := \lambda\rho + \varepsilon\log\mathbb{E}_{\zeta\sim\nu_\xi}\left[e^{\frac{L_\theta(\zeta)-\lambda c(\xi, \zeta)}{\varepsilon}}\right]\]

Model presentation#

Start from a simple network with one hidden layer, aiming at classifying 2-dimensional samples:

 1# Specify the model
 2class SimpleNN(nn.Module):
 3    def __init__(self, hidden_units):
 4        super().__init__()
 5        # Two hidden layers and logit output
 6        self.linear_relu_stack = nn.Sequential(
 7            nn.Linear(2, hidden_units),
 8            nn.ReLU(),
 9            nn.Linear(hidden_units, hidden_units),
10            nn.ReLU(),
11            nn.Linear(hidden_units, 1),
12        )
13
14    def forward(self, x):
15        logits = self.linear_relu_stack(x)
16        return logits
17
18# Instanciate it
19model = SimpleNN(32)

Usually, one would train it with a simple training procedure, looking vaguely as follows:

Training procedure: default ERM.#
 1for sample, target in my_dataloader:
 2    # Clean the kitchen
 3    my_optimizer.zero_grad()
 4
 5    # Forward pass
 6    inference = model(sample)
 7    loss = my_loss_function(inference, target)
 8
 9    # Backward pass
10    loss.backward()
11    my_optimizer.step()
12
13    # Testing
14    if my_condition():
15        with pt.no_grad():
16            model.eval()
17            print(my_loss_function(test_sample, test_target))
18            model.train()

This is very simple thanks to the ease-of-use of PyTorch, and now we wish to see how to use the interface to robustify this procedure.

SkWDRO’s interface for robustification#

The main idea of the interface comes from the fact that machine-learning models can be split into two kinds:

  • Some that attempt to link an input to a target, i.e. learn a parametrized function \(f_\theta\) such that in average on the available data \(y\approx f_\theta(x)\). For the sake of clarity, we distinguish two subcases:
    • Classification where y is categorical.

    • Regression where y can be anything.

  • Others that prescribe some cost to available samples to be minimized (or reward, or likelihood, etc, conversly to be maximized - wolog), i.e. learn \(c_\theta\) such that on average \(c_\theta\) is small.

This means that loss functions \(L_\theta\) will always look like follows for some function \(\ell\):

(2)#\[\begin{split}L_\theta(\xi):=\begin{cases} \ell(y-f_\theta(x)) & \text{in the regression case}\\ \ell(-y.f_\theta(x)) & \text{in the classification case}\\ \ell(c_\theta(x)) & \text{in the estimation case.} \end{cases}\end{split}\]

With that distinction in mind, the assumed structure of models treated by SkWDRO is as \(L_\theta\) in its most general case: a function that for each input \(\xi\in\Xi\) (possibly batched), outputs a scalar (batched accordingly). This is general enough to cover the three cases mentioned, but without further assumption it leaves for Python a choice to make at the function call: do we input the 2-uple \(\xi=(x, y)\) for regression/classification and input a 2-uple \(\xi=(x,)\) in other cases, or is it better to assume some structure on \(\xi\) at the function signature level? Turning to the way most PyTorch models are specified, especially with regard to the native \(\ell\) functionals (see the torch doc on the topic e.g. torch.nn.binary_cross_entropy()), it seems like the second option is more idiomatic, so this input&target structure is assumed.

Assumption: Target model structure.

We aim for a PyTorch model specified as a torch.nn.Module subclass, containing the parameters \(\theta\) as attributes. It should expose a overridden forward() method with the following signature with respect to the studied Model type (as self instance in python):

Forward call specification for SkWDRO models.#
forward :: Model -> x -> Maybe y -> Float

The inputs (x and y) should be batchable, in which case the output will be so. The Maybe on the y variable should wrap the batches (i.e. no array of Maybes, instead use a Maybe [y]). Translated in Python: one should be able to call:

model.forward(inputs, targets)

(whether or not inputs contains batch dimensions) in the first two cases of (2), and

model.forward(inputs, None)

in the third one.

Getting back to the training procedure described before, we see that it is not close to the default setting at all! Thus in the SkWDRO team we attempted to provide meaningful but easy-to-use interfaces to guide you through the transformation process.

  • The skwdro.torch.robustify() function is the main, best, and most easy to use one. Read the explanations bellow to understand how to use it in priority and leave the rest for situations that are not covered by this function.

  • skwdro.solvers.DualLoss and variations thereof are meant for people who already have a model complying with the assumptions above, and lets you have a lot more control on the precise pieces of the model.

robustify: the simplest method#

Here comes the documentation of the main dish, the robustify() function.

skwdro.torch.robustify(loss_: Module | Callable[[...], Tensor], transform_: Module | None, rho: Tensor, xi_batchinit: Tensor, xi_labels_batchinit: Tensor | None, post_sample: bool = True, cost_spec: str | None = None, n_samples: int = 10, seed: int = 42, *, reduction: str | None = None, learning_rate: float | None = None, epsilon: float | None = None, sigma: float | None = None, l2reg: float | None = None, adapt: str | None = 'prodigy', n_iter: int | Tuple[int, int] | None = None, imp_samp: bool = True, loss_reduces_spatial_dims: bool = False) _DualFormulation

Provide the wrapped version of the primal loss.

Parameters:
loss_: nn.Module|Callable

the primal loss \(L_\theta\). Can be given either as a torch.nn.Module or as a (functional) callable.

transform_: nn.Module|None

the transformation to apply to the (non-label) data before feeding it to the loss. Identity if set to None (default).

rho: Tensor, scalar tensor

Wasserstein radius

xi_batchinit: Tensor, shape (n_samples, n_features)

Data points to initialize the samplers and \(\lambda_0\)

xi_labels_batchinit: Optional[Tensor], shape (n_samples, n_features)

Labels to initialize the samplers and \(\lambda_0\)

post_sample: bool

whether to use a post-sampled dual loss

cost_spec: str|None

the cost specification in the format (k, p) for a sample k-norm and p-power. None to use the default (2, 2).

n_samples: int

number of \(\zeta\) samples to draw before the gradient descent begins (can be changed if needed between inferences)

seed: int

the seed for the samplers

reduction: str | None

specifies the reduction to apply to the outer expectation of the SkWDRO formula applied: 'none' | 'mean' | 'sum'. - 'none': no reduction will be applied, - 'mean': the sum of the output will be divided by the number of elements in the output, - 'sum': the output will be summed. Default: None which translates to 'mean'

learning_rate: float

the step size for the default descent algorithm linked to the loss function

epsilon: float|None

Epsilon if hard coded, None to let the algo find it.

sigma: float|None

Sigma if hard coded, None to let the algo find it.

l2reg: float|None

L2 regularization if needed

adapt: str|None

the adaptative step to use between “prodigy” and “mechanic”.

n_iter: int|tuple[int, int]|None

can set the default number of iterations if used through the default solving routines. Mostly an internal parameter. If int, it is the number of internal robust optimization steps, if a 2-uple of ints, it is the number of erm steps preceding the robust solve then the number of robust steps, if None it will be filled by default.

imp_samp: bool

whether to use importance sampling (will work only for (2, 2) costs).

loss_reduces_spatial_dims: bool

flag that can be set to True if the primal loss reduces the last dimension of the losses batch with its reduction set to 'none', e.g. for torch.CrossEntropyLoss which will take one dimension as channel axis, defaults to False

This may seem daunting at first glance, so let’s dive in while focusing on the most important parts.

Diving into the robustify function#

Here are the arguments and how to use them in the example training procedure:

  • The important ones:
    • loss_: this is the function that takes the output of your inference model \(f_\theta(x)\) and computes the mismatch to the target \(y\), whether in the sense of classification or regression.

      Warning

      The functional interface of pytorch is available for this argument specificaly, but it is recent and less tested. Use at your own risk. If you believe that the interface is the reason behind some bug you have, we encourage you to wrap it in a Module instance, as follows:

      Translate the functional api to object-oriented#
       1class MyOopLoss(torch.nn.Module):
       2    def forward(self, input, target, *args, **kwargs):
       3        # Here goes any reshaping necessary
       4        ...
       5        # Call directly the function
       6        return my_functional_loss(
       7            input, target,
       8            reduction='none',
       9            *args, **kwargs
      10        )
      
    • transform_ is the inference mechanism \(f_\theta\). In most cases, it will contain all of the parameters \(\theta\), e.g. it can be your linear model Linear or any kind of Module (neural nets, etc). No functional interface ios available there.

    • rho is the most important hyperparameter of the WDRO framework: it represents the radius of uncertainty defining the (regularized) Wasserstein ambiguity set. You can use simple cross-validation to find a suitable one for your particular problem if you do not have any idea about it, or you can turn to some tricks from the litterature, e.g. [1].

    • xi_batchinit and xi_labels_batchinit must be taken as a subset of the dataset to help the optimizers with a good starting point \(\lambda_0\) for the dual parameter \(\lambda\ge 0\) of SkWDRO‘s magic formula (1).

    • epsilon is the regularization parameter determining how much we smooth the Wasserstein ambiguity region with the entropic regularization \(\varepsilon\mathcal{KL}(\pi\|\pi_0)\).

      One may use cross-validation to select a good one, but the beware the importance of numerical stability in this choice: even though the out-of-sample performance remain the main goal, very small values of epsilon may lead to difficulties in the optimization process.

    • sigma is the amount of noise of the (non-truncated) gaussian distribution that is used as \(\pi_0\) “reference transport plan”. More precisely, \(\pi_0(\xi, \zeta):=\delta_\xi\otimes\mathcal{N}(\xi, \sigma^2I)\), so that given a sample \(\xi\) (or batch thereof) the “adversarial” samples are sampled from \(\mathcal{N}(\xi, \sigma^2I)\) in the “log-avg-exp” expression.

    • n_samples defines the number of \(\zeta\) samples drawn for each \(\xi\). Recall that the computational efforts for the gradient step thus goes from \(\mathcal{O}(B.d)\) for ERM with batchsize \(B\) and dimension \(d\) to \(\mathcal{O}(B.d.\texttt{n_samples})\), which can be significant if you want precise gradients.

  • The not so important ones:
    • post_sample can be set to false to sample the \(\zeta\) adversarial samples only once at the beginning of the optimization procedure, if you are doing fullbatch optimization in the first place (i.e. GD, not SGD). This opens the door to more performant algorithms such as torch.optim.BFGS.t the expense of statistical soundness of the estimation of the logsumexp expression.

    • cost_spec is string-like specification defining the 2-uple (k, p), in order to specify the cost functional for the Wasserstein distance as a p-th power distance \(\|\zeta-\xi\|_k^p\) in some \(\|\cdot\|_k\)-Banach space. See the skwdro.base.cost_decoder() function’s sources for more details.

    • imp_samp can be set to false to disable importance sampling on the inner expectation of the “log-avg-exp” expression when it would otherwise be enabled (when \(p=k=2\)).

    • adapt can be set to either "mechanic" or "prodigy" to set up an automatic learning rate tuner based on the adam optimizer for the builtin optimizer of the SkWDRO loss function. Otherwise, set it to None to get the regular AdamW implementation. learning_rate can be set to any positive floating point number in order to specify the stepsize of Adam in this case.

This function is meant to perform two tasks as one: merge your loss function with your inference model, if need be, and build the dual loss displayed in (1). The output is a skwdro.oracle_torch.DualLoss object/module that represents \(L_\theta^\texttt{robust}\) from (1).

Before/after comparison for robustify#

Training procedure: SkWDRO with robustify.#
 1robust_model = robustify(
 2    my_loss_function,
 3    model,
 4    pt.tensor(0.01),  # the radius you picked for the Wasserstein ball
 5    *next(iter(my_dataloader)),
 6    # Optionally set those keyword-only HPs:
 7    epsilon = 1e-3,
 8    sigma = 0.1
 9)
10
11for sample, target in my_dataloader:
12    # Clean the kitchen
13    my_optimizer.zero_grad()
14
15    # Forward pass
16    inference = model(sample)
17    sample_loss = my_loss_function(inference, target)
18    sample_loss = robust_model(sample, target)
19
20    # Backward pass
21    sample_loss.backward()
22    my_optimizer.step()
23
24    # Testing
25    if my_condition():
26        with pt.no_grad():
27            model.eval()
28            print(my_loss_function(test_sample, test_target))
29            # Note: to perform forward inference, use robust_model.primal_loss.transform
30            print(robust_model.primal_loss(test_sample, test_target))
31            model.train()

Dual losses: tune everything you want#

Conceptually, (1) is specified in its most general form by the following building blocks:

/

Math notation

Notation in the codebase (with link if relevant)

Examples

A loss function

\(L_\theta(\zeta)\)

primal_loss/loss

\((\zeta_y - \left\langle\theta|\zeta_x\right\rangle)^2\), \(\log\left(1+e^{\zeta_y\left\langle\theta|\zeta_x\right\rangle}\right)\), etc

A cost functional

\(c(a, b)\)

Cost (tuto)

\(\|b-a\|_k^p\), \(\begin{cases}0&\text{if }a=b\\ 1&\text{otherwise}\end{cases}\), etc

A reference transport plan

\(\nu_\xi(\zeta):=\pi_0(\xi, \zeta)\)

BaseSampler (tuto)

\(\mathcal{N}(\zeta | \xi, \sigma^2I)\), \(\mathcal{U}_{\left[\xi-\frac{\sigma}2, \xi+\frac{sigma}2\right]}(\zeta)\), \(\mathcal{U}_{\{0, \dots, 255\}}(\zeta)\), \(\delta_\xi(\zeta)\), etc

skwdro lets you build your custom robust loss function representing the dual formula (1) through its second main interface: skwdro.solvers.DualLoss.

Diving into the DualLoss class(es)#

The main one anyone will want to try, that is aliased by DualLoss, is the following.

skwdro.solvers.DualPostSampledLoss(loss: Loss, cost: TorchCost, n_samples: int, epsilon_0: Tensor, rho_0: Tensor, n_iter: int | Tuple[int, int] = 10000, *, reduction: str = None, gradient_hypertuning: bool = False, learning_rate: float | None = None, imp_samp: bool = True, adapt: str | None = 'prodigy') None[source]

Dual loss implementing a sampling of the \(\zeta\) vectors at each forward pass.

Parameters:
lossLoss

the (primal) loss of interest \(L_\theta\)

costCost

ground-distance function

n_samplesint

number of \(\zeta\) samples to draw at each forward pass

epsilon_0: torch.Tensor

scalar tensor containing the \(\varepsilon\) regularization hyperparameter

rho_0: torch.Tensor

scalar tensor containing the \(\rho\) (regularized) Wasserstein radius hyperparameter

n_iter: Steps

either a tuple (number of ERM iterations, number of DRO iterations), of type (int, int), or an integer for the number of DRO iterations

reduction: str | None

specifies the reduction to apply to the outer expectation of the SkWDRO formula applied: 'none' | 'mean' | 'sum'. - 'none': no reduction will be applied, - 'mean': the sum of the output will be divided by the number of elements in the output, - 'sum': the output will be summed. Default: None which translates to 'mean'

gradient_hypertuning: bool

set to True to accumulate gradients in rho and epsilon .. tip:: should almost always be kept to False

learning_rate: Optional[float]

set the stepsize of the torch.optim.AdamW algorithm. Defaults to None which will be parsed as 5e-2

imp_samp: bool

set to True to enable importance sampling

Warning

Unlike the skwdro.torch.robustify() interface, there is no protection against mistakes here. So please do not attempt to set importance sampling for now if:

  • your target is categorical

  • your model is non-differentiable

  • your model includes parts that use the regular .backwards() torch interface for inner autodiff utilities instead of the functional API

  • your cost functional does not implement the right functions (see appropriate tutorials).

adapt: Optional[str]

set to either:

  • None to use torch.optim.AdamW.

    Tip

    Set the learning rate with the above parameter learning_rate.

  • "prodigy" or "mechanic" to get automatic learning rate tuning

It has a sister-class DualPreSampledLoss, that will keep the same sampled \(\zeta\) values for the inner expectation, needing only to sample it once at the expense of statistical soundness, coming from an idea of [2].

In order to get the building blocks for the first two arguments of the constructor described above, here is a simple receipe:

  • build your own \(L_\theta\) loss, either,
    • by combining the \(\ell\) functional with your inference model \(f_\theta\)/\(c_\theta\), using the skwdro.base.losses_torch.WrappedPrimalLoss helper class,

    • or by reusing any loss already available in skwdro.base.losses_torch,

  • then get yourself a cost functional tailored to the geometry and properties of your space of interest,
    • by subclassing skwdro.base.costs_torch.Cost,

    • or by using any of the already available ones in skwdro.base.costs_torch,

  • get a good sampling strategy to explore adversarial samples according to the association cost-space-prior knowledge,
    • build your own sampler \(\nu_\xi\) by subclassing skwdro.base.samplers.torch.BaseSampler,

    • use the sampler generated by the geometry of your space, i.e. linked to the cost you chose previously, using the skwdro.base.samplers.torch.cost_samplers module’s helpers,

    • fetch the sampler best suited (according to the chef’s menu, if available, see the skwdro.base.losses_torch module) to the problem you are solving by calling your losses skwdro.base.losses_torch.Loss.default_sampler() method.

Note

To build your loss function, you will need to have chosen a sampler, and to pick it you may want to use the skwdro.base.losses_torch.Loss.default_sampler() utility. While this may sound a bit circular, one may set to starting sampler of any loss to None and then overwrite it dynamically with the setter method:

Set a sampler after initialization (Logistic regression).#
1sigma = torch.tensor(.01)
2loss = skwdro.base.losses_torch.LogisticLoss(None, d=10)
3loss.sampler = loss.default_sampler(xi, xi_labels, sigma)

Before/after comparison for DualLosses#

Training procedure: SkWDRO with DualLoss.#
 1SEED = 42
 2xi_warmup, xi_labels_warmup = next(iter(my_dataloader)),
 3
 4# Ingredient 2: the cost functional
 5# Pick \|x-y\|_2^2
 6cost = NormLabelCost(
 7    2., 2., 1000.
 8)
 9
10# Ingredient 3: the sampling distribution
11sampler = LabeledCostSampler(
12    cost,
13    xi_warmup, xi_labels_warmup,
14    sigma = 0.1,
15    seed = SEED
16)
17
18# Ingredient 1: the loss function
19loss_functional = WrappedPrimalLoss(
20    my_loss_function,  # as a torch.Module!
21    model,
22    sampler,
23    has_labels=True
24)
25
26# Mix together, and let it rest
27robust_model = DualLoss(
28    loss,
29    cost,
30    10,
31    pt.tensor(1e-3),  # the regularization parameter
32    pt.tensor(0.01),  # the radius you picked for the Wasserstein ball
33    # Optionally set those keyword-only HPs:
34    epsilon = 1e-3,
35    sigma = 0.1
36)
37
38for sample, target in my_dataloader:
39    # Clean the kitchen
40    my_optimizer.zero_grad()
41
42    # Forward pass
43    inference = model(sample)
44    loss = my_loss_function(inference, target)
45    loss = robust_model(inference, target)
46
47    # Backward pass
48    loss.backward()
49    my_optimizer.step()
50
51    # Testing
52    if my_condition():
53        with pt.no_grad():
54            model.eval()
55            print(my_loss_function(test_sample, test_target))
56            # Note: to perform forward inference, use robust_model.primal_loss.transform
57            print(loss_functional(test_sample, test_target))
58            model.train()

Conclusion#

The take away message of this small tutorial is that you can either use the simple interface of skwdro.torch.robustify() to change only a few lines in your codebase, or you can get more control over the details of the algorithm by using the skwdro.solvers.DualLoss interfaces. We advise new users to turn to the first, while users who try to study the behavior of SkWDRO in more details may take a look at various strategies proposed by the latter through the various aforementioned modules (for costs, samplers, losses, etc).

Next#

References#