jittor.sparse 源代码

# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved. 
# Maintainers:
#   Dun Liang <randonlang@gmail.com>.
#   Xiangli Li <190569238@qq.com>
#
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

import jittor as jt
import numpy as np

[文档] class SparseVar: '''用于存储稀疏的 Var。 参数: - indices (jt.Var): 索引值,表示稀疏 Var 中非零元素的位置。 - values (jt.Var): 数值,表示稀疏 Var 中非零元素的值。 - shape (jt.NanoVector): 表示稀疏 Var 的形状。 形状: - indices: :math:`(D, N)`,其中 :math:`D` 为稀疏 Var 维数,:math:`N` 为非零元素索引个数。 - values: :math:`(N,)`,其中 :math:`N` 为非零元素索引个数。 - shape: 长度为 :math:`D` 的 NanoVector,其中 :math:`D` 为稀疏 Var 维数。 属性: - ndim: 表示稀疏 Var 的维数。 代码示例: >>> from jittor.sparse import SparseVar >>> indices = jt.array([[0, 1, 2], [1, 2, 3]]) >>> values = jt.array([2, 3, 3]) >>> shape = jt.NanoVector([5, 5,]) >>> a = SparseVar(indices, values, shape) >>> a.to_dense() jt.Var([[0 2 0 0 0] [0 0 3 0 0] [0 0 0 3 0] [0 0 0 0 0] [0 0 0 0 0]], dtype=int32) ''' def __init__(self,indices,values,shape): assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector) self.indices = indices self.values = values self.shape = shape self.ndim = len(shape) def _indices(self): return self.indices def _values(self): return self.values def t(self): indices = list(self.indices.split(1,dim=0)) indices[-1],indices[-2] = indices[-2],indices[-1] indices = jt.concat(indices,dim=0) shape = list(self.shape) shape[-1],shape[-2] = shape[-2],shape[-1] shape = jt.NanoVector(shape) return SparseVar(indices,self.values,shape) def to_dense(self): ret = jt.zeros(self.shape,self.values.dtype) indices = tuple(self.indices.split(1,dim=0)) ret[indices]=self.values return ret
[文档] def sparse_array(indices,values,shape): ''' 将给出的稀疏数组的索引、值和形状用于构建一个Jittor稀疏张量。 对于任何给定的索引,其对应的值不应该为零。 以下为数学表示,其中 :math:`(i_1, i_2, ..., i_N)` 是indices的某一行, :math:`(j)` 是该行在indices中的索引: .. math:: SparseVar[i_1, i_2, ..., i_N] = values[j] 参数: - indices (Var): 稀疏数组的索引,必须是二维整型的Jittor变量。其每一行是代表非零值在最终结果中的位置 - values (Var): 非零值组成的一维浮点型Jittor变量。 - shape (NanoVector): 稀疏张量的形状。长度必须和indices中的列数相同。 返回值: jt.sparse.SparseVar: 一个Jittor的稀疏张量 代码示例: >>> import jittor as jt >>> indices = jt.array([[0, 0], [1, 2]]) >>> values = jt.array([1, 2]) >>> shape = jt.NanoVector([3, 4]) >>> jt.sparse.sparse_array(indices, values, shape).shape [3,4,] ''' return SparseVar(indices,values,shape)
[文档] def spmm(spase_x,y): ''' 稀疏矩阵和密集矩阵的乘法操作。此函数先将稀疏矩阵转化为密集矩阵,然后使用Jittor的矩阵乘法操作将两个矩阵相乘。此函数需要等长的rows和columns,即输入的两个矩阵可以执行矩阵乘法操作。 参数: - spase_x (jittor.sparse.SparseVar): 2维稀疏矩阵。假设有N行和M列 - y (Var): 2维密集矩阵。假设有M行和P列 返回值: jt.Var: 结果矩阵。它是一个N行和P列的密集矩阵 代码示例: >>> import jittor as jt >>> indices = jt.array([[0, 0], [1, 2]]) >>> values = jt.array([1, 2]) >>> shape = jt.NanoVector([3, 4]) >>> sparse_mat = jt.sparse.sparse_array(indices, values, shape) [3,4,] >>> y = jt.randn(4,5) >>> jt.sparse.spmm(sparse_mat, y).shape [3,5,] ''' assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var) assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0] # TODO x = spase_x.to_dense() return jt.matmul(x,y)