Source code for skwdro.solvers.utils

from typing import Union, Optional, Tuple
import numpy as np
import torch as pt


Steps = Union[int, Tuple[int, int]]


[docs] def detach_tensor(tensor: pt.Tensor) -> np.ndarray: out = tensor.detach().cpu().numpy().flatten() assert isinstance(out, np.ndarray) return out # float(out) if len(out) == 1 else out
[docs] def maybe_detach_tensor(tensor: Optional[pt.Tensor]) -> Optional[np.ndarray]: return None if tensor is None else detach_tensor(tensor)
[docs] def diff_opt_tensor( tensor: Optional[pt.Tensor], us_dim: Optional[int] = 0 ) -> Optional[pt.Tensor]: if tensor is None: return None else: return diff_tensor(tensor, us_dim)
[docs] def diff_tensor(tensor: pt.Tensor, us_dim: Optional[int] = 0) -> pt.Tensor: if us_dim is not None: return tensor.clone().unsqueeze(us_dim).requires_grad_(True) else: return tensor.clone().requires_grad_(True)
[docs] def maybe_unsqueeze( tensor: Optional[pt.Tensor], dim: int = 0 ) -> Optional[pt.Tensor]: return None if tensor is None else tensor.unsqueeze(dim)
[docs] def normalize_maybe_vects( tensor: Optional[pt.Tensor], threshold: float = 1., scaling: float = 1., dim: int = 0 ) -> Optional[pt.Tensor]: return None if tensor is None else normalize_just_vects( tensor, threshold, scaling, dim )
[docs] def normalize_just_vects( tensor: pt.Tensor, threshold: float = 1., scaling: float = 1., dim: int = 0 ) -> pt.Tensor: n = pt.linalg.norm(tensor, dim=dim, keepdims=True) assert isinstance(n, pt.Tensor) return tensor / n * pt.min(pt.tensor(threshold), n) / scaling
[docs] class NoneGradError(ValueError): pass
[docs] def maybe_flatten_grad_else_raise(tensor: pt.Tensor) -> pt.Tensor: if tensor.grad is None: raise NoneGradError(tensor.shape) else: return tensor.grad.flatten()
[docs] def check_tensor_validity(tensor: pt.Tensor) -> bool: return bool(tensor.isfinite().logical_not().any().item())
[docs] def interpret_steps_struct( steps_spec: Steps, default_split: float = .3 ) -> Tuple[int, int]: if isinstance(steps_spec, int): assert 0 <= default_split <= 1. pretrain_iters = int(steps_spec * default_split) return pretrain_iters, steps_spec else: # already tuple return steps_spec