Source code for skwdro.base.losses_torch.weber
import torch as pt
import torch.nn as nn
[docs]
class SimpleWeber(nn.Module):
reduction = 'none'
def __init__(self, d: int) -> None:
super(SimpleWeber, self).__init__()
self.pos = nn.Parameter(pt.zeros(d))
self.d = d
[docs]
def forward(self, xi: pt.Tensor, xi_labels: pt.Tensor) -> pt.Tensor:
distances = pt.linalg.norm(
xi - self.pos.unsqueeze(0), dim=-1, keepdims=True
)
val = xi_labels * distances * xi_labels.shape[1]
assert isinstance(val, pt.Tensor)
return val