Note
Go to the end to download the full example code.
Effect of the epsilon Sinkhorn regularization parameter#
This example illustrates the use of the skwdro.linear_models.LogisticRegression class on datasets that are shifted at test time.
It uses this setting to study the (small) impact of the regularization hyperparameter on the accuracy of the classification.
import numpy as np
from sklearn.datasets import make_blobs
from skwdro.linear_models import LogisticRegression
from skwdro.solvers.optim_cond import OptCondTorch
from utils.classifier_comparison_utils import plot_classifier_comparison
Setup#
n = 50 # Total number of samples
n_train = (3 * n) // 4 # Number of training samples
n_test = n - n_train # Number of test samples
sdevs = [(2.5, 5), (1, 5)]
# Fix centers for blobs dataset
pos = 4
centers = [np.array([-pos,-pos]), np.array([pos,pos])]
# Create datasets with variance that is shifted at test time
datasets = []
for (sdev_1, sdev_2) in sdevs:
train_dataset = make_blobs(n_samples=n_train, centers=centers, cluster_std=(sdev_1, sdev_2)) # type: ignore
test_dataset = make_blobs(n_samples=n_test, centers=centers, cluster_std=(sdev_2, sdev_1)) # type: ignore
datasets.append((train_dataset, test_dataset))
WDRO classifiers#
We build various SkWDRO estimators for \(\varepsilon\) varying.
# Rho chosen analytically
rho = 1e-0 # 2*4**2
# Enthropic regularization: test various ones
e0, e1 = -3, 1
regs = np.logspace(e0, e1, base=10, num=5)
# Kappa: weight of label shift
kappa = 100000
# Cost:
# t: torch backend
# NLC: norm cost that takes labels into account
# 2 2 : squared 2-norm
# kappa: weight of label shift
cost = f"t-NLC-2-2-{kappa}"
# WDRO classifier
classifiers = [
LogisticRegression(rho=0.),
*(LogisticRegression(
rho=rho,
cost=cost,
solver_reg=eps,
n_zeta_samples=100,
opt_cond=OptCondTorch(order='inf', tol_theta=1e-9, mode='abs')
) for eps in regs)
]
Make plot#
Observe that the accuracy changes with the \(\varepsilon\).

0%| | 0/6 [00:00<?, ?it/s]
0%| | 0/6 [00:00<?, ?it/s,
0.946/0.946
]
17%|█▋ | 1/6 [00:00<00:03, 1.51it/s,
0.946/0.946
]
17%|█▋ | 1/6 [00:09<00:03, 1.51it/s,
0.690/0.946
]
33%|███▎ | 2/6 [00:09<00:21, 5.44s/it,
0.690/0.946
]
33%|███▎ | 2/6 [00:18<00:21, 5.44s/it,
0.659/0.946
]
50%|█████ | 3/6 [00:18<00:20, 6.99s/it,
0.659/0.946
]
50%|█████ | 3/6 [00:27<00:20, 6.99s/it,
0.476/0.946
]
67%|██████▋ | 4/6 [00:27<00:15, 7.74s/it,
0.476/0.946
]
67%|██████▋ | 4/6 [00:35<00:15, 7.74s/it,
0.424/0.946
]
83%|████████▎ | 5/6 [00:35<00:08, 8.12s/it,
0.424/0.946
]
83%|████████▎ | 5/6 [00:44<00:08, 8.12s/it,
0.492/0.946
]
100%|██████████| 6/6 [00:44<00:00, 8.31s/it,
0.492/0.946
]
100%|██████████| 6/6 [00:44<00:00, 7.44s/it,
0.492/0.946
]
0%| | 0/6 [00:00<?, ?it/s]
0%| | 0/6 [00:01<?, ?it/s,
0.957/0.957
]
17%|█▋ | 1/6 [00:01<00:08, 1.64s/it,
0.957/0.957
]
17%|█▋ | 1/6 [00:10<00:08, 1.64s/it,
0.691/0.957
]
33%|███▎ | 2/6 [00:10<00:23, 5.86s/it,
0.691/0.957
]
33%|███▎ | 2/6 [00:18<00:23, 5.86s/it,
0.669/0.957
]
50%|█████ | 3/6 [00:19<00:21, 7.14s/it,
0.669/0.957
]
50%|█████ | 3/6 [00:27<00:21, 7.14s/it,
0.523/0.957
]
67%|██████▋ | 4/6 [00:27<00:15, 7.74s/it,
0.523/0.957
]
67%|██████▋ | 4/6 [00:36<00:15, 7.74s/it,
0.396/0.957
]
83%|████████▎ | 5/6 [00:36<00:08, 8.06s/it,
0.396/0.957
]
83%|████████▎ | 5/6 [00:44<00:08, 8.06s/it,
0.432/0.957
]
100%|██████████| 6/6 [00:45<00:00, 8.27s/it,
0.432/0.957
]
100%|██████████| 6/6 [00:45<00:00, 7.51s/it,
0.432/0.957
]
Total running time of the script: (1 minutes 30.103 seconds)