# jittor.weightnorm 源代码

import jittor as jt
from jittor import nn

def _weight_norm(v, g, dim):
'''
权重归一化。权重归一化是一种重新参数化神经网络，提高网络泛化能力的常用方法。具体公式为 :math:v_{norm} = v * (g / || v ||_2)

参数:
- v (Jittor.Var): 待归一化的权重。
- g (Jittor.Var): 缩放因子，用于调整归一化权重的范围。
- dim (int): 归一化操作在该维度上进行。

返回值:
Jittor.Var: 归一化后的权重。

示例：
>>> from jittor import weightnorm
>>> v = jt.array([1.0, 2.0, 3.0])
>>> g = jt.array([1.0])
>>> weightnorm._weight_norm(v, g, 0)
jt.Var([0.26726124 0.5345225  0.8017837 ], dtype=float32)

'''
return v * (g / jt.norm(v, 2, dim, keepdim=True))

[文档]
class WeightNorm(object):
'''
权重标准化的类。权重标准化是一种标准化技术，可以使模型的训练过程更稳定。它通过将权重参数的幅值和方向进行解耦，使得网络更容易优化。它可以用来对某个Module添加或者移除权重标准化模块。

参数:
- name (str): 要标准化的权重变量名称；
- dim (int): 对权重进行标准化的维度。-1表示对最后一维进行权重标准化。

形状:
使用该类，可以对一个神经网络层的某个张量作权重归一化。归一化对张量的形状没有影响。

代码示例:
>>> import jittor as jt
>>> from jittor import weightnorm
>>> wn = weightnorm.WeightNorm(\"weight\", -1)
>>> linear_layer = jt.nn.Linear(3,4)
>>> wn.apply(linear_layer, \"weight\", -1) # 对linear_layer的weight变量作归一化
<jittor.weightnorm.WeightNorm object at 0x0000012FB9C3C400>
>>> hasattr(linear_layer, 'weight_g')
True
>>> wn.remove(linear_layer)
>>> hasattr(linear_layer, 'weight_g')
False

'''
def __init__(self, name: str, dim: int) -> None:
if dim is None:
dim = -1
self.name = name
self.dim = dim

# TODO Make return type more specific

[文档]
def compute_weight(self, module: nn.Module):
g = getattr(module, self.name + '_g')
v = getattr(module, self.name + '_v')
return _weight_norm(v, g, self.dim)

@staticmethod
def apply(module, name: str, dim: int):
if hasattr(module, '__fhook2__') and isinstance(module.__fhook2__, WeightNorm):
raise RuntimeError("Cannot register two weight_norm hooks on "
"the same parameter {}".format(name))

if dim is None:
dim = -1

fn = WeightNorm(name, dim)

weight = getattr(module, name)
# todo: add check
# remove w from parameter list
# del module._parameters[name]
delattr(module, name)

# add g and v as new parameters and express w as g/||v|| * v
module.__setattr__(name + '_g', jt.norm(weight, 2, dim, keepdim=True).detach())
module.__setattr__(name + '_v', weight.detach())
setattr(module, name, fn.compute_weight(module))

# recompute weight before every forward()
# todo: support multiple hook in a module
module.register_pre_forward_hook(fn)
return fn

[文档]
def remove(self, module: nn.Module) -> None:
weight = self.compute_weight(module)
delattr(module, self.name)
delattr(module, self.name + '_g')
delattr(module, self.name + '_v')
setattr(module, self.name, weight.detach())

def __call__(self, module: nn.Module, inputs) -> None:
setattr(module, self.name, self.compute_weight(module))

[文档]
def weight_norm(module, name, dim):
'''
对模块增加一个权重归一化操作。通过WeightNorm在权重矩阵每一个切片（在维度dim上）进行 :math:L_2 范数归一化，数学公式描述如下：设权重矩阵为W，经过归一化后的权重矩阵为W'，则有 :math:W' = \\frac{W}{||W||_2}

参数:
- module: 输入的模型，类型为模型对象
- name (Jittor.Var): 指定的参数的名称
- dim (int): 进行权重归一化的维度

返回值:
处理后的模型对象。

代码示例：
>>> import jittor as jt
>>> from jittor import weightnorm
>>> class jt_module(jt.nn.Module):
>>>         def __init__(self, weight):
>>>             super().__init__()
>>>             self.linear = jt.array(weight)
>>>
>>>         def execute(self, x):
>>>             return jt.matmul(self.linear, x)
>>>
>>> jm = jt_module(weight)
>>> weightnorm.weight_norm(jm, 'linear', -1)

'''
WeightNorm.apply(module, name, dim)
return module

[文档]
def remove_weight_norm(module, name: str = 'weight'):
'''
移除模块的权重归一化。该函数通过检查模块是否有权重归一化相关的属性 __fhook2__ 来判断是否对该模块进行过权重归一化操作。如果存在，则删除该属性，用于消除对模块的权重归一化操作。

参数:
- module (Module): 需要移除权重归一化的模块
- name (str, 可选): 权重属性的名称. 默认值： 'weight'

返回值:
若模块存在权重归一化属性，返回删除权重归一化属性后的模块。如果模块中没有找到指定的权重归一化属性，会抛出ValueError。

代码示例：
>>> from jittor import weightnorm
>>> from jittor import nn
>>> model = nn.Linear(20, 40)
>>> model = weightnorm.weight_norm(model, 'weight', -1)
>>> hasattr(model,\"__fhook2__\")
True
>>> model = weightnorm.remove_weight_norm(model, 'weight')
>>> hasattr(model,\"__fhook2__\")
False

'''
if hasattr(module, "__fhook2__") and isinstance(module.__fhook2__, WeightNorm):
delattr(module, "__fhook2__")
return module
raise ValueError("weight_norm of '{}' not found in {}"
.format(name, module))