jittor.optim 源代码

# ***************************************************************
# Copyright (c) 2022 Jittor. All Rights Reserved. 
# Maintainers:
#     Guowei Yang <471184555@qq.com>
#     Guoye Yang <498731903@qq.com>
#     Wenyang Zhou <576825820@qq.com>
#     Meng-Hao Guo <guomenghao1997@gmail.com>
#     Dun Liang <randonlang@gmail.com>.
#
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np

[文档]class Optimizer(object): """ Basic class of Optimizer. Example:: optimizer = nn.SGD(model.parameters(), lr) optimizer.step(loss) """ def __init__(self, params, lr, param_sync_iter=10000): self.param_groups = [] self.lr = lr self.param_sync_iter = param_sync_iter assert len(params) > 0, "Length of parameters should not be zero" if not isinstance(params[0], dict): params = [{'params': params}] for pg in params: assert isinstance(pg, dict) self.param_groups.append(pg) self.n_step = 0 # __zero_grad is a value for fast determ the grad is zero or not # so we can omit 0+x self.__zero_grad = True self._grad_map = {}
[文档] def add_param_group(self, group): self.param_groups.append(group)
[文档] def clip_grad_norm(self, max_norm:float, norm_type:int=2): r"""Clips gradient norm of this optimizer. The norm is computed over all gradients together. Args: max_norm (float or int): max norm of the gradients norm_type (int): 1-norm or 2-norm Example:: a = jt.ones(2) opt = jt.optim.SGD([a], 0.1) loss = a*a opt.zero_grad() opt.backward(loss) print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 opt.clip_grad_norm(0.01, 2) print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 opt.step() """ if self.__zero_grad: return grads = [] for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue grads.append(g.flatten()) if len(grads) == 0: return total_norm = jt.norm(jt.concat(grads), norm_type) clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue g.update(g*clip_coef)
@property def defaults(self): exclude = set(("defaults", "param_groups", "n_step", "pre_step", "step")) return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude and not callable(v) }
[文档] def state_dict(self): state = {"defaults": self.defaults} return state
[文档] def load_state_dict(self, state): for k,v in state["defaults"].items(): setattr(self, k, v)
[文档] def zero_grad(self): self.__zero_grad = True
[文档] def backward(self, loss, retain_graph=False): ''' optimize.backward(loss) is used for accumulate multiple step, it can be used as following: Origin source code :: n_iter = 10000 batch_size = 100 ... for i in range(n_iter): ... loss = calc_loss() optimizer.step(loss) Accumulation version :: n_iter = 10000 batch_size = 100 accumulation_steps = 10 n_iter *= accumulation_steps batch_size //= accumulation_steps ... for i in range(n_iter): ... loss = calc_loss() # if loss is a mean across batch, we need to divide accumulation_steps optimizer.backward(loss / accumulation_steps) if (i+1) % accumulation_steps == 0: optimizer.step() ''' # clean prev grads params = [] params_has_grad = [] for pg in self.param_groups: for p in pg['params']: params.append(p) if not p.is_stop_grad(): params_has_grad.append(p) # sync prev params jt.sync(params_has_grad) # get gradient grads = jt.grad(loss, params_has_grad, retain_graph) # sync grads and model if in mpi if jt.in_mpi: dep = [] def add_dep(v): nonlocal dep v._add_dependency(dep) dep = [v] for g in grads: g.assign(g.mpi_all_reduce("mean")) add_dep(g._input(0)) if self.n_step % self.param_sync_iter == 0: for p in params: p.assign(p.mpi_broadcast()) add_dep(p) self.n_step += 1 # set up grads in param_groups pid = 0 for pg in self.param_groups: if "grads" not in pg: pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ] pg_grads = pg["grads"] for i, p in enumerate(pg['params']): if not p.is_stop_grad(): # accumulate grad and stop grad of grad g = grads[pid].stop_grad() if not self.__zero_grad: g = g + pg_grads[i] pg_grads[i].update(g) pid += 1 self.__zero_grad = False
[文档] def pre_step(self, loss, retain_graph=False): """ something should be done before step, such as calc gradients, mpi sync, and so on. Example:: class MyOptimizer(Optimizer): def step(self, loss): self.pre_step(loss) ... self.post_step() """ if loss is not None: self.backward(loss, retain_graph) jt.flags.node_order = 1
[文档] def post_step(self): """ something should be done before step, such as zero grad, and so on. Example:: class MyOptimizer(Optimizer): def step(self, loss): self.pre_step(loss) ... self.post_step() """ jt.flags.node_order = 0 self.zero_grad()
[文档] def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) for pg in self.param_groups: lr = pg.get("lr", self.lr) for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue p.update(p - g * lr) self.post_step()
def _build_grad_map(self): _grad_map = {} for pg in self.param_groups: for p, g in zip(pg["params"], pg["grads"]): _grad_map[id(p)] = g self._grad_map = _grad_map
[文档] def find_grad(self, v:jt.Var) -> jt.Var: if id(v) not in self._grad_map: self._build_grad_map() if id(v) not in self._grad_map: raise RuntimeError("This variable is not managed by this optimizer") return self._grad_map[id(v)]
[文档]def opt_grad(v:jt.Var, opt:Optimizer): ''' Get grad of certain variable in optimizer, Example:: model = Model() optimizer = SGD(model.parameters()) ... optimizer.backward(loss) for p in model.parameters(): grad = p.opt_grad(optimizer) ''' return opt.find_grad(v)
jt.Var.opt_grad = opt_grad
[文档]class SGD(Optimizer): """ SGD Optimizer. Example:: optimizer = nn.SGD(model.parameters(), lr, momentum=0.9) optimizer.step(loss) """ def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False): super().__init__(params, lr) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening self.nesterov = nesterov # initialize required arguments for pg in self.param_groups: values = pg["values"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad())
[文档] def add_param_group(self, group): values = group["values"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group)
[文档] def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph=False) jt.flags.node_order = 1 for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) momentum = pg.get("momentum", self.momentum) weight_decay = pg.get("weight_decay", self.weight_decay) dampening = pg.get("dampening", self.dampening) nesterov = pg.get("nesterov", self.nesterov) # optimize main body for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): if p.is_stop_grad(): continue dp = p * weight_decay + g v.update(momentum * v + dp * (1 - dampening)) if nesterov: p.update(p - (dp + momentum * v) * lr) else: p.update(p - v * lr) self.post_step()
[文档]class RMSprop(Optimizer): """ RMSprop Optimizer. Args: params(list): parameters of model. lr(float): learning rate. eps(float): term added to the denominator to avoid division by zero, default 1e-8. alpha(float): smoothing constant, default 0.99. Example: optimizer = nn.RMSprop(model.parameters(), lr) optimizer.step(loss) """ def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99): super().__init__(params, lr) self.eps = eps self.alpha = alpha # initialize required arguments for each param_groups for pg in self.param_groups: values = pg["values"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad())
[文档] def add_param_group(self, group): values = group["values"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group)
[文档] def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) eps = pg.get("eps", self.eps) alpha = pg.get("alpha", self.alpha) for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): if p.is_stop_grad(): continue v.update(alpha * v + (1-alpha) * g * g) p.update(p - lr * g / (jt.sqrt(v) + eps)) self.post_step()
[文档]class Adam(Optimizer): """ Adam Optimizer. Example:: optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) optimizer.step(loss) """ def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): super().__init__(params, lr) self.eps = eps self.betas = betas self.weight_decay = weight_decay # assert weight_decay==0, "weight_decay is not supported yet" # initialize required arguments for each param_groups for pg in self.param_groups: values = pg["values"] = [] m = pg["m"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad())
[文档] def add_param_group(self, group): values = group["values"] = [] m = group["m"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group)
[文档] def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) n = float(self.n_step) jt.flags.node_order = 1 for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) eps = pg.get("eps", self.eps) weight_decay = pg.get("weight_decay", self.weight_decay) b0, b1 = pg.get("betas", self.betas) for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): if p.is_stop_grad(): continue g = p * weight_decay + g m.update(b0 * m + (1-b0) * g) v.update(b1 * v + (1-b1) * g * g) step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) p.update(p - m * step_size / (jt.sqrt(v) + eps)) self.post_step()
[文档]class AdamW(Optimizer): """ AdamW Optimizer. Example:: optimizer = nn.AdamW(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) optimizer.step(loss) """ def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): super().__init__(params, lr) self.eps = eps self.betas = betas self.weight_decay = weight_decay # assert weight_decay==0, "weight_decay is not supported yet" # initialize required arguments for each param_groups for pg in self.param_groups: values = pg["values"] = [] m = pg["m"] = [] for p in pg["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad())
[文档] def add_param_group(self, group): values = group["values"] = [] m = group["m"] = [] for p in group["params"]: values.append(jt.zeros(p.shape, p.dtype).stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad()) self.param_groups.append(group)
[文档] def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) n = float(self.n_step) for pg in self.param_groups: # get arguments from each param_groups lr = pg.get("lr", self.lr) eps = pg.get("eps", self.eps) weight_decay = pg.get("weight_decay", self.weight_decay) b0, b1 = pg.get("betas", self.betas) for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): if p.is_stop_grad(): continue p.update(p * (1 - lr * weight_decay)) bias_correction1 = 1 - b0 ** n bias_correction2 = 1 - b1 ** n m.update(b0 * m + (1-b0) * g) #exp_avg v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps step_size = lr / bias_correction1 p.update(p - step_size * m / denom) self.post_step()
[文档]class LRScheduler: def __init__(self,optimizer, last_epoch=-1): assert isinstance(optimizer,Optimizer) self.optimizer = optimizer if last_epoch==-1: for gp in optimizer.param_groups: gp.setdefault('initial_lr',gp.get('lr',optimizer.lr)) else: for gp in optimizer.param_groups: assert 'initial_lr' in gp self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) self.last_epoch = last_epoch self.optimizer._step_count = 0 self._step_count = 0 self.step()
[文档] def get_lr(self): raise NotImplementedError
[文档] def get_last_lr(self): return self._last_lr
[文档] def step(self,epoch=None): self._step_count += 1 if epoch is None: self.last_epoch += 1 values = self.get_lr() else: self.last_epoch = epoch values = self.get_lr() for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data param_group['lr'] = lr self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
[文档]class LambdaLR(LRScheduler): def __init__(self, optimizer, lr_lambda, last_epoch=-1): if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: if len(lr_lambda) != len(optimizer.param_groups): raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda))) self.lr_lambdas = list(lr_lambda) super(LambdaLR, self).__init__(optimizer, last_epoch)
[文档] def get_lr(self): return [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]