jittor.weightnorm 源代码

import jittor as jt
from jittor import nn

def _weight_norm(v, g, dim):
    '''
    权重归一化。权重归一化是一种重新参数化神经网络,提高网络泛化能力的常用方法。具体公式为 :math:`v_{norm} = v * (g / || v ||_2)`

        参数:
            - v (Jittor.Var): 待归一化的权重。
            - g (Jittor.Var): 缩放因子,用于调整归一化权重的范围。
            - dim (int): 归一化操作在该维度上进行。

        返回值:
            Jittor.Var: 归一化后的权重。

        示例:
            >>> from jittor import weightnorm
            >>> v = jt.array([1.0, 2.0, 3.0])
            >>> g = jt.array([1.0])
            >>> weightnorm._weight_norm(v, g, 0)
            jt.Var([0.26726124 0.5345225  0.8017837 ], dtype=float32)


    '''
    return v * (g / jt.norm(v, 2, dim, keepdim=True))

[文档] class WeightNorm(object): ''' 权重标准化的类。权重标准化是一种标准化技术,可以使模型的训练过程更稳定。它通过将权重参数的幅值和方向进行解耦,使得网络更容易优化。它可以用来对某个Module添加或者移除权重标准化模块。 参数: - name (str): 要标准化的权重变量名称; - dim (int): 对权重进行标准化的维度。-1表示对最后一维进行权重标准化。 形状: 使用该类,可以对一个神经网络层的某个张量作权重归一化。归一化对张量的形状没有影响。 代码示例: >>> import jittor as jt >>> from jittor import weightnorm >>> wn = weightnorm.WeightNorm(\"weight\", -1) >>> linear_layer = jt.nn.Linear(3,4) >>> wn.apply(linear_layer, \"weight\", -1) # 对linear_layer的weight变量作归一化 <jittor.weightnorm.WeightNorm object at 0x0000012FB9C3C400> >>> hasattr(linear_layer, 'weight_g') True >>> wn.remove(linear_layer) >>> hasattr(linear_layer, 'weight_g') False ''' def __init__(self, name: str, dim: int) -> None: if dim is None: dim = -1 self.name = name self.dim = dim # TODO Make return type more specific
[文档] def compute_weight(self, module: nn.Module): g = getattr(module, self.name + '_g') v = getattr(module, self.name + '_v') return _weight_norm(v, g, self.dim)
@staticmethod def apply(module, name: str, dim: int): if hasattr(module, '__fhook2__') and isinstance(module.__fhook2__, WeightNorm): raise RuntimeError("Cannot register two weight_norm hooks on " "the same parameter {}".format(name)) if dim is None: dim = -1 fn = WeightNorm(name, dim) weight = getattr(module, name) # todo: add check # remove w from parameter list # del module._parameters[name] delattr(module, name) # add g and v as new parameters and express w as g/||v|| * v module.__setattr__(name + '_g', jt.norm(weight, 2, dim, keepdim=True).detach()) module.__setattr__(name + '_v', weight.detach()) setattr(module, name, fn.compute_weight(module)) # recompute weight before every forward() # todo: support multiple hook in a module module.register_pre_forward_hook(fn) return fn
[文档] def remove(self, module: nn.Module) -> None: weight = self.compute_weight(module) delattr(module, self.name) delattr(module, self.name + '_g') delattr(module, self.name + '_v') setattr(module, self.name, weight.detach())
def __call__(self, module: nn.Module, inputs) -> None: setattr(module, self.name, self.compute_weight(module))
[文档] def weight_norm(module, name, dim): ''' 对模块增加一个权重归一化操作。通过WeightNorm在权重矩阵每一个切片(在维度dim上)进行 :math:`L_2` 范数归一化,数学公式描述如下:设权重矩阵为W,经过归一化后的权重矩阵为W',则有 :math:`W' = \\frac{W}{||W||_2}` 参数: - module: 输入的模型,类型为模型对象 - name (Jittor.Var): 指定的参数的名称 - dim (int): 进行权重归一化的维度 返回值: 处理后的模型对象。 代码示例: >>> import jittor as jt >>> from jittor import weightnorm >>> class jt_module(jt.nn.Module): >>> def __init__(self, weight): >>> super().__init__() >>> self.linear = jt.array(weight) >>> >>> def execute(self, x): >>> return jt.matmul(self.linear, x) >>> >>> jm = jt_module(weight) >>> weightnorm.weight_norm(jm, 'linear', -1) ''' WeightNorm.apply(module, name, dim) return module
[文档] def remove_weight_norm(module, name: str = 'weight'): ''' 移除模块的权重归一化。该函数通过检查模块是否有权重归一化相关的属性 ``__fhook2__`` 来判断是否对该模块进行过权重归一化操作。如果存在,则删除该属性,用于消除对模块的权重归一化操作。 参数: - module (Module): 需要移除权重归一化的模块 - name (str, 可选): 权重属性的名称. 默认值: `'weight'` 返回值: 若模块存在权重归一化属性,返回删除权重归一化属性后的模块。如果模块中没有找到指定的权重归一化属性,会抛出ValueError。 代码示例: >>> from jittor import weightnorm >>> from jittor import nn >>> model = nn.Linear(20, 40) >>> model = weightnorm.weight_norm(model, 'weight', -1) >>> hasattr(model,\"__fhook2__\") True >>> model = weightnorm.remove_weight_norm(model, 'weight') >>> hasattr(model,\"__fhook2__\") False ''' if hasattr(module, "__fhook2__") and isinstance(module.__fhook2__, WeightNorm): delattr(module, "__fhook2__") return module raise ValueError("weight_norm of '{}' not found in {}" .format(name, module))