Note
Go to the end to download the full example code.
Spatial perturbations and logistic regression#
This example illustrates the use of the skwdro.linear_models.LogisticRegression class on datasets that are shifted at test time.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs, make_moons
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from skwdro.linear_models import LogisticRegression
from utils.classifier_comparison import plot_classifier_comparison
Setup#
n = 500 # Total number of samples
n_train = (3 * n) // 4 # Number of training samples
n_test = n - n_train # Number of test samples
sdevs = [(2.5, 5), (1, 5)]
# Fix centers for blobs dataset
pos = 4
centers = [np.array([-pos,-pos]), np.array([pos,pos])]
# Create datasets with variance that is shifted at test time
datasets = []
for (sdev_1, sdev_2) in sdevs:
train_dataset = make_blobs(n_samples=n_train, centers=centers, cluster_std=(sdev_1, sdev_2)) # type: ignore
test_dataset = make_blobs(n_samples=n_test, centers=centers, cluster_std=(sdev_2, sdev_1)) # type: ignore
datasets.append((train_dataset, test_dataset))
WDRO classifiers#
# Rho chosen analytically
rhos = [0, 2*4**2]
# Kappa: weight of label shift
kappa = 1000
# Cost:
# t: torch backend
# NLC: norm cost that takes labels into account
# 2 2 : squared 2-norm
# kappa: weight of label shift
cost = f"t-NLC-2-2-{kappa}"
# WDRO classifier
classifiers = [LogisticRegression(rho=rho, cost=cost) for rho in rhos]
Make plot#

0%| | 0/2 [00:00<?, ?it/s]
0%| | 0/2 [00:02<?, ?it/s,
0.292/0.292
]
50%|█████ | 1/2 [00:02<00:02, 2.81s/it,
0.292/0.292
]
50%|█████ | 1/2 [00:06<00:02, 2.81s/it,
0.693/0.292
]
100%|██████████| 2/2 [00:06<00:00, 3.29s/it,
0.693/0.292
]
100%|██████████| 2/2 [00:06<00:00, 3.22s/it,
0.693/0.292
]
0%| | 0/2 [00:00<?, ?it/s]
0%| | 0/2 [00:00<?, ?it/s,
0.494/0.494
]
50%|█████ | 1/2 [00:00<00:00, 1.34it/s,
0.494/0.494
]
50%|█████ | 1/2 [00:09<00:00, 1.34it/s,
0.691/0.494
]
100%|██████████| 2/2 [00:09<00:00, 5.57s/it,
0.691/0.494
]
100%|██████████| 2/2 [00:09<00:00, 4.84s/it,
0.691/0.494
]
Total running time of the script: (0 minutes 16.590 seconds)