import torch
from torch.optim import SGD, Adam
PRERULES = {}
POSTRULES = {}
BOUND = 10**2
[docs]
def prerule(name):
def decorator(func):
PRERULES[name] = func
return func
return decorator
[docs]
def postrule(name):
def decorator(func):
POSTRULES[name] = func
return func
return decorator
[docs]
@prerule('mwu')
def prerule_mwu(p):
assert (p > 0).all()
p.log_()
[docs]
@postrule('mwu')
def postrule_mwu(p):
p.exp_()
[docs]
@prerule('mwu_simplex')
def prerule_mwu_simplex(p):
assert (p > 0).all()
p.log_()
[docs]
@postrule('mwu_simplex')
def postrule_mwu_simplex(p):
p.exp_()
p /= torch.sum(p)
[docs]
@postrule('non_neg')
def postrule_non_neg(p):
p.clip_(0, None)
[docs]
@prerule('max')
@postrule('max')
def rule_max(p):
p.neg_()
[docs]
@prerule('bound')
def prerule_bound(p):
p.grad.clip_(-BOUND, BOUND)
[docs]
class HybridOpt(object):
def __init__(self, params, **kwargs):
super(HybridOpt, self).__init__(params, **kwargs)
def _apply_rules(self, rules):
for group in self.param_groups:
intersection = group.keys() & rules.keys()
assert len(intersection) <= 1
if len(intersection) == 1:
key = intersection.pop()
for p in group['params']:
with torch.no_grad():
rules[key](p)
[docs]
def step(self, *args, **kwargs):
self._apply_rules(PRERULES)
super(HybridOpt, self).step(*args, **kwargs)
self._apply_rules(POSTRULES)
[docs]
class HybridSGD(HybridOpt, SGD):
def __init__(self, *args, **kwargs):
super(HybridSGD, self).__init__(*args, **kwargs)
[docs]
class HybridAdam(HybridOpt, Adam):
def __init__(self, *args, **kwargs):
super(HybridAdam, self).__init__(*args, **kwargs)