Recipe for a good ground-cost for Wasserstein-DRO#

Practice

Tip

Read the tutorial on SkWDRO to understand better this part.

Recall the formula for SkWDRO:

\[L_\theta^\texttt{robust}(\xi) := \lambda\rho + \varepsilon\log\mathbb{E}_{\zeta\sim\nu_\xi}\left[e^{\frac{L_\theta(\zeta)-\lambda c(\xi, \zeta)}{\varepsilon}}\right].\]

It includes a cost function \(c\) that imposes a notion of geometry in the design space of the samples \(\xi,\zeta\in\Xi\). We will answer the question of how to pick this crucial hyperparameter.

Distance structure#

Many of the examples of robustness treated in the litterature showcase costs of the form

\[c(\xi, \zeta) = d(\xi, \zeta)^p\]

for some power \(p\ge 1\). These are especially well treated in the litterature for their nice distance properties that are transmitted to the transport cost, making it the so-called Wasserstein distance [2] (see last section for implementation details).

Simple cases: turn to euclidean geometry#

If your problem is formulated in simple cases in which no structure is prescribed on the space of samples \(\Xi\), your cost should look like the euclidean norm:

\[c(\xi, \zeta) := \|\zeta-\xi\|_2\]

Then, this opens a wide range of Wasserstein distances, called the \(W_p\) distances in the litterature, for which you raise this norm to some power and change the allowed radius \(\rho\) accordingly.

\[c(\xi, \zeta) := \|\zeta-\xi\|_2^p\]

The choice of p can be made in accordance with the behaviour of the loss function \(L_\theta\) in its “worst regions”. To learn more about the growth criteria that make most sense regarding this remark, take a look at [3].

The interface of SkWDRO with regard to this matter is very simple to use: in the robustification interface you may directly specify such a structure by using a decoded string following our simple grammar.

Here, this is done with the following cost specification:

specification of the cost power#
p: float = 1.  # pick to your liking
cost_spec = f"t-NC-2-{p}"

Tip

Notice that if you choose \(p=2\), you gain access to an efficient importance-sampling algorithm. More on this in another tutorial…

Other norm-based costs#

Then, if you want to impose more geometry to interpret your results in a different way (e.g. comparing to the robustness induced by FGSA, or anything like this), you can change the norm to a different one by specifying a \(k\) parameter.

specification of the norm type#
p: float = 1.
k: float = 1.  # pick to your liking
cost_spec = f"t-NC-{k}-{p}"

What about targets?#

If the loss function relates to a classification or regression task, then one may allow for targets to be subject to uncertainty as well. The cost specification interface lets you account for that as well, by switching from the usual “Norm-Cost” (NC) to a “Norm-with-Label-Cost” (NLC):

specification of the norm type#
p: float = 1.
k: float = 1.  # pick to your liking
cost_spec = f"t-NLC-{k}-{p}"

Tip

Again, picking the 2-2 combination unlocks the importance-sampling algorithm.

Checking the litterature on WDRO (see [1] Remark 8), one may see that there is a range of ways to interpolate between no target transport being allowed, and a cost of the same magnitude for changes of targets as changes to the inputs. We can introduce a hyperparameter \(\kapppa\) that lets us weight the contribution of a target change to the transport cost.

\[c(\xi, \zeta) = \|\xi^\texttt{input} - \zeta^\texttt{input}\|_k^p + \mathbf{\kappa}\|\xi^\texttt{target} - \zeta^\texttt{target}\|_k^p\]

This can be specified to the NLC cost parser easiuly as the last optional parameter.

specification of the norm type#
p: float = 1.
k: float = 1.  # pick to your liking
cost_spec = f"t-NLC-{k}-{p}-10.0"

Tip

As a guideline, using \(\kappa=\infty\) amounts to using NC (rule of thumb: moving a target is now infinitely costly, so it is “not allowed”), while putting it to a small value will make it comparatively easier/”cheaper” to move an target than an input. The formula is stabilized numerically, so you may try various values of \(\kappa\) without unbalencing the transport cost. It ends up being implemented as follows:

\[c(\xi, \zeta) = \left(\frac1{1+\kappa}\|\xi^\texttt{input} - \zeta^\texttt{input}\|_k^p + \frac\kappa{1+\kappa}\|\xi^\texttt{target} - \zeta^\texttt{target}\|_k\right)^p\]

Also note that for LabeledCostSamplers, the variance of the labels samplers is cattered to your choice of \(\kappa\).

For reference, the grammar of this specification string is the following, with FLOAT representig a python floating point number interpolated in a string:

Grammar for the cost-specification strings.#
 1// Entry point
 2spec: engine DASH type DASH FLOAT DASH FLOAT kappa? ;
 3
 4DASH: '-' ;
 5
 6FLOAT: .* ; // Python-parseable positive floating point number
 7
 8// NC for simple p-powered k-norm cost
 9// NLC for same with a penalization of label switches with weight kappa
10type: 'NC' | 'NLC' ;
11
12kappa: DASH FLOAT ; // Python-parseable positive floating point number

Building you own cost function#

Many applications of WDRO do not fall into the setting above. From optics to discrete optimisation, they need to impose other kinds of structure via the cost function. So instead of trying to cover every case by hand, we allow users to subclass skwdro.base.costs_torch.Cost in order to implement their own.

The documentation of this useful abstract class will guide you through the methods you need to implement:

class skwdro.base.costs_torch.TorchCost(name: str = '', engine: str = '')[source]

Base class for transport functions

forward(xi: Tensor, zeta: Tensor, xi_labels: Tensor | None = None, zeta_labels: Tensor | None = None) Tensor[source]

This function is called by default when using the __call__ dunder of pytorch modules: it sends directly to the value() method.

abstractmethod solve_max_series_exp(xi: Tensor, xi_labels: Tensor | None, rhs: Tensor, rhs_labels: Tensor | None) Tuple[Tensor, Tensor | None][source]

Override this method to provide an explicit solution to the expansion of the inner supremum one would wish to solve if they were solving the usual WDRO approach:

\[\zeta^\texttt{imp_samp}:=\text{arg}\min_{\zeta} \left\langle\nabla_\xi L_\theta(\xi)\mid\zeta-\xi\right\rangle - \lambda c(\xi, \zeta).\]

Important

This is an unconstrained first-order approximation of the supremum, which can be ill-posed or untractable, but is usually cheap enough for efficient importance sampling. One may attempt to implement higher-order approximations and add constraints if cheap enough solutions are available, for reasonably small models, if desired.

The methods you should override are the following:

  • the skwdro.base.costs_torch.Cost.value() forwarding method that should take couples \((\xi, \zeta)\), specified as four arguments (xi, xi_labels, zeta, zeta_labels) (with _label arguments allowed to be set to None). It should return the cost incurred for transporting one unit of mass from \(\xi\) to \(\zeta\).

  • the skwdro.base.costs_torch.Cost._sampler_data() method is useful if you wish to build a skwdro.base.samplers.torch.NoLabelsCostSampler or skwdro.base.samplers.torch.LabeledCostSampler. Let it return the torch.distributions.Distribution instance from which to sample data points, given \(\xi\).

  • same goes for the skwdro.base.costs_torch.Cost._sampler_labels() method if your model has targets. It can return a None value instead of a distribution if your whole setup handles None as labels.

  • the skwdro.base.costs_torch.Cost.solve_max_series_exp() method lets you use importance sampling if you believe that it is well defined for your cost function and the structure of the problem studied. If not, make it raise and set all importance sampling flags to False at the creation of the loss function.

Illustration on some new example#

Consider a problem stemming from some gaussian curvature prescription model, or unbalanced WDRO [4]:

  • the space of available samples is the sphere \(\Xi = \mathcal{S}^{d-1}\),

  • the cost function is log-bilinear for samples pointing in the same half-space

    \[c(\xi, \zeta) = -\log(\texttt{ReLU}[\left\langle\zeta, \xi\right\rangle]) + \chi_{\{(x, y)|\langle x, y\rangle > 0\}}(\xi, \zeta).\]

If we want to use this structure to build some WDRO model, you may implement this cost functional as is done bellow. As discussed previously, it is strongly advised to think imediately about the generating distribution \(\nu_\xi\) at the same time as the cost function, as the interplay between them is crucial empirically.

Build a relevant sampler tailored to our cost function#
 1import torch
 2from skwdro.distributions import Distribution
 3from skwdro.base.costs_torch import Cost
 4
 5
 6class HalfSphereUniform(Distribution):
 7    """
 8    Proposition of a torch ``Distribution`` that samples uniformly on the half-sphere
 9    pointing in the same direction as the "center" ``xi``. The expectation of this distribution is
10    thus ``xi``.
11    """
12    def __init__(self, xi):
13        super().__init__(torch.Size(), xi.shape, False)
14        self.center = xi
15
16    def rsample(self, sample_shape = torch.Size()) -> torch.Tensor:
17        """
18        Generates a sample_shape shaped reparameterized sample or sample_shape
19        shaped batch of reparameterized samples if the distribution parameters
20        are batched.
21        Samples an isotropic gaussian, then projects the samples to the right half-space of
22        positive scalar product with the "center" ``xi``, and finally projects them on the
23        sphere.
24        """
25        noise = torch.randn(sample_shape + self.center.size())
26        dim_range = tuple(range(len(sample_shape), noise.dim()))
27        projected_noise = noise * torch.sign(-torch.sum(self.center * noise, dim=dim_range, keepdim=True))
28        return projected_noise / torch.linalg.norm(projected_noise, dim=dim_range, keepdim=True)

Now that we have a nice distribution to associate with our cost function, we can specify it with the relevant helper class.

Cost function for gaussian curvature prescription#
 1class GaussianPrescriptionCost(Cost):
 2    def __init__(self):
 3        # This initialization procedure is not important
 4        super().__init__(
 5           "Gaussian-curvature-prescription cost functional",
 6           "pt"
 7        )
 8        # Here is the important line: the homogeneity for the radius
 9        # is set bellow (see next section if you are curious).
10        self.power: float = 1.
11
12    def value(self, xi, xi_labels, zeta, zeta_labels):
13        r"""
14        This value function computes the cost of a pair (``xi``, ``zeta``).
15
16        .. math::
17            c(\xi, \zeta):=-\log\left([\langle\zeta,\xi\rangle]_+\right)
18        """
19        assert xi_labels is None
20        assert zeta_labels is None
21        # Write the cost function here, using only pytorch functions to
22        #    allow all the internal machinery to go smoothly. Here e.g.
23        #    we leverage the relu function to compute max(0, <x,y>)
24        scalar_prod = (xi * zeta).sum(dim=-1)
25        if scalar_prod <= 0.:
26            return torch.zeros_like(scalar_prod)
27        return -torch.log(torch.nn.functional.relu(scalar_prod))
28
29    def _sampler_data(self, xi, epsilon):
30        del epsilon
31        return HalfSphereUniform(xi)
32
33    def _sampler_labels(self, xi_labels, epsilon):
34        assert xi_labels is None
35        return None
36
37    def solve_max_series_exp(self, xi, xi_labels, rhs, rhs_labels):
38        assert xi_labels is None
39        assert rhs_labels is None
40        del rhs
41        return xi, xi_labels

The definition of the half-sphere sampler is not mendatory! You may use any placeholder you want for the _sampler_data method overload, and follow the user guide to see how to implement you own custom sampler and integrate it to the dual-loss formulation.

Testing our class on synthetic data#
>>> c = GaussianPrescriptionCost()
>>> xi = torch.rand((2, 3))
>>> xi = xi / torch.norm(xi, dim=-1, keepdim=True)  # project on sphere
>>> # Try our sampler for validation: the cost should be > 0
>>> zeta = c._sampler_data(xi, torch.tensor(0.1)).sample(torch.Size((5,)))
>>> print(c(xi, zeta))
tensor([[3.7459, 0.5499],
        [0.5176, 2.2018],
        [0.1327, 0.0773],
        [0.0441, 0.6632],
        [0.7007, 0.0457]])
>>> print(c(xi, xi).abs().max())
tensor(1.1921e-07)

This cost function is not associated trivially to any learning problems, but it showcases the way you can impose geometrical structure on the SkWDRO framework in general.

To go further#

Getting back to the first WDRO tutorial, you may recall that the transport cost constraint was cast as follows

\[W(\hat{\mathbb{P}}^N, \mathbb{Q}) \le \rho.\]

But in order to get a true distance, one may not use any cost function! So e.g. in the case of distances put to some power \(p\) as in NormCost, one must acknowledge that the power \(p\) must be taken into account in the radius. For example, in the litterature, the distance is defined in a straightforward way for any distance \(d\)

\[W_p(\hat{\mathbb{P}}^N, \mathbb{Q}) := \sqrt[p]{\inf_{\pi\in\Pi(\hat{\mathbb{P}}^N, \mathbb{Q})} \int_{\Xi^2}d(\xi, \zeta)^p d\pi(\xi, \zeta)} \le \rho\]

To avoid changing all the theoretical derivations related to the duality results in the SkWDRO framework, we just raise both sides of the equation to the power \(p\), which translates to appropriate tricks inside of the libraries solvers.

Tip

In fact the interface is flexible enough to specify this behavior for any loss! Set the skwdro.base.costs_torch.TorchCost.p attribute to any positive floating point number to allow for a power parameter, defining to what power you want to raise both sides of the equation so that the cost is of the same order of magnitude in average as \(\rho\). See the example above.

References#