Source code for skwdro.base.losses_torch.quadratic

from typing import Optional


import torch as pt
import torch.nn as nn

from .base_loss import Loss
from skwdro.base.samplers.torch.base_samplers import (
    BaseSampler
)
from skwdro.base.samplers.torch.classif_sampler import (
    ClassificationNormalNormalSampler
)
from skwdro.base.samplers.torch.base_samplers import LabeledSampler


[docs] class QuadraticLoss(Loss): def __init__( self, sampler: LabeledSampler, *, d: int = 0, l2reg: Optional[float] = None, fit_intercept: bool = False ) -> None: super(QuadraticLoss, self).__init__(sampler, True, l2reg=l2reg) assert d > 0, "Please provide a valid data dimension d>0" self.linear = nn.Linear(d, 1, bias=fit_intercept) self.L = nn.MSELoss(reduction='none')
[docs] def regression(self, X) -> pt.Tensor: coefs = self.linear(X) assert isinstance(coefs, pt.Tensor) return coefs
[docs] def value( self, xi: pt.Tensor, xi_labels: Optional[pt.Tensor] ) -> pt.Tensor: assert xi_labels is not None coefs = self.regression(xi) return self.regularize(self.L(coefs, xi_labels))
[docs] @classmethod def default_sampler( cls, xi, xi_labels, epsilon, seed: Optional[int] ) -> BaseSampler: return ClassificationNormalNormalSampler( xi, xi_labels, seed=seed, sigma=epsilon, l_sigma=epsilon )
@property def theta(self) -> pt.Tensor: return self.linear.weight @property def intercept(self) -> pt.Tensor: return self.linear.bias