jittor.linalg 源代码

# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
#     Haoyang Peng <2247838039@qq.com>
#     Guowei Yang <471184555@qq.com>
#     Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from functools import partial

#TODO:full_matrices=1
[文档] def svd(x): """ 计算输入矩阵的奇异值分解。计算方式遵循公式: :math:`x` = :math:`u` :math:`s` :math:`v` , 其中 :math:`u` 、 :math:`s` 、 :math:`v` 分别是输入矩阵的左奇异向量、奇异值和右奇异向量。 参数: - x (ndarray, jittor.Var): 输入的待分解矩阵。 返回值: 返回奇异值分解后的各部分, u,s,v(ndarray) 代码示例: >>> from jittor.linalg import svd >>> X = jt.Var([[1., 0., 0., 0., 2.], [0., 0., 3., 0., 0.], [0., 0., 0., 0., 0.], [0., 2., 0., 0., 0.]]) >>> u, s, v = svd(X) >>> print(u) jt.Var([[-0. 1. 0. 0.] [ 0. 0. 1. 0.] [-1. 0. 0. 0.] [ 0. 0. 0. 1.]], dtype=float32) """ def forward_code(np, data): a = data["inputs"][0] u, s, v = data["outputs"] #TODO:remove copyto tu, ts, tv = np.linalg.svd(a, full_matrices=0) np.copyto(u, tu) np.copyto(s, ts) np.copyto(v, tv) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] u, s, v = data["f_outputs"] v = T(v) m, n = inp.shape[-2:] k = np.min((m, n)) i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k)))) if out_index == 0: f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i) gu = dout utgu = _dot(T(u), gu) t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :] t = _dot(_dot(u, t), T(v)) if m > n: i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) - _dot(u, np.conj(T(u)))) t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut)) np.copyto(out, t) elif out_index == 1: gs = dout t = i * gs[..., :, np.newaxis] t = _dot(_dot(u, t), T(v)) np.copyto(out, t) elif out_index == 2: f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i) gv = dout vtgv = _dot(T(v), gv) t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv))) t = _dot(_dot(u, t), T(v)) if m < n: i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) - _dot(v, np.conj(T(v)))) t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt)) np.copyto(out, t) m, n = x.shape[-2:] k = min(m, n) s1 = list(x.shape) s1[-1] = k s2 = list(x.shape) s2[-2] = k s3 = list(x.shape)[:-2] s3.append(k) u, s, v = jt.numpy_code( [s1, s3, s2], [x.dtype, x.dtype, x.dtype], [x], forward_code, [backward_code], ) return u, s, v
[文档] def eigh(x): """ 计算输入矩阵的特征值和特征向量。特征值和特征向量的计算公式: .. math:: A v = \\lambda v 其中, :math:`A` 是单位矩阵, :math:`v` 是特征向量, :math:`\\lambda` 是特征值。 参数: - x (Var): 输入矩阵, 维度为(..., :math:`M` , :math:`M` ), 其中 :math:`M` 是矩阵的大小。 返回值: 返回两个变量的元组( :math:`w` , :math:`v` )(tuple ( ``Var`` , ``Var`` )), 其中 :math:`w` 是特征值, 维度为(..., :math:`M` ), :math:`v` 是规范化后的特征向量, 维度为(..., :math:`M` , :math:`M` )。 代码示例: >>> x = jt.random((2, 2)) jt.Var([[ 0.9814584 -0.1916754 ] [-0.8806686 -0.47373292]], dtype=float32) >>> jt.linalg.eigh(x) (jt.Var([-0.54820526 1.4303665 ], dtype=float32), jt.Var([[-0.1916754 -0.9814584 ] [-0.47373292 0.8806686 ]], dtype=float32)) """ def forward_code(np, data): a = data["inputs"][0] w, v = data["outputs"] tw, tv = np.linalg.eigh(a, UPLO='L') np.copyto(w, tw) np.copyto(v, tv) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] w, v = data["f_outputs"] k = int(inp.shape[-1]) w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1) if out_index == 0: t = _dot(v * dout[..., np.newaxis, :], T(v)) np.copyto(out, t) elif out_index == 1: if np.any(dout): off_diag = np.ones((k, k)) - np.eye(k) F = off_diag / (T(w_repeated) - w_repeated + np.eye(k)) t = _dot(_dot(v, F * _dot(T(v), dout)), T(v)) np.copyto(out, t) sw = x.shape[:-2] + x.shape[-1:] sv = x.shape w, v = jt.numpy_code( [sw, sv], [x.dtype, x.dtype], [x], forward_code, [backward_code], ) return w, v
[文档] def inv(x): """ 计算输入矩阵的逆。 参数: - x (Var): 输入Var。 返回值: 逆矩阵(Var) 。 代码示例: >>> x = jt.random((2, 2)) jt.Var([[ 0.9814584 -0.1916754 ] [-0.8806686 -0.47373292]], dtype=float32) >>> jt.linalg.inv(x) jt.Var([[ 2.520024 -0.4921519 ] [-1.1624513 -0.62531066]], dtype=float32) """ def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] t_a = np.linalg.inv(a) np.copyto(m_a, t_a) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] lmx = data["f_outputs"] mx = lmx[0] t = -_dot(_dot(T(mx), dout), T(mx)) np.copyto(out, t) lmx = jt.numpy_code( [x.shape], [x.dtype], [x], forward_code, [backward_code], ) mx = lmx[0] return mx
[文档] def pinv(x): """ 实现对输入矩阵的伪逆运算, 该算子支持梯度反向传播。 参数: - x (numpy.ndarray): (..., :math:`M` , :math:`N` )类型的维度数组或矩阵。 返回值: 返回 :math:`x` 的伪逆矩阵(numpy.ndarray) (..., :math:`N` , :math:`M` );对于每个高阶张量的最后两个维度进行计算, 这两个维度被视为矩阵, 其他高阶维度被视为批次。 代码示例: >>> import jittor as jt >>> x = jt.array([[1.0,2.0],[3.0,4.0]]) >>> y = jt.linalg.pinv(x) >>> print(y) jt.Var([[-1.9999999 1. ] [ 1.5 -0.5 ]], dtype=float32) """ def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] t_a = np.linalg.pinv(a) np.copyto(m_a, t_a) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] lmx = data["f_outputs"] mx = lmx[0] t = T( -_dot(_dot(mx, T(dout)), mx) + _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx)) + _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx) ) np.copyto(out, t) sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]] lmx = jt.numpy_code( [sw], [x.dtype], [x], forward_code, [backward_code], ) mx = lmx[0] return mx
[文档] def det(x): """ 计算输入矩阵的行列式值。 参数: - x (Var): 输入的张量, 要求为形状 ``(..., M, M)`` 的张量, 其中 :math:`M` 是矩阵的维度。此张量可以是任意维度, 最后两个维度被视为2D矩阵。 返回值: 输入矩阵 :math:`x` 的行列式值, 形状为 ``(..., 1)`` 的 ``Var``。 代码示例: >>> import jittor as jt >>> x = jt.array([[1.0,2.0],[3.0,4.0]]) >>> print(jt.linalg.det(x)) jt.Var([-2.], dtype=float32) """ def forward_code(np, data): a = data["inputs"][0] L = data["outputs"][0] tL = np.linalg.det(a) np.copyto(L, tL) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] inp = data["inputs"][0] n_d = np.reshape(dout, np.shape(dout) + (1, 1)) n_o = np.reshape(f_out, np.shape(f_out) + (1, 1)) s = n_d * n_o * T(np.linalg.inv(inp)) np.copyto(out, s) s = x.shape x_s = s[:-2] if len(s) == 2: x_s.append(1) l_det = jt.numpy_code( [x_s], [x.dtype], [x], forward_code, [backward_code], ) det = l_det[0] return det
[文档] def slogdet(x): """ 计算输入矩阵的行列式的符号与对数值。其采用 ``LU`` 分解来计算行列式的符号和自然对数值: .. math:: det(a) = sign * exp(logdet) 其中, ``sign`` 表示行列式的符号, ``logdet`` 表示行列式的自然对数值。 参数: - x (Var): ``(..., M, M)`` 维张量, 表示需要求行列式的矩阵。其中, :math:`M` 代表矩阵的行列数。 返回值: 返回两个张量, 第一个张量代表行列式的符号, 第二个张量代表行列式的对数值。 代码示例: >>> import jittor as jt >>> x = jt.array([[1.0,2.0],[3.0,4.0]]) >>> sign, logdet = jt.linalg.slogdet(x) >>> print(f\"Sign: {sign.data}, LogDet: {logdet.data}\") Sign: [-1.], LogDet: [0.6931472] """ def forward_code(np, data): a = data["inputs"][0] sign, m_a = data["outputs"] sign_, t_a = np.linalg.slogdet(a) np.copyto(m_a, t_a) np.copyto(sign, sign_) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] if out_index == 0: np.copyto(out, 0) if out_index == 1: t = np.reshape(dout, np.shape(dout) + (1, 1)) t = t * T(np.linalg.inv(inp)) np.copyto(out, t) s = x.shape det_s = s[:-2] if len(det_s) == 0: det_s.append(1) sign, mx = jt.numpy_code( [det_s, det_s], [x.dtype, x.dtype], [x], forward_code, [backward_code], ) return sign, mx
[文档] def cholesky(x): """ 对输入矩阵的Cholesky分解, 形式如下公式: .. math:: x = LL^T 其中 :math:`x` 必须是 ``Hermite`` 和正定矩阵。 :math:`L` 是一个下三角矩阵。 参数: - x (Var): 输入的二维矩阵, 维度为 ``(..., M, M)`` 。 :math:`x` 应满足如下条件:首先, :math:`x` 应是正定的;其次, :math:`x` 应是 ``Hermite`` 的(即 :math:`x` 等于其共轭转置)。 :math:`x` 可以是实矩阵或者复矩阵。 返回值: 返回 :math:`x` 的Cholesky分解的下三角矩阵 :math:`L` ( ``Var`` ), 其维度与 :math:`x` 相同, 为 ``(..., M, M)`` 。 代码示例: >>> import jittor as jt >>> a = jt.array([[4.0, 12, -16], [12, 37, -43], [-16, -43, 98]]) >>> print(jt.linalg.cholesky(a)) jt.Var([[ 2. 0. 0.] [ 6. 1. 0.] [-8. 5. 3.]], dtype=float32) """ def forward_code(np, data): a = data["inputs"][0] L = data["outputs"][0] tL = np.linalg.cholesky(a) np.copyto(L, tL) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] solve_trans = lambda a, b: np.linalg.solve(T(a), b) phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1])) def conjugate_solve(L, X): return solve_trans(L, T(solve_trans(L, T(X)))) s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout))) s = (s + T(s)) / 2. np.copyto(out, s) lL = jt.numpy_code( [x.shape], [x.dtype], [x], forward_code, [backward_code], ) L = lL[0] return L
[文档] def solve(a,b): """ 求解线性矩阵方程 :math:`Ax = B` , 通过以下公式实现: .. math:: x = A^{-1}B 其中 ``A`` 是一个非奇异矩阵, ``B`` 是一个向量或者矩阵。函数的正向传播将通过Numpy的linalg.solve实现, 反向传播将通过计算梯度公式实现。函数的反向传播在 ``A`` 或 ``B`` 中存在0的情况下可能会得到不正确的结果。 参数: - a(Var): 线性方程组的系数矩阵 :math:`A`。 - b(Var): 线性方程组的常数向量 :math:`B`。 返回值: 线性方程组的解向量(Var)。 代码示例: >>> a = jt.array([[1., 2.], [3., 4.]]) >>> b = jt.array([5., 6.]) >>> from jittor.linalg import solve >>> print(solve(a, b)) jt.Var([-4. 4.5], dtype=float32) """ def forward_code(np, data): a, b = data["inputs"] L = data["outputs"][0] ans = np.linalg.solve(a, b) np.copyto(L, ans) def backward_code1(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] inp = data["inputs"][0] updim = lambda x: x if x.ndim == a.ndim else x[..., None] t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out))) np.copyto(out, t) def backward_code2(np, data): out = data["outputs"][0] np.copyto(out, 0) l_ans = jt.numpy_code( [b.shape], [b.dtype], [a, b], forward_code, [backward_code1, backward_code2], ) ans = l_ans[0] return ans
[文档] def qr(x): """ 对输入矩阵进行 :math:`QR` 分解。 参数: x (array): 要进行 :math:`QR` 分解的矩阵, 形状为(M,M)。 返回值: :math:`QR` 分解的结果, 是形状为(M,M)的矩阵。 代码示例: >>> import jittor as jt >>> x = jt.random((2, 2)) >>> q, r = jt.linalg.qr(x) >>> print(q, r) jt.Var([[-0.9639901 -0.2659382] [-0.2659382 0.9639901]], dtype=float32) jt.Var([[-1.0051305 -1.0211498 ] [ 0. 0.29402444]], dtype=float32) """ def forward_code(np, data): a = data["inputs"][0] q, r = data["outputs"] Q, R = np.linalg.qr(a) np.copyto(q,Q) np.copyto(r,R) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') _harmard = partial(np.einsum, '...ij,...ij->...ij') dout = data["dout"] out = data["outputs"][0] q, r = data["f_outputs"] out_index = data["out_index"] #pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags if out_index == 0: # Q_TERM q_t = _dot(T(q),dout) rhs_solve = q_t - T(q_t) rhs_solve = T(np.tril(rhs_solve,-1)) qsolve = np.linalg.solve(r,rhs_solve) qsolve = T(qsolve) tq = _dot(q,qsolve) np.copyto(out,tq) else: #R_TERM r_t = _dot(r ,T(dout)) rhs_solve = r_t - T(r_t) rhs_solve = np.tril(rhs_solve,-1) rhs_solve = T(rhs_solve) r_solve = np.linalg.solve(r,rhs_solve) tr = _dot(q,(T(r_solve) + dout)) np.copyto(out,tr) q, r = jt.numpy_code( [x.shape,x.shape], [x.dtype,x.dtype], [x], forward_code, [backward_code], ) return q, r
[文档] def einsum(string, *args): """ 实现 :math:`einsum(Einstein Summation)` 操作。 参数: - string(str): 用以描述运算的字符串, 在 :math:`einsum` 中指定输入、输出和对决定输出的维度上的操作。 - args(Sequence[Var]): 待执行 :math:`einsum` 操作的输入数据, 可以是一个或多个的数据。 返回值: 经过einsum操作后的Var。 代码示例: >>> import jittor as jt >>> a = jt.random((1, 2, 1)) >>> b = jt.random((1, 1, 2)) >>> c = jt.linalg.einsum('bij,bjk->bik', a, b) >>> print(c) jt.Var([[[0.10123717 0.6376321 ] [0.10257126 0.64603466]]], dtype=float32) """ import numpy as np_cpu if string == "i,j->ij": return args[0].broadcast((args[0].shape[0], args[1].shape[0]), dims=[1]).multiply(args[1]) def forward_code(np, data): out = data["outputs"][0] npout = np.einsum(string, *data["inputs"]) np.copyto(out, npout) def backward_code(np, data, argnum=0): real_len = len(data["inputs"]) - 2 operands = data["inputs"][:real_len] _ops = operands if np_cpu is not np: # fake a numpy array _ops = [ np_cpu.zeros((1,)*o.ndim) for o in _ops ] in_subs, out_subs, _ = np_cpu.core.einsumfunc._parse_einsum_input([string] + _ops) dout = data["dout"] out_index = data["out_index"] out = data["outputs"][0] inp = data["inputs"][argnum] c = data["f_outputs"] in_subs_list = in_subs.split(',') op_num = argnum subs_wrt = in_subs_list[op_num] rest_of_ops = operands[:op_num] + operands[op_num+1:] rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num+1:] other_named_subs = set(''.join([out_subs] + rest_of_subs)) naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt) if sub not in other_named_subs] if naked_summed: naked_summed_dims, ones_subs = zip(*naked_summed) ones_subs = ''.join(ones_subs) ones = np_cpu.ones(np_cpu.array(operands[op_num].shape)[list(naked_summed_dims)]) new_input_subs = ','.join([out_subs, ones_subs] + rest_of_subs) new_operands = [dout, ones] + rest_of_ops else: new_input_subs = ','.join([out_subs] + rest_of_subs) new_operands = [dout] + rest_of_ops new_subscripts = new_input_subs + '->' + subs_wrt x = np.einsum(new_subscripts, *new_operands) while np.ndim(x) > np.ndim(inp): x = np.sum(x, axis=broadcast_idx) for axis, size in enumerate(inp.shape): if size == 1: x = np.sum(x, axis=axis, keepdims=True) np.copyto(out, x) def einsum_outshape(einsum_expr, inputs): shps = np_cpu.concatenate([in_.shape for in_ in inputs]) p = einsum_expr.replace(" ", "").split(',') s = p[:-1] + p[-1].split('->') rec_shape = [] ellip_expr = None const_rep = '1234567890' # assume tensor shape no more than 10 dimensions for idx, expr in enumerate(s[:-1]): if "..." in expr: assert "..." in s[-1] else: continue shp = inputs[idx].shape ellipsis_pos = len(expr.replace("...", "")) nellip_expr = const_rep[0 : len(shp) - ellipsis_pos] if ellip_expr is None: ellip_expr = nellip_expr else: assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis." s[idx] = expr.replace("...", ellip_expr) if ellip_expr: s[-1] = s[-1].replace("...", ellip_expr) if s[-1]=='': return () else: inop = list(map(list,s)) return tuple(shps[(np_cpu.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)].astype(np_cpu.int64)) output_shape = [int(x) for x in einsum_outshape(string, args)] backwards = [partial(backward_code, argnum=idx) for idx in range(len(args))] a = jt.numpy_code( [output_shape], [args[0].dtype], args, forward_code, backwards, )[0] return a