Source code for skwdro.solvers.result

from collections.abc import Iterable
from functools import wraps


[docs] class SolverResult(Iterable): def __init__(self, coef=None, intercept=None, dual_var=None, robust_loss=None, _iter_attrs=['coef', 'intercept', 'dual_var'], **kwargs): self.coef = coef self.intercept = intercept self.dual_var = dual_var self.robust_loss = robust_loss self._iter_attrs = _iter_attrs for key, value in kwargs.items(): setattr(self, key, value) def __iter__(self): """ Iterator for legacy behaviour """ return iter((getattr(self, attr) for attr in self._iter_attrs))
[docs] def wrap_solver_result(solver_func): """ Decorator to wrap the return of a legacy solver """ @wraps(solver_func) # This is a convenience function for invoking update_wrapper() as a function decorator when defining a wrapper function. It enables the doc to go through. def wrapper(*args, **kwargs): legacy_res = solver_func(*args, **kwargs) if isinstance(legacy_res, SolverResult): return legacy_res elif isinstance(legacy_res, tuple): assert len(legacy_res) >= 1 assert len(legacy_res) <= 4 d = {} d['coef'] = legacy_res[0] if len(legacy_res) == 2: d['dual_var'] = legacy_res[1] if len(legacy_res) >= 3: d['intercept'] = legacy_res[1] d['dual_var'] = legacy_res[2] if len(legacy_res) == 4: d['robust_loss'] = legacy_res[3] return SolverResult(**d, _iter_attrs=d.keys()) elif isinstance(legacy_res, dict): return SolverResult(**legacy_res, _iter_attrs=legacy_res.keys()) raise RuntimeError("Invalid return from solver") return wrapper