Recipe for a good ground-cost for Wasserstein-DRO#
Practice
Tip
Read the tutorial on SkWDRO to understand better this part.
Recall the formula for SkWDRO:
It includes a cost function \(c\) that imposes a notion of geometry in the design space of the samples \(\xi,\zeta\in\Xi\). We will answer the question of how to pick this crucial hyperparameter.
Distance structure#
Many of the examples of robustness treated in the litterature showcase costs of the form
for some power \(p\ge 1\). These are especially well treated in the litterature for their nice distance properties that are transmitted to the transport cost, making it the so-called Wasserstein distance [2] (see last section for implementation details).
Simple cases: turn to euclidean geometry#
If your problem is formulated in simple cases in which no structure is prescribed on the space of samples \(\Xi\), your cost should look like the euclidean norm:
Then, this opens a wide range of Wasserstein distances, called the \(W_p\) distances in the litterature, for which you raise this norm to some power and change the allowed radius \(\rho\) accordingly.
The choice of p can be made in accordance with the behaviour of the loss function \(L_\theta\) in its “worst regions”. To learn more about the growth criteria that make most sense regarding this remark, take a look at [3].
The interface of SkWDRO with regard to this matter is very simple to use: in the robustification interface you may directly specify such a structure by using a decoded string following our simple grammar.
Here, this is done with the following cost specification:
p: float = 1. # pick to your liking
cost_spec = f"t-NC-2-{p}"
Tip
Notice that if you choose \(p=2\), you gain access to an efficient importance-sampling algorithm. More on this in another tutorial…
Other norm-based costs#
Then, if you want to impose more geometry to interpret your results in a different way (e.g. comparing to the robustness induced by FGSA, or anything like this), you can change the norm to a different one by specifying a \(k\) parameter.
p: float = 1.
k: float = 1. # pick to your liking
cost_spec = f"t-NC-{k}-{p}"
What about targets?#
If the loss function relates to a classification or regression task, then one may allow for targets to be subject to uncertainty as well.
The cost specification interface lets you account for that as well, by switching from the usual “Norm-Cost” (NC) to a “Norm-with-Label-Cost” (NLC):
p: float = 1.
k: float = 1. # pick to your liking
cost_spec = f"t-NLC-{k}-{p}"
Tip
Again, picking the 2-2 combination unlocks the importance-sampling algorithm.
Checking the litterature on WDRO (see [1] Remark 8), one may see that there is a range of ways to interpolate between no target transport being allowed, and a cost of the same magnitude for changes of targets as changes to the inputs. We can introduce a hyperparameter \(\kapppa\) that lets us weight the contribution of a target change to the transport cost.
This can be specified to the NLC cost parser easiuly as the last optional parameter.
p: float = 1.
k: float = 1. # pick to your liking
cost_spec = f"t-NLC-{k}-{p}-10.0"
Tip
As a guideline, using \(\kappa=\infty\) amounts to using NC (rule of thumb: moving a target is now infinitely costly, so it is “not allowed”), while putting it to a small value will make it comparatively easier/”cheaper” to move an target than an input.
The formula is stabilized numerically, so you may try various values of \(\kappa\) without unbalencing the transport cost. It ends up being implemented as follows:
Also note that for LabeledCostSamplers, the variance of the labels samplers is cattered to your choice of \(\kappa\).
For reference, the grammar of this specification string is the following, with FLOAT representig a python floating point number interpolated in a string:
1// Entry point
2spec: engine DASH type DASH FLOAT DASH FLOAT kappa? ;
3
4DASH: '-' ;
5
6FLOAT: .* ; // Python-parseable positive floating point number
7
8// NC for simple p-powered k-norm cost
9// NLC for same with a penalization of label switches with weight kappa
10type: 'NC' | 'NLC' ;
11
12kappa: DASH FLOAT ; // Python-parseable positive floating point number
Building you own cost function#
Many applications of WDRO do not fall into the setting above.
From optics to discrete optimisation, they need to impose other kinds of structure via the cost function.
So instead of trying to cover every case by hand, we allow users to subclass skwdro.base.costs_torch.Cost in order to implement their own.
The documentation of this useful abstract class will guide you through the methods you need to implement:
- class skwdro.base.costs_torch.TorchCost(name: str = '', engine: str = '')[source]
Base class for transport functions
- forward(xi: Tensor, zeta: Tensor, xi_labels: Tensor | None = None, zeta_labels: Tensor | None = None) Tensor[source]
This function is called by default when using the __call__ dunder of pytorch modules: it sends directly to the
value()method.
- abstractmethod solve_max_series_exp(xi: Tensor, xi_labels: Tensor | None, rhs: Tensor, rhs_labels: Tensor | None) Tuple[Tensor, Tensor | None][source]
Override this method to provide an explicit solution to the expansion of the inner supremum one would wish to solve if they were solving the usual WDRO approach:
\[\zeta^\texttt{imp_samp}:=\text{arg}\min_{\zeta} \left\langle\nabla_\xi L_\theta(\xi)\mid\zeta-\xi\right\rangle - \lambda c(\xi, \zeta).\]Important
This is an unconstrained first-order approximation of the supremum, which can be ill-posed or untractable, but is usually cheap enough for efficient importance sampling. One may attempt to implement higher-order approximations and add constraints if cheap enough solutions are available, for reasonably small models, if desired.
The methods you should override are the following:
the
skwdro.base.costs_torch.Cost.value()forwarding method that should take couples \((\xi, \zeta)\), specified as four arguments(xi, xi_labels, zeta, zeta_labels)(with_labelarguments allowed to be set toNone). It should return the cost incurred for transporting one unit of mass from \(\xi\) to \(\zeta\).the
skwdro.base.costs_torch.Cost._sampler_data()method is useful if you wish to build askwdro.base.samplers.torch.NoLabelsCostSamplerorskwdro.base.samplers.torch.LabeledCostSampler. Let it return thetorch.distributions.Distributioninstance from which to sample data points, given \(\xi\).same goes for the
skwdro.base.costs_torch.Cost._sampler_labels()method if your model has targets. It can return aNonevalue instead of a distribution if your whole setup handlesNoneas labels.the
skwdro.base.costs_torch.Cost.solve_max_series_exp()method lets you use importance sampling if you believe that it is well defined for your cost function and the structure of the problem studied. If not, make itraiseand set all importance sampling flags toFalseat the creation of the loss function.
Illustration on some new example#
Consider a problem stemming from some gaussian curvature prescription model, or unbalanced WDRO [4]:
the space of available samples is the sphere \(\Xi = \mathcal{S}^{d-1}\),
the cost function is log-bilinear for samples pointing in the same half-space
\[c(\xi, \zeta) = -\log(\texttt{ReLU}[\left\langle\zeta, \xi\right\rangle]) + \chi_{\{(x, y)|\langle x, y\rangle > 0\}}(\xi, \zeta).\]
If we want to use this structure to build some WDRO model, you may implement this cost functional as is done bellow. As discussed previously, it is strongly advised to think imediately about the generating distribution \(\nu_\xi\) at the same time as the cost function, as the interplay between them is crucial empirically.
1import torch
2from skwdro.distributions import Distribution
3from skwdro.base.costs_torch import Cost
4
5
6class HalfSphereUniform(Distribution):
7 """
8 Proposition of a torch ``Distribution`` that samples uniformly on the half-sphere
9 pointing in the same direction as the "center" ``xi``. The expectation of this distribution is
10 thus ``xi``.
11 """
12 def __init__(self, xi):
13 super().__init__(torch.Size(), xi.shape, False)
14 self.center = xi
15
16 def rsample(self, sample_shape = torch.Size()) -> torch.Tensor:
17 """
18 Generates a sample_shape shaped reparameterized sample or sample_shape
19 shaped batch of reparameterized samples if the distribution parameters
20 are batched.
21 Samples an isotropic gaussian, then projects the samples to the right half-space of
22 positive scalar product with the "center" ``xi``, and finally projects them on the
23 sphere.
24 """
25 noise = torch.randn(sample_shape + self.center.size())
26 dim_range = tuple(range(len(sample_shape), noise.dim()))
27 projected_noise = noise * torch.sign(-torch.sum(self.center * noise, dim=dim_range, keepdim=True))
28 return projected_noise / torch.linalg.norm(projected_noise, dim=dim_range, keepdim=True)
Now that we have a nice distribution to associate with our cost function, we can specify it with the relevant helper class.
1class GaussianPrescriptionCost(Cost):
2 def __init__(self):
3 # This initialization procedure is not important
4 super().__init__(
5 "Gaussian-curvature-prescription cost functional",
6 "pt"
7 )
8 # Here is the important line: the homogeneity for the radius
9 # is set bellow (see next section if you are curious).
10 self.power: float = 1.
11
12 def value(self, xi, xi_labels, zeta, zeta_labels):
13 r"""
14 This value function computes the cost of a pair (``xi``, ``zeta``).
15
16 .. math::
17 c(\xi, \zeta):=-\log\left([\langle\zeta,\xi\rangle]_+\right)
18 """
19 assert xi_labels is None
20 assert zeta_labels is None
21 # Write the cost function here, using only pytorch functions to
22 # allow all the internal machinery to go smoothly. Here e.g.
23 # we leverage the relu function to compute max(0, <x,y>)
24 scalar_prod = (xi * zeta).sum(dim=-1)
25 if scalar_prod <= 0.:
26 return torch.zeros_like(scalar_prod)
27 return -torch.log(torch.nn.functional.relu(scalar_prod))
28
29 def _sampler_data(self, xi, epsilon):
30 del epsilon
31 return HalfSphereUniform(xi)
32
33 def _sampler_labels(self, xi_labels, epsilon):
34 assert xi_labels is None
35 return None
36
37 def solve_max_series_exp(self, xi, xi_labels, rhs, rhs_labels):
38 assert xi_labels is None
39 assert rhs_labels is None
40 del rhs
41 return xi, xi_labels
The definition of the half-sphere sampler is not mendatory! You may use any placeholder you want for the _sampler_data method overload, and follow the user guide to see how to implement you own custom sampler and integrate it to the dual-loss formulation.
>>> c = GaussianPrescriptionCost()
>>> xi = torch.rand((2, 3))
>>> xi = xi / torch.norm(xi, dim=-1, keepdim=True) # project on sphere
>>> # Try our sampler for validation: the cost should be > 0
>>> zeta = c._sampler_data(xi, torch.tensor(0.1)).sample(torch.Size((5,)))
>>> print(c(xi, zeta))
tensor([[3.7459, 0.5499],
[0.5176, 2.2018],
[0.1327, 0.0773],
[0.0441, 0.6632],
[0.7007, 0.0457]])
>>> print(c(xi, xi).abs().max())
tensor(1.1921e-07)
This cost function is not associated trivially to any learning problems, but it showcases the way you can impose geometrical structure on the SkWDRO framework in general.
To go further#
Getting back to the first WDRO tutorial, you may recall that the transport cost constraint was cast as follows
But in order to get a true distance, one may not use any cost function!
So e.g. in the case of distances put to some power \(p\) as in NormCost, one must acknowledge that the power \(p\) must be taken into account in the radius.
For example, in the litterature, the distance is defined in a straightforward way for any distance \(d\)
To avoid changing all the theoretical derivations related to the duality results in the SkWDRO framework, we just raise both sides of the equation to the power \(p\), which translates to appropriate tricks inside of the libraries solvers.
Tip
In fact the interface is flexible enough to specify this behavior for any loss! Set the skwdro.base.costs_torch.TorchCost.p attribute to any positive floating point number to allow for a power parameter, defining to what power you want to raise both sides of the equation so that the cost is of the same order of magnitude in average as \(\rho\). See the example above.