# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from functools import partial
#TODO:full_matrices=1
[文档]
def svd(x):
"""
计算输入矩阵的奇异值分解。计算方式遵循公式: :math:`x` = :math:`u` :math:`s` :math:`v` , 其中 :math:`u` 、 :math:`s` 、 :math:`v` 分别是输入矩阵的左奇异向量、奇异值和右奇异向量。
参数:
- x (ndarray, jittor.Var): 输入的待分解矩阵。
返回值:
返回奇异值分解后的各部分, u,s,v(ndarray)
代码示例:
>>> from jittor.linalg import svd
>>> X = jt.Var([[1., 0., 0., 0., 2.], [0., 0., 3., 0., 0.], [0., 0., 0., 0., 0.], [0., 2., 0., 0., 0.]])
>>> u, s, v = svd(X)
>>> print(u)
jt.Var([[-0. 1. 0. 0.]
[ 0. 0. 1. 0.]
[-1. 0. 0. 0.]
[ 0. 0. 0. 1.]], dtype=float32)
"""
def forward_code(np, data):
a = data["inputs"][0]
u, s, v = data["outputs"]
#TODO:remove copyto
tu, ts, tv = np.linalg.svd(a, full_matrices=0)
np.copyto(u, tu)
np.copyto(s, ts)
np.copyto(v, tv)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
out_index = data["out_index"]
u, s, v = data["f_outputs"]
v = T(v)
m, n = inp.shape[-2:]
k = np.min((m, n))
i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k))))
if out_index == 0:
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
gu = dout
utgu = _dot(T(u), gu)
t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :]
t = _dot(_dot(u, t), T(v))
if m > n:
i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) -
_dot(u, np.conj(T(u))))
t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut))
np.copyto(out, t)
elif out_index == 1:
gs = dout
t = i * gs[..., :, np.newaxis]
t = _dot(_dot(u, t), T(v))
np.copyto(out, t)
elif out_index == 2:
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
gv = dout
vtgv = _dot(T(v), gv)
t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv)))
t = _dot(_dot(u, t), T(v))
if m < n:
i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) -
_dot(v, np.conj(T(v))))
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
np.copyto(out, t)
m, n = x.shape[-2:]
k = min(m, n)
s1 = list(x.shape)
s1[-1] = k
s2 = list(x.shape)
s2[-2] = k
s3 = list(x.shape)[:-2]
s3.append(k)
u, s, v = jt.numpy_code(
[s1, s3, s2],
[x.dtype, x.dtype, x.dtype],
[x],
forward_code,
[backward_code],
)
return u, s, v
[文档]
def eigh(x):
"""
计算输入矩阵的特征值和特征向量。特征值和特征向量的计算公式:
.. math::
A v = \\lambda v
其中, :math:`A` 是单位矩阵, :math:`v` 是特征向量, :math:`\\lambda` 是特征值。
参数:
- x (Var): 输入矩阵, 维度为(..., :math:`M` , :math:`M` ), 其中 :math:`M` 是矩阵的大小。
返回值:
返回两个变量的元组( :math:`w` , :math:`v` )(tuple ( ``Var`` , ``Var`` )), 其中 :math:`w` 是特征值, 维度为(..., :math:`M` ), :math:`v` 是规范化后的特征向量, 维度为(..., :math:`M` , :math:`M` )。
代码示例:
>>> x = jt.random((2, 2))
jt.Var([[ 0.9814584 -0.1916754 ]
[-0.8806686 -0.47373292]], dtype=float32)
>>> jt.linalg.eigh(x)
(jt.Var([-0.54820526 1.4303665 ], dtype=float32), jt.Var([[-0.1916754 -0.9814584 ]
[-0.47373292 0.8806686 ]], dtype=float32))
"""
def forward_code(np, data):
a = data["inputs"][0]
w, v = data["outputs"]
tw, tv = np.linalg.eigh(a, UPLO='L')
np.copyto(w, tw)
np.copyto(v, tv)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
out_index = data["out_index"]
w, v = data["f_outputs"]
k = int(inp.shape[-1])
w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1)
if out_index == 0:
t = _dot(v * dout[..., np.newaxis, :], T(v))
np.copyto(out, t)
elif out_index == 1:
if np.any(dout):
off_diag = np.ones((k, k)) - np.eye(k)
F = off_diag / (T(w_repeated) - w_repeated + np.eye(k))
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
np.copyto(out, t)
sw = x.shape[:-2] + x.shape[-1:]
sv = x.shape
w, v = jt.numpy_code(
[sw, sv],
[x.dtype, x.dtype],
[x],
forward_code,
[backward_code],
)
return w, v
[文档]
def inv(x):
"""
计算输入矩阵的逆。
参数:
- x (Var): 输入Var。
返回值:
逆矩阵(Var) 。
代码示例:
>>> x = jt.random((2, 2))
jt.Var([[ 0.9814584 -0.1916754 ]
[-0.8806686 -0.47373292]], dtype=float32)
>>> jt.linalg.inv(x)
jt.Var([[ 2.520024 -0.4921519 ]
[-1.1624513 -0.62531066]], dtype=float32)
"""
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
t_a = np.linalg.inv(a)
np.copyto(m_a, t_a)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
lmx = data["f_outputs"]
mx = lmx[0]
t = -_dot(_dot(T(mx), dout), T(mx))
np.copyto(out, t)
lmx = jt.numpy_code(
[x.shape],
[x.dtype],
[x],
forward_code,
[backward_code],
)
mx = lmx[0]
return mx
[文档]
def pinv(x):
"""
实现对输入矩阵的伪逆运算, 该算子支持梯度反向传播。
参数:
- x (numpy.ndarray): (..., :math:`M` , :math:`N` )类型的维度数组或矩阵。
返回值:
返回 :math:`x` 的伪逆矩阵(numpy.ndarray) (..., :math:`N` , :math:`M` );对于每个高阶张量的最后两个维度进行计算, 这两个维度被视为矩阵, 其他高阶维度被视为批次。
代码示例:
>>> import jittor as jt
>>> x = jt.array([[1.0,2.0],[3.0,4.0]])
>>> y = jt.linalg.pinv(x)
>>> print(y)
jt.Var([[-1.9999999 1. ]
[ 1.5 -0.5 ]], dtype=float32)
"""
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
t_a = np.linalg.pinv(a)
np.copyto(m_a, t_a)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
lmx = data["f_outputs"]
mx = lmx[0]
t = T(
-_dot(_dot(mx, T(dout)), mx)
+ _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx))
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
)
np.copyto(out, t)
sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
lmx = jt.numpy_code(
[sw],
[x.dtype],
[x],
forward_code,
[backward_code],
)
mx = lmx[0]
return mx
[文档]
def det(x):
"""
计算输入矩阵的行列式值。
参数:
- x (Var): 输入的张量, 要求为形状 ``(..., M, M)`` 的张量, 其中 :math:`M` 是矩阵的维度。此张量可以是任意维度, 最后两个维度被视为2D矩阵。
返回值:
输入矩阵 :math:`x` 的行列式值, 形状为 ``(..., 1)`` 的 ``Var``。
代码示例:
>>> import jittor as jt
>>> x = jt.array([[1.0,2.0],[3.0,4.0]])
>>> print(jt.linalg.det(x))
jt.Var([-2.], dtype=float32)
"""
def forward_code(np, data):
a = data["inputs"][0]
L = data["outputs"][0]
tL = np.linalg.det(a)
np.copyto(L, tL)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
inp = data["inputs"][0]
n_d = np.reshape(dout, np.shape(dout) + (1, 1))
n_o = np.reshape(f_out, np.shape(f_out) + (1, 1))
s = n_d * n_o * T(np.linalg.inv(inp))
np.copyto(out, s)
s = x.shape
x_s = s[:-2]
if len(s) == 2:
x_s.append(1)
l_det = jt.numpy_code(
[x_s],
[x.dtype],
[x],
forward_code,
[backward_code],
)
det = l_det[0]
return det
[文档]
def slogdet(x):
"""
计算输入矩阵的行列式的符号与对数值。其采用 ``LU`` 分解来计算行列式的符号和自然对数值:
.. math::
det(a) = sign * exp(logdet)
其中, ``sign`` 表示行列式的符号, ``logdet`` 表示行列式的自然对数值。
参数:
- x (Var): ``(..., M, M)`` 维张量, 表示需要求行列式的矩阵。其中, :math:`M` 代表矩阵的行列数。
返回值:
返回两个张量, 第一个张量代表行列式的符号, 第二个张量代表行列式的对数值。
代码示例:
>>> import jittor as jt
>>> x = jt.array([[1.0,2.0],[3.0,4.0]])
>>> sign, logdet = jt.linalg.slogdet(x)
>>> print(f\"Sign: {sign.data}, LogDet: {logdet.data}\")
Sign: [-1.], LogDet: [0.6931472]
"""
def forward_code(np, data):
a = data["inputs"][0]
sign, m_a = data["outputs"]
sign_, t_a = np.linalg.slogdet(a)
np.copyto(m_a, t_a)
np.copyto(sign, sign_)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
out_index = data["out_index"]
if out_index == 0:
np.copyto(out, 0)
if out_index == 1:
t = np.reshape(dout, np.shape(dout) + (1, 1))
t = t * T(np.linalg.inv(inp))
np.copyto(out, t)
s = x.shape
det_s = s[:-2]
if len(det_s) == 0:
det_s.append(1)
sign, mx = jt.numpy_code(
[det_s, det_s],
[x.dtype, x.dtype],
[x],
forward_code,
[backward_code],
)
return sign, mx
[文档]
def cholesky(x):
"""
对输入矩阵的Cholesky分解, 形式如下公式:
.. math::
x = LL^T
其中 :math:`x` 必须是 ``Hermite`` 和正定矩阵。 :math:`L` 是一个下三角矩阵。
参数:
- x (Var): 输入的二维矩阵, 维度为 ``(..., M, M)`` 。 :math:`x` 应满足如下条件:首先, :math:`x` 应是正定的;其次, :math:`x` 应是 ``Hermite`` 的(即 :math:`x` 等于其共轭转置)。 :math:`x` 可以是实矩阵或者复矩阵。
返回值:
返回 :math:`x` 的Cholesky分解的下三角矩阵 :math:`L` ( ``Var`` ), 其维度与 :math:`x` 相同, 为 ``(..., M, M)`` 。
代码示例:
>>> import jittor as jt
>>> a = jt.array([[4.0, 12, -16], [12, 37, -43], [-16, -43, 98]])
>>> print(jt.linalg.cholesky(a))
jt.Var([[ 2. 0. 0.]
[ 6. 1. 0.]
[-8. 5. 3.]], dtype=float32)
"""
def forward_code(np, data):
a = data["inputs"][0]
L = data["outputs"][0]
tL = np.linalg.cholesky(a)
np.copyto(L, tL)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
solve_trans = lambda a, b: np.linalg.solve(T(a), b)
phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))
def conjugate_solve(L, X):
return solve_trans(L, T(solve_trans(L, T(X))))
s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
s = (s + T(s)) / 2.
np.copyto(out, s)
lL = jt.numpy_code(
[x.shape],
[x.dtype],
[x],
forward_code,
[backward_code],
)
L = lL[0]
return L
[文档]
def solve(a,b):
"""
求解线性矩阵方程 :math:`Ax = B` , 通过以下公式实现:
.. math::
x = A^{-1}B
其中 ``A`` 是一个非奇异矩阵, ``B`` 是一个向量或者矩阵。函数的正向传播将通过Numpy的linalg.solve实现,
反向传播将通过计算梯度公式实现。函数的反向传播在 ``A`` 或 ``B`` 中存在0的情况下可能会得到不正确的结果。
参数:
- a(Var): 线性方程组的系数矩阵 :math:`A`。
- b(Var): 线性方程组的常数向量 :math:`B`。
返回值:
线性方程组的解向量(Var)。
代码示例:
>>> a = jt.array([[1., 2.], [3., 4.]])
>>> b = jt.array([5., 6.])
>>> from jittor.linalg import solve
>>> print(solve(a, b))
jt.Var([-4. 4.5], dtype=float32)
"""
def forward_code(np, data):
a, b = data["inputs"]
L = data["outputs"][0]
ans = np.linalg.solve(a, b)
np.copyto(L, ans)
def backward_code1(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
inp = data["inputs"][0]
updim = lambda x: x if x.ndim == a.ndim else x[..., None]
t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
np.copyto(out, t)
def backward_code2(np, data):
out = data["outputs"][0]
np.copyto(out, 0)
l_ans = jt.numpy_code(
[b.shape],
[b.dtype],
[a, b],
forward_code,
[backward_code1, backward_code2],
)
ans = l_ans[0]
return ans
[文档]
def qr(x):
"""
对输入矩阵进行 :math:`QR` 分解。
参数:
x (array): 要进行 :math:`QR` 分解的矩阵, 形状为(M,M)。
返回值:
:math:`QR` 分解的结果, 是形状为(M,M)的矩阵。
代码示例:
>>> import jittor as jt
>>> x = jt.random((2, 2))
>>> q, r = jt.linalg.qr(x)
>>> print(q, r)
jt.Var([[-0.9639901 -0.2659382]
[-0.2659382 0.9639901]], dtype=float32) jt.Var([[-1.0051305 -1.0211498 ]
[ 0. 0.29402444]], dtype=float32)
"""
def forward_code(np, data):
a = data["inputs"][0]
q, r = data["outputs"]
Q, R = np.linalg.qr(a)
np.copyto(q,Q)
np.copyto(r,R)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
_harmard = partial(np.einsum, '...ij,...ij->...ij')
dout = data["dout"]
out = data["outputs"][0]
q, r = data["f_outputs"]
out_index = data["out_index"]
#pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
if out_index == 0: # Q_TERM
q_t = _dot(T(q),dout)
rhs_solve = q_t - T(q_t)
rhs_solve = T(np.tril(rhs_solve,-1))
qsolve = np.linalg.solve(r,rhs_solve)
qsolve = T(qsolve)
tq = _dot(q,qsolve)
np.copyto(out,tq)
else: #R_TERM
r_t = _dot(r ,T(dout))
rhs_solve = r_t - T(r_t)
rhs_solve = np.tril(rhs_solve,-1)
rhs_solve = T(rhs_solve)
r_solve = np.linalg.solve(r,rhs_solve)
tr = _dot(q,(T(r_solve) + dout))
np.copyto(out,tr)
q, r = jt.numpy_code(
[x.shape,x.shape],
[x.dtype,x.dtype],
[x],
forward_code,
[backward_code],
)
return q, r
[文档]
def einsum(string, *args):
"""
实现 :math:`einsum(Einstein Summation)` 操作。
参数:
- string(str): 用以描述运算的字符串, 在 :math:`einsum` 中指定输入、输出和对决定输出的维度上的操作。
- args(Sequence[Var]): 待执行 :math:`einsum` 操作的输入数据, 可以是一个或多个的数据。
返回值:
经过einsum操作后的Var。
代码示例:
>>> import jittor as jt
>>> a = jt.random((1, 2, 1))
>>> b = jt.random((1, 1, 2))
>>> c = jt.linalg.einsum('bij,bjk->bik', a, b)
>>> print(c)
jt.Var([[[0.10123717 0.6376321 ]
[0.10257126 0.64603466]]], dtype=float32)
"""
import numpy as np_cpu
if string == "i,j->ij":
return args[0].broadcast((args[0].shape[0], args[1].shape[0]), dims=[1]).multiply(args[1])
def forward_code(np, data):
out = data["outputs"][0]
npout = np.einsum(string, *data["inputs"])
np.copyto(out, npout)
def backward_code(np, data, argnum=0):
real_len = len(data["inputs"]) - 2
operands = data["inputs"][:real_len]
_ops = operands
if np_cpu is not np:
# fake a numpy array
_ops = [ np_cpu.zeros((1,)*o.ndim) for o in _ops ]
in_subs, out_subs, _ = np_cpu.core.einsumfunc._parse_einsum_input([string] + _ops)
dout = data["dout"]
out_index = data["out_index"]
out = data["outputs"][0]
inp = data["inputs"][argnum]
c = data["f_outputs"]
in_subs_list = in_subs.split(',')
op_num = argnum
subs_wrt = in_subs_list[op_num]
rest_of_ops = operands[:op_num] + operands[op_num+1:]
rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num+1:]
other_named_subs = set(''.join([out_subs] + rest_of_subs))
naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt)
if sub not in other_named_subs]
if naked_summed:
naked_summed_dims, ones_subs = zip(*naked_summed)
ones_subs = ''.join(ones_subs)
ones = np_cpu.ones(np_cpu.array(operands[op_num].shape)[list(naked_summed_dims)])
new_input_subs = ','.join([out_subs, ones_subs] + rest_of_subs)
new_operands = [dout, ones] + rest_of_ops
else:
new_input_subs = ','.join([out_subs] + rest_of_subs)
new_operands = [dout] + rest_of_ops
new_subscripts = new_input_subs + '->' + subs_wrt
x = np.einsum(new_subscripts, *new_operands)
while np.ndim(x) > np.ndim(inp):
x = np.sum(x, axis=broadcast_idx)
for axis, size in enumerate(inp.shape):
if size == 1:
x = np.sum(x, axis=axis, keepdims=True)
np.copyto(out, x)
def einsum_outshape(einsum_expr, inputs):
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
p = einsum_expr.replace(" ", "").split(',')
s = p[:-1] + p[-1].split('->')
rec_shape = []
ellip_expr = None
const_rep = '1234567890' # assume tensor shape no more than 10 dimensions
for idx, expr in enumerate(s[:-1]):
if "..." in expr:
assert "..." in s[-1]
else:
continue
shp = inputs[idx].shape
ellipsis_pos = len(expr.replace("...", ""))
nellip_expr = const_rep[0 : len(shp) - ellipsis_pos]
if ellip_expr is None:
ellip_expr = nellip_expr
else:
assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis."
s[idx] = expr.replace("...", ellip_expr)
if ellip_expr:
s[-1] = s[-1].replace("...", ellip_expr)
if s[-1]=='':
return ()
else:
inop = list(map(list,s))
return tuple(shps[(np_cpu.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)].astype(np_cpu.int64))
output_shape = [int(x) for x in einsum_outshape(string, args)]
backwards = [partial(backward_code, argnum=idx) for idx in range(len(args))]
a = jt.numpy_code(
[output_shape],
[args[0].dtype],
args,
forward_code,
backwards,
)[0]
return a