# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import NanoVector, Var
import numpy as np
import math
import warnings
[文档]
def eye(shape, dtype="float32"):
"""
创建一个形状为 ``shape`` 的单位矩阵。
单位矩阵的定义为:
对于任何 :math:`i, j (0 \le i, j < dim)` ,当 :math:`i` 等于 :math:`j` 时,输出为 ``1`` ;否则为 ``0`` 。
.. math::
eye(i, j) =
\\begin{cases}
1, & \\text{if } i = j \\\\
0, & \\text{else}
\\end{cases}
参数:
- shape (``int, Tuple[int]``): 规定返回矩阵的形状,如果 ``shape`` 为整数, 那么返回的矩阵的形状为 ``(shape, shape)``。
- dtype (``dtype``, 可选): 数据类型,默认为 ``float32``
代码示例:
>>> init.eye(2)
jt.Var([[1. 0.]
[0. 1.]], dtype=float32)
>>> init.eye((2, 3), dtype='int32')
jt.Var([[1 0 0]
[0 1 0]], dtype=int32)
返回值:
返回一个单位矩阵
注意事项:
- ``shape`` 如果是整数序列,只能传入两个维度的长度,否则会引起异常
"""
if isinstance(shape, int):
shape = (shape,shape)
assert len(shape)==2, f"len of shape should be 2, but got {shape}"
index = jt.index(shape)
return (index[0]==index[1]).unary(dtype)
[文档]
def eye_(var):
"""
原地修改矩阵 ``var``,将其成为单位矩阵。
单位矩阵的定义为:
对于任何 :math:`i, j (0 \le i, j < dim)` ,当 :math:`i` 等于 :math:`j` 时,输出为 ``1`` ;否则为 ``0`` 。
.. math::
eye(i, j) =
\\begin{cases}
1, & \\text{if } i = j \\\\
0, & \\text{else}
\\end{cases}
参数:
- var (``Var``): ``Var``类型的张量
代码示例:
>>> x = jt.randn((3,2))
>>> x
jt.Var([[-0.582047 1.1068922 ]
[ 0.08440255 -0.86142904]
[ 0.06269896 -0.979008 ]], dtype=float32)
>>> init.eye_(x)
>>> x
jt.Var([[1. 0.]
[0. 1.]
[0. 0.]], dtype=float32)
返回值:
返回一个形状和 ``var`` 相同的单位矩阵
"""
return var.assign(eye(var.shape, var.dtype))
Var.eye_ = eye_
[文档]
def constant(shape, dtype="float32", value=0.0):
"""
创建一个数值均为 ``value``, 形状由可变参数 ``shape`` 确定, 默认填充 ``float32`` 类型的零值。
参数:
- shape (``Tuple[int]``): 整数序列,定义了输出的形状
- dtype (``var.dtype``, optional): 数据类型,默认为 ``float32``
- value (``int`` or ``float``): 填充的数值
代码示例:
>>> init.constant(2)
jt.Var([0. 0.], dtype=float32)
>>> init.constant((2,3), \"float32\", 3.8)
>>> x
jt.Var([[3.8 3.8 3.8]
[3.8 3.8 3.8]], dtype=float32)
返回值:
返回一个填充值为 ``value`` ,形状大小为 ``shape`` 的张量(Var)
"""
return jt.array(value).unary(dtype).broadcast(NanoVector(shape))
[文档]
def constant_(var, value=0.0):
"""
原地修改张量 ``var`` ,将其修改为数值均为 ``value`` 的张量,默认填充零值。
参数:
- var (``Var``): Var类型的张量
- value (``int`` or ``float``): 填充的数值
代码示例:
>>> x = jt.randn((3,2))
>>> x
jt.Var([[ 0.0335454 -0.18658346]
[-0.53283346 -1.2073938 ]
[-1.344916 -1.6093305 ]], dtype=float32)
>>> init.constant_(x, 3.8)
>>> x
>jt.Var([[3.8 3.8]
[3.8 3.8]
[3.8 3.8]], dtype=float32)
返回值:
就地修改 ``var`` ,返回一个填充值为 ``value`` ,形状和 ``var`` 相同的张量(Var)
"""
return var.assign(constant(var.shape, var.dtype, value))
Var.constant_ = constant_
fill = Var.fill_ = constant_
[文档]
def zero(shape, dtype="float32"):
"""
返回一个全为 0 的张量(``Var``),形状由可变参数 ``shape`` 定义,
数据类型由 ``dtype`` 定义。如果不给定 ``dtype`` , 默认类型为 ``float32``。
参数:
- shape (``Tuple[int]``): 整数序列,定义了输出的形状
- dtype (``var.dtype``,可选): 数据类型,默认为 ``float32``
代码示例:
>>> init.zero((3,2), dtype='int32')
jt.Var([[0 0]
[0 0]
[0 0]], dtype=int32)
返回值:
全为 0 的张量(``Var``)
注意事项:
- 该函数和 ``jt.zeros()`` 用处一致
"""
return constant(shape, dtype, 0)
[文档]
def zero_(var):
"""
原地修改张量 ``var`` ,将其修改为数值均为0的张量。
参数:
- var (``Var``): 输入的 ``Var``
代码示例:
>>> x = jt.randn((3,2))
>>> x
jt.Var([[-1.1730903 0.21458259]
[ 1.0399616 0.07660236]
[-1.8453276 -0.95629567]], dtype=float32)
>>> init.zero_(x)
>>> x
jt.Var([[0. 0.]
[0. 0.]
[0. 0.]], dtype=float32)
返回值:
全为 0 的张量(``Var``)
"""
return var.assign(zero(var.shape, var.dtype))
Var.zero_ = zero_
[文档]
def random_(var):
'''
该函数将输入变量(``var``)重新赋值为随机值,生成的随机数范围在 0 到 1 之间。
参数:
- var(``Var``): 需要被重新赋值的变量
返回值:
``Var``: 重新赋值后的变量
代码示例:
>>> import jittor as jt
>>> x = jt.init.one(5)
>>> jt.init.random_(x)
>>> print(x)
jt.Var([0.9079071 0.1955278 0.2359613 0.8015607 0.83047885], dtype=float32)
'''
return var.assign(jt.rand(var.shape, var.dtype))
Var.random_ = random_
[文档]
def one(shape, dtype="float32"):
"""
返回一个全为 1 的张量(``Var``),形状由可变参数 ``shape`` 定义, 数据类型由可变参数 ``dtype`` 定义,默认类型 ``float32``。
参数:
- shape (``Tuple[int]``): 整型序列,定义了输出的形状
- dtype (``var.dtype``, 可选): 数据类型,默认为 ``float32``
代码示例:
>>> init.one((3, 4))
jt.Var([[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]], dtype=float32)
>>> init.one(5)
jt.Var([1. 1. 1. 1. 1.], dtype=float32)
返回值:
全为 1 的张量(``Var``)
"""
return constant(shape, dtype, 1)
[文档]
def one_(var):
"""
原地修改张量 ``var`` ,将其数值全部填充为1,数据类型保持不变。
参数:
- var (``Var``): Var类型的张量
代码示例:
>>> x = jt.randn((3, 2))
>>> x
jt.Var([[ 0.8584159 -1.1204817 ]
[ 0.5418147 -0.62170196]
[-0.91137475 -0.13982968]], dtype=float32)
>>> init.one_(x)
>>> x
jt.Var([[1. 1.]
[1. 1.]
[1. 1.]], dtype=float32)
返回值:
就地修改张量 ``var`` ,返回一个数值全为1的张量
"""
return var.assign(one(var.shape, var.dtype))
Var.one_ = one_
Var.uniform_ = uniform_
[文档]
def gauss(shape, dtype="float32", mean=0.0, std=1.0):
"""
创建一个 Var,其形状由 ``shape`` 指定,数据类型为 ``dtype`` ,
元素值遵循均值为 ``mean`` 和标准差为 ``std`` 的高斯分布(正态分布)的张量。
.. math::
out_i \\sim \\mathcal{N}(\\text{mean}, \\text{std}^2)
参数:
- shape (``Tuple[int]``): 整型序列,定义了输出的形状
- dtype (``dtype``): 数据类型,默认为 ``float32``
- mean (``float`` or ``Var``, 可选): 高斯分布的均值, 默认值为 ``0.0``
- std (``float`` or ``Var``, 可选): 高斯分布的标准差,默认值为 ``1.0``
代码示例:
>>> init.gauss((3,2))
jt.Var([[-1.09277 -0.22924843]
[-0.5264394 -0.13242662]
[-1.1316705 1.2506602 ]], dtype=float32)
>>> init.gauss((2,2), 'float32', 0, 10)
jt.Var([[ 3.457805 5.8171034]
[-1.6440934 2.1744032]], dtype=float32)
返回值:
返回一个数值符合高斯分布的 Var
注意事项:
- 确保指定的 ``std`` 是非负的,因为标准差不能是负数
"""
return jt.random(NanoVector(shape), dtype, "normal") * std + mean
[文档]
def gauss_(var, mean=0.0, std=1.0):
"""
原地修改张量 ``var`` ,将其数值填充满足以均值为 ``mean`` 和标准差为 ``std`` 的高斯分布(正态分布)的随机数。
.. math::
out_i \\sim \\mathcal{N}(\\text{mean}, \\text{std}^2)
参数:
- var (``Var``): 改变的 var
- mean (``float`` or ``Var``, 可选): 高斯分布的均值, 默认值为 ``0.0``
- std (``float`` or ``Var``, 可选): 高斯分布的标准差,默认值为 ``1.0``
代码示例:
>>> x = init.zero((3,3))
>>> x
jt.Var([[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]], dtype=float32)
>>> init.gauss_(x, 0, 5.0)
>>> x
jt.Var([[-0.34221977 -8.056475 2.6251674 ]
[-0.81275284 2.00393 -2.4397573 ]
[ 0.7867035 -8.159389 7.100675 ]], dtype=float32)
返回值:
就地修改张量 ``var`` ,返回一个数值符合高斯分布的张量
"""
return var.assign(gauss(var.shape, var.dtype, mean, std))
Var.gauss_ = gauss_
Var.normal_ = gauss_
Var.invariant_uniform_ = invariant_uniform_
[文档]
def relu_invariant_gauss(shape, dtype="float32", mode="fan_in"):
'''
返回由 relu_invariant_gauss 初始化的 Var。
参数:
- shape (``int`` or ``Tuple[int]``): 输出Var的形状
- dtype (``str``): 输出Var的 ``dtype`` ,默认 ``float32``
- mode (``str``): 模式选择,应为 ``fan_in`` 或 ``fan_out``。选择 ``'fan_in'`` 保留正向传递中权重方差的大小。选择 ``'fan_out'`` 保留反向传递中的大小。
代码示例:
>>> from jittor import init
>>> from jittor import nn
>>> a = init.relu_invariant_gauss((2,2))
>>> print(a)
jt.Var([[ 0.30814755 -0.1328245 ]
[-0.10410424 -0.01558159]], dtype=float32)
返回值:
由 relu_invariant_gauss 初始化的 Jittor Var
'''
assert len(shape)>1
assert mode=="fan_in" or mode=="fan_out"
matsize=1
for i in shape[2:]:
matsize *= i
fan = (shape[1] * matsize) if mode=="fan_in" else (shape[0] * matsize)
std = math.sqrt(2.0/fan)
return gauss(shape, dtype, 0, std)
[文档]
def relu_invariant_gauss_(var, mode="fan_in"):
'''
用随机的 relu 不变高斯初始化 Var。
参数:
- var (``Var``): 要用随机 relu 不变高斯初始化的Var
- mode (``str``): 模式选择,应为 ``fan_in`` 或 ``fan_out``。选择 ``'fan_in'`` 保留正向传递中权重方差的大小。选择 ``'fan_out'`` 保留反向传递中的大小。
代码示例:
>>> from jittor import init
>>> from jittor import nn
>>> linear = nn.Linear(2,2)
>>> init.relu_invariant_gauss_(linear.weight)
>>> print(linear.weight)
jt.Var([[ 0.74033755 -0.74033755]
[-0.74033755 0.74033755]], dtype=float32)
>>> linear.weight.relu_invariant_gauss_() # 这样也可以
返回值:
由随机 relu 不变高斯初始化的Var
'''
return var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
Var.relu_invariant_gauss_ = relu_invariant_gauss_
[文档]
def calculate_std(var, mode, nonlinearity, param=0.01):
'''
计算标准差。
参数:
- var (``Var``): 输入Var
- mode (``str``): 模式,可选值为 ``fan_in`` 和 ``fan_out``。默认值:``fan_in``
- nonlinearity (``str``): 非线性函数,可选值为 ``linear``、``conv1d``、``conv2d``、``conv3d``、``conv_transpose1d``、``conv_transpose2d``、``conv_transpose3d``、``sigmoid``、``tanh``、``relu`` 和 ``leaky_relu``。默认值:``linear``
- param (``float``): 非线性函数的参数。默认值:``0.01``
代码示例:
>>> x = jt.random((2, 2))
jt.Var([[ 2.520024 -0.4921519 ]
[-1.1624513 -0.62531066]], dtype=float32)
>>> jt.calculate_std(x)
0.7071067811865476
返回值:
标准差(``float``)
'''
mode = mode.lower()
assert isinstance(param,(int,float))
assert var.ndim>=2
assert mode in ['fan_in', 'fan_out']
fan = var.shape[1] if mode == 'fan_in' else var.shape[0]
fan *= var[0][0].numel()
gains = {
'linear':1,
'conv1d':1,
'conv2d':1,
'conv3d':1,
'conv_transpose1d':1,
'conv_transpose2d':1,
'conv_transpose3d':1,
'sigmoid':1,
'tanh':5.0/3,
'relu':math.sqrt(2.0),
'leaky_relu':math.sqrt(2.0 / (1 + param ** 2)),
}
gain = gains[nonlinearity]
std = gain/math.sqrt(fan)
return std
Var.kaiming_uniform_ = kaiming_uniform_
[文档]
def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
'''
将 ``Var`` 通过 kaiming normal 随机初始化。
参数:
- var (``Var``):需要被 kaiming normal 初始化的 Var
- a (``float``):此层后使用的整流器的负斜率(仅在非线性函数是 ``leaky_relu`` 时使用)。
- mode (``str``): 模式选择,应为 ``fan_in`` 或 ``fan_out``。选择 ``'fan_in'`` 保留正向传递中权重方差的大小。选择 ``'fan_out'`` 保留反向传递中的大小。
- nonlinearity (``str``):此层后使用的非线性函数。默认值:``leaky_ relu`` 。
代码示例:
>>> from jittor import init
>>> from jittor import nn
>>> linear = nn.Linear(2,2)
>>> init.kaiming_normal_(linear.weight)
>>> linear.weight
返回值:
由 kaiming uniform 随机初始化的 Var
'''
std = calculate_std(var,mode,nonlinearity,a)
return gauss_(var,0, std)
Var.kaiming_normal_ = kaiming_normal_
Var.xavier_uniform_ = xavier_uniform_
[文档]
def xavier_gauss(shape, dtype="float32", gain=1.0):
'''
返回由 xavier_gauss 初始化的 ``Var``。结果 ``Var`` 的值将从 :math:`\mathcal N(-a, a)` 中采样,其中
.. math::
\\text{std} = \\text{gain} \\times \\sqrt{\\frac{2}{\\text{fan_in} + \\text{fan_out}}}
参数:
- shape (``int`` or ``Tuple[int]``): 输出Var的形状
- dtype (``str``): 输出 ``Var`` 的 ``dtype`` ,默认 ``float32``
- gain (``float``): 可选的缩放因子。
代码示例:
>>> from jittor import init
>>> from jittor import nn
>>> linear = nn.Linear(2,2)
>>> init.xavier_gauss_(linear.weight, init.calculate_gain('relu'))
>>> print(linear.weight)
jt.Var([[ 0.27429324 -0.15574329]
[-0.15574329 -0.27429324]], dtype=float32)
>>> linear.weight.xavier_gauss_() # This is ok too
jt.Var([[ 0.27429324 -0.15574329]
[-0.15574329 -0.27429324]], dtype=float32)
返回值:
由 xavier_gauss 初始化的 ``Var``
'''
assert len(shape)>1
matsize=1
for i in shape[2:]:
matsize *= i
fan = (shape[1] * matsize) + (shape[0] * matsize)
std = gain * math.sqrt(2.0/fan)
return gauss(shape, dtype, 0, std)
[文档]
def xavier_gauss_(var, gain=1.0):
'''
返回由 xavier_gauss 初始化的 ``Var``。结果 ``Var`` 的值将从 :math:`\mathcal N(-a, a)` 中采样,其中
.. math::
\\text{std} = \\text{gain} \\times \\sqrt{\\frac{2}{\\text{fan_in} + \\text{fan_out}}}
参数:
- var (``Var``): 通过 xavier_guass 随机初始化的变量
- gain (``float``): 可选的缩放因子。
代码示例:
>>> from jittor import init
>>> from jittor import nn
>>> linear = nn.Linear(2,2)
>>> init.xavier_gauss_(linear.weight, init.calculate_gain('relu'))
>>> print(linear.weight)
jt.Var([[ 0.27429324 -0.15574329]
[-0.15574329 -0.27429324]], dtype=float32)
>>> linear.weight.xavier_gauss_() # This is ok too
jt.Var([[ 0.27429324 -0.15574329]
[-0.15574329 -0.27429324]], dtype=float32)
返回值:
由 xavier_gauss 初始化的 Var
'''
return var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
Var.xavier_gauss_ = xavier_gauss_
[文档]
def calculate_gain(nonlinearity, param=None):
r"""返回给定非线性函数的推荐增益值,值如下:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\displaystyle \frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\displaystyle \sqrt{\frac{2}{1 + \text{negative_slope}^2}}`
SELU :math:`\displaystyle \frac{3}{4}`
================= ====================================================
参数:
nonlinearity: 非线性函数(`nn.functional` 名称)
param: 非线性函数的可选参数
代码示例:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
elif nonlinearity == 'selu':
return 3.0 / 4
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
[文档]
def trunc_normal_(var, mean=0., std=1., a=-2., b=2.):
# type: (jt.jittor_core.Var, float, float, float, float) -> jt.jittor_core.Var
"""
将输入的 Var 用一个截断正态分布中的值填充。
具体而言,`trunc_normal` 函数采样的方法如下:
首先,根据指定的均值 ``mean`` 和标准差 ``std`` ,从正态分布中采样一个随机值 :math:`x` 。
然后,对采样的随机值 :math:`x` 进行截断操作,将其限制在指定的范围 ``[a, b]`` 内。具体而言,如果 :math:`x` 小于下界 ``a`` ,则将它替换为 ``a`` ;如果x大于上界 ``b``,则将它替换为 ``b``。最后,将截断后的值作为初始化值赋给张量 ``var`` 的对应元素。
参数:
- var (``Var``): 一个 ``n`` 维的变量
- mean (``float``,可选): 正态分布的均值,默认值:``0.0``
- std (``float``,可选): 正态分布的标准差,默认值:``1.0``
- a (``float``,可选): 最小截断值,默认值:``-2.0``
- b (``float``,可选): 最大截断值,默认值: ``2.0``
代码示例:
>>> from jittor import init
>>> from jittor import nn
>>> linear = nn.Linear(2,2)
>>> init.trunc_normal_(linear.weight, std=.02)
>>> linear.weight
返回值:
使用截断正态分布填充的输入变量(``Var``)
"""
return var.assign(_no_grad_trunc_normal_(var, mean, std, a, b))
Var.trunc_normal_ = trunc_normal_
def _no_grad_trunc_normal_(var, mean, std, a, b):
"""
使用截断正态分布初始化给定变量的值。与 :math:`trunc\\_normal\\_` 函数不同的是,该函数在执行初始化时不会进行梯度计算,即不会自动记录梯度。
该函数使用均值为 ``mean`` , 标准差为 ``std`` , 上下限为 ``a`` , ``b`` 的截断正态分布初始化给定的变量 ``var`` .当均值 ``mean`` 超出 ``[a-2*std, b+2*std]`` 的范围时,会弹出警告提示。
在数学上,截断正态分布的累积分布函数为 :math:`F(x)=\\frac{1+\\text{erf}(x / \\sqrt{2})}{2}`。
参数:
- var (jt.Var): 要初始化的变量。
- mean (float): 截断正态分布的均值。
- std (float): 截断正态分布的标准差。
- a (float): 截断正态分布的下界。
- b (float): 截断正态分布的上界。
返回值:
初始化后的变量(Var)。
代码示例:
>>> import jittor as jt
>>> var = jt.zeros([3, 3])
>>> var = jt._no_grad_trunc_normal_(var, mean=0., std=1., a=-2., b=2.)
"""
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
# var.uniform(2 * l - 1, 2 * u - 1)
var.uniform_(low=2 * l - 1, high=2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
var = var.erfinv()
# Transform to proper mean, std
var = var.multiply(std * math.sqrt(2.))
var = var.add(mean)
# Clamp to ensure it's in the proper range
var = var.clamp(min_v=a, max_v=b)
return var