r"""
Simple Neural Network
=====================

This example solves a simple binary classification problem using a basic
neural network with 2 layers.

The classification problem is generated by the make_moons dataset generator
from scikit--learn.


"""
import matplotlib.pyplot as plt
from utils.plotting import plot_decision_boundary
from tqdm.auto import tqdm
import torch as pt
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import train_test_split

from skwdro.torch import robustify


# %%
# Problem setup
# ~~~~~~~~~~~~~

from sklearn.datasets import make_moons

n = 256 + 64

X, y = make_moons(n_samples=n,
                  noise=0.05,
                  random_state=42)


# %%
# Visualize the data
# ~~~~~~~~~~~~~~~~~~

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdYlBu)  # type: ignore
plt.show()

# %%
# Preprocessing
# ~~~~~~~~~~~~~

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    train_size=256,
    test_size=64,
    random_state=42
)

device = "cuda" if pt.cuda.is_available() else "cpu"
# Set a seed for torch to avoid too widely different results
pt.manual_seed(42)

# Turn data into tensors
full_batch_x = pt.from_numpy(X_train).to(device)
full_batch_y = pt.from_numpy(y_train).unsqueeze(-1).to(full_batch_x)
dataset = DataLoader(
    TensorDataset(
        full_batch_x,
        full_batch_y
    ),
    batch_size=64
)

batch_x_test = pt.from_numpy(X_train).to(device)
batch_y_test = pt.from_numpy(y_train).unsqueeze(-1).to(batch_x_test)


# %%
# Two-layers model
# ~~~~~~~~~~~~~~~~

class SimpleNN(nn.Module):
    def __init__(self, in_features, out_features, hidden_units):
        super().__init__()
        # Two hidden layers and logit output
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, out_features),
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits


loss_fn = nn.BCEWithLogitsLoss(reduction='none')

# Define a sample batch for initialization
sample_batch_x, sample_batch_y = next(iter(dataset))

# %%
# Set the models up
# ~~~~~~~~~~~~~~~~~
# First build the ERM model solving:
#
# .. math::
#    \min_\theta \frac{1}{N}\sum_{i=1}^NL_\theta(\xi_i)
#

erm_model = SimpleNN(
    in_features=2,
    out_features=1,
    hidden_units=10
).to(full_batch_x)
print(erm_model)

# ERM loss, with the same interfaces as the robust one
erm_loss = robustify(
    loss_fn,
    erm_model,
    pt.tensor(0.),
    sample_batch_x, sample_batch_y
)  # Replaces the loss of the model by the dual WDRO loss

# %%
# Then the robust model solving
#
# .. math::
#    \min_{\theta, \lambda\ge 0} \rho\lambda+\frac{1}{N}\sum_{i=1}^N\texttt{LogSumExp}_\varepsilon \{L_\theta(\cdot)-\lambda\|\cdot-\xi_i\|_2^2\}
#

robust_model = SimpleNN(
    in_features=2,
    out_features=1,
    hidden_units=100
).to(full_batch_x)


print(robust_model)

# Robust loss
robust_loss = robustify(
    loss_fn,
    robust_model,
    pt.tensor(1e-3),
    sample_batch_x, sample_batch_y,
    cost_spec="t-NLC-2-2",
    n_samples=16
)  # Replaces the loss of the model by the dual WDRO loss

# %%
# Training loop
# ~~~~~~~~~~~~~

def train(model, epochs = 300, lr = 5e-3):
    optimizer = pt.optim.AdamW(params=model.parameters(), lr=lr)
    # optimizer = pt.optim.AdamW(params=robust_loss.parameters())

    # Training loop
    epoch_iterator = tqdm(range(epochs), position=0, desc='Epochs', leave=True)
    losses = []
    for _ in epoch_iterator:
        avg_testloss = 0.
        for batch_x, batch_y in tqdm(dataset, position=1, desc='Sample', leave=False):

            # ## Training
            model.train()

            optimizer.zero_grad()
            # loss = loss_fn(model(batch_x.squeeze()), batch_y)
            loss = model(batch_x, batch_y, reset_sampler=True)
            loss.backward()
            optimizer.step()

            # ## Testing
            model.eval()
            with pt.no_grad():
                # Forward pass
                model.erm_mode = True
                test_logits = model.primal_loss.transform(batch_x_test)
                test_pred = pt.round(pt.sigmoid(test_logits))
                # Compute the loss
                avg_testloss += loss_fn(test_logits, batch_y_test).mean().item()
                model.erm_mode = False
            epoch_iterator.set_postfix(
                {'acc': f"{(test_pred == batch_y_test).float().mean().item()*100}%"}
            )
        losses.append(avg_testloss / len(dataset))

        # Print
        epoch_iterator.set_postfix({'loss': avg_testloss / len(dataset)})
    return losses


# %%
# Learn the ERM model
# ~~~~~~~~~~~~~~~~~~~
erm_losses = train(erm_loss)

# %%
# Learn the SkWDRO model
# ~~~~~~~~~~~~~~~~~~~~~~
dro_losses = train(robust_loss)

# %%
# Visuals
# ~~~~~~~
# First, the ERM:

# Plot decision boundaries for training and test sets
plt.figure(figsize=(12, 6))

plt.subplot(2, 2, 1)
plt.title("Train")
plot_decision_boundary(erm_model, full_batch_x, full_batch_y)

plt.subplot(2, 2, 2)
plt.title("Test")
plot_decision_boundary(erm_model, batch_x_test, batch_y_test)

plt.subplot(2, 1, 2)
plt.title("Test loss through epochs")
plt.plot(erm_losses)
plt.yscale('log')

plt.show()

# %%
# Then the DRO model:

# Plot decision boundaries for training and test sets
plt.figure(figsize=(12, 6))

plt.subplot(2, 2, 1)
plt.title("Train")
plot_decision_boundary(robust_model, full_batch_x, full_batch_y)

plt.subplot(2, 2, 2)
plt.title("Test")
plot_decision_boundary(robust_model, batch_x_test, batch_y_test)

plt.subplot(2, 1, 2)
plt.title("Test loss through epochs")
plt.plot(dro_losses)
plt.yscale('log')

plt.show()
