r"""
Simple Neural Network
=====================

This example solves a simple binary classification problem using a basic
neural network with 2 layers.

The classification problem is generated by the make_moons dataset generator
from scikit--learn.


"""
import matplotlib.pyplot as plt
from utils.plotting import plot_decision_boundary
from tqdm import tqdm
import torch as pt
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import train_test_split

from skwdro.torch import robustify


# %%
# Problem setup
# ~~~~~~~~~~~~~

from sklearn.datasets import make_moons

n = 512 + 64

X, y = make_moons(n_samples=n,
                  noise=0.05,
                  random_state=42)


# %%
# Visualize the data
# ~~~~~~~~~~~~~~~~~~

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdYlBu)  # type: ignore
plt.show()

# %%
# Preprocessing
# ~~~~~~~~~~~~~

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    train_size=512,
    test_size=64,
    random_state=42
)

device = "cuda" if pt.cuda.is_available() else "cpu"

# Turn data into tensors
full_batch_x = pt.from_numpy(X_train).to(device)
full_batch_y = pt.from_numpy(y_train).unsqueeze(-1).to(full_batch_x)
dataset = DataLoader(
    TensorDataset(
        full_batch_x,
        full_batch_y
    ),
    batch_size=64
)

batch_x_test = pt.from_numpy(X_train).to(device)
batch_y_test = pt.from_numpy(y_train).unsqueeze(-1).to(batch_x_test)


# %%
# Two-layers model
# ~~~~~~~~~~~~~~~~

class SimpleNN(nn.Module):
    def __init__(self, in_features, out_features, hidden_units):
        super().__init__()
        # Two hidden layers and logit output
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, out_features),
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits


# %%
# Set the model up
# ~~~~~~~~~~~~~~~~

model = SimpleNN(
    in_features=2,
    out_features=1,
    hidden_units=5
).to(full_batch_x)


print(model)


loss_fn = nn.BCEWithLogitsLoss(reduction='none')

# Define a sample batch for initialization
sample_batch_x, sample_batch_y = next(iter(dataset))


# Robust loss
robust_loss = robustify(
    loss_fn,
    model,
    pt.tensor(1e-4),
    sample_batch_x, sample_batch_y,
    cost_spec="t-NC-2-2",
    n_samples=16
)  # Replaces the loss of the model by the dual WDRO loss

# %%
# Training loop
# ~~~~~~~~~~~~~

pt.manual_seed(42)
epochs = 250

# optimizer = pt.optim.AdamW(params=model.parameters(),lr=1e-2)
optimizer = pt.optim.AdamW(params=robust_loss.parameters())


# Training loop
iterator = tqdm(range(epochs), position=0, desc='Epochs', leave=False)
losses = []
for epoch in iterator:

    avg_testloss = 0.
    for batch_x, batch_y in tqdm(dataset, position=1, desc='Sample', leave=False):

        # ## Training
        model.train()

        optimizer.zero_grad()
        # loss = loss_fn(model(batch_x.squeeze()), batch_y)
        loss = robust_loss(batch_x, batch_y, reset_sampler=True)
        loss.backward()
        optimizer.step()

        # ## Testing
        model.eval()
        with pt.no_grad():
            # Forward pass
            test_logits = model(batch_x_test)
            test_pred = pt.round(pt.sigmoid(test_logits))
            # Compute the loss
            avg_testloss += loss_fn(test_logits, batch_y_test).mean().item()
        iterator.set_postfix(
            {'acc': f"{(test_pred == batch_y_test).float().mean().item()*100}%"}
        )
        losses.append(loss.item())

    # Print
    iterator.set_postfix({'loss': avg_testloss / len(dataset)})


# %%
# Visuals
# ~~~~~~~

# Plot decision boundaries for training and test sets
plt.figure(figsize=(12, 6))

plt.subplot(2, 2, 1)
plt.title("Train")
plot_decision_boundary(model, full_batch_x, full_batch_y)

plt.subplot(2, 2, 2)
plt.title("Test")
plot_decision_boundary(model, batch_x_test, batch_y_test)

plt.subplot(2, 1, 2)
plt.title("Test loss through epochs")
plt.plot(losses)
plt.yscale('log')

plt.show()
