jittor.weightnorm

这里是Jittor的weightnorm模块的API文档,您可以通过from jittor import weightnorm来获取该模块。

class jittor.weightnorm.WeightNorm(name: str, dim: int)[源代码]

权重标准化的类。权重标准化是一种标准化技术,可以使模型的训练过程更稳定。它通过将权重参数的幅值和方向进行解耦,使得网络更容易优化。它可以用来对某个Module添加或者移除权重标准化模块。

参数:
  • name (str): 要标准化的权重变量名称;

  • dim (int): 对权重进行标准化的维度。-1表示对最后一维进行权重标准化。

形状:

使用该类,可以对一个神经网络层的某个张量作权重归一化。归一化对张量的形状没有影响。

代码示例:
>>> import jittor as jt
>>> from jittor import weightnorm
>>> wn = weightnorm.WeightNorm("weight", -1)
>>> linear_layer = jt.nn.Linear(3,4)
>>> wn.apply(linear_layer, "weight", -1) # 对linear_layer的weight变量作归一化
<jittor.weightnorm.WeightNorm object at 0x0000012FB9C3C400>
>>> hasattr(linear_layer, 'weight_g')
True
>>> wn.remove(linear_layer)
>>> hasattr(linear_layer, 'weight_g')
False
compute_weight(module: Module)[源代码]
remove(module: Module) None[源代码]
jittor.weightnorm.remove_weight_norm(module, name: str = 'weight')[源代码]

移除模块的权重归一化。该函数通过检查模块是否有权重归一化相关的属性 __fhook2__ 来判断是否对该模块进行过权重归一化操作。如果存在,则删除该属性,用于消除对模块的权重归一化操作。

参数:
  • module (Module): 需要移除权重归一化的模块

  • name (str, 可选): 权重属性的名称. 默认值: ‘weight’

返回值:

若模块存在权重归一化属性,返回删除权重归一化属性后的模块。如果模块中没有找到指定的权重归一化属性,会抛出ValueError。

代码示例:
>>> from jittor import weightnorm
>>> from jittor import nn
>>> model = nn.Linear(20, 40)
>>> model = weightnorm.weight_norm(model, 'weight', -1)
>>> hasattr(model,"__fhook2__")
True
>>> model = weightnorm.remove_weight_norm(model, 'weight')
>>> hasattr(model,"__fhook2__")
False
jittor.weightnorm.weight_norm(module, name, dim)[源代码]

对模块增加一个权重归一化操作。通过WeightNorm在权重矩阵每一个切片(在维度dim上)进行 \(L_2\) 范数归一化,数学公式描述如下:设权重矩阵为W,经过归一化后的权重矩阵为W’,则有 \(W' = \frac{W}{||W||_2}\)

参数:
  • module: 输入的模型,类型为模型对象

  • name (Jittor.Var): 指定的参数的名称

  • dim (int): 进行权重归一化的维度

返回值:

处理后的模型对象。

代码示例:
>>> import jittor as jt
>>> from jittor import weightnorm
>>> class jt_module(jt.nn.Module):
>>>         def __init__(self, weight):
>>>             super().__init__()
>>>             self.linear = jt.array(weight)
>>> 
>>>         def execute(self, x):
>>>             return jt.matmul(self.linear, x)
>>> 
>>> jm = jt_module(weight)
>>> weightnorm.weight_norm(jm, 'linear', -1)