jittor.contrib 源代码

# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved. 
# Maintainers: 
#     Guowei Yang <471184555@qq.com>
#     Guoye Yang <498731903@qq.com>
#     Dun Liang <randonlang@gmail.com>. 
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
from jittor import pool
from collections.abc import Sequence


[文档] def argmax_pool(x, size, stride, padding=0): """ 对输入张量进行最大值池化操作。即在池化窗口内找到最大值作为该窗口的值。该操作按照给定的步长值来移动池化窗口。 .. math:: Y = \\max_{(i, j) \\in \\text{{window}}(size)} X_{i, j, k} 其中,:math:`\\text{{window}}(size)` 是池化窗口,:math:`X_{i,j,k}` 是输入张量在 :math:`(i, j, k)` 位置的元素,:math:`Y` 是进行最大池化操作后的输出张量值。 参数: - x(``Var``): 输入的张量。 - size(``int``): 池化窗口尺寸。 - stride(``int``): 池化窗口的移动步长,步长值确定了池化窗口的移动速度。 - padding(``int``, 可选): 输入的每一条边补充 0 的层数,默认值: ``0`` 代码示例: >>> from jittor.contrib import argmax_pool >>> input_array = jt.random([1, 1, 4, 4]) jt.Var([[[[0.49449673 0.00643021 0.07254869 0.3258533 ] [0.61617774 0.09950083 0.3104945 0.48131013] [0.37913334 0.09407917 0.18861724 0.09006661] [0.7495838 0.25495356 0.00436674 0.3918325 ]]]], dtype=float32) >>> output = argmax_pool(input_array, 2, 2) jt.Var([[[[0.61617774 0.48131013] [0.7495838 0.3918325 ]]]], dtype=float32) 返回值: 经过最大池化操作后的结果(``Var``) """ return pool.pool(x, size, 'maximum', padding, stride)
def concat(arr, dim): """ 沿指定维度连接输入张量序列。所有输入张量的尺寸必须匹配,除了连接维度上的尺寸。在连接维度上,所有其他尺寸必须相同。 参数: - arr (``list``): 要连接的张量序列 - dim (``int``): 要连接的维度。默认值:``0`` 代码示例: >>> jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1) jt.Var([[1,2],[2,2]],dtype=int32) 返回值: 连接后的张量(``Var``) """ # TODO: low performance when concat lots of vars total_dim = 0 if dim < 0: dim += len(arr[0].shape) for a in arr: total_dim += a.shape[dim] cdim = 0 s = None indexes = [ f"i{i}" for i in range(len(a.shape)) ] for a in arr: shape = list(a.shape) shape[dim] = total_dim indexes[dim] = f"i{dim}-{cdim}" b = a.reindex(shape, indexes) # ugly fix for preventing large fused op if len(arr)>=100: b.stop_fuse() if s is None: s = b else: s += b cdim += a.shape[dim] return s
[文档] def check(bc): """ 检查bc中的每个元素是否等于1或等于bc中维度0的最大值。 参数: - bc(``Var``): 输入的数组,代表要进行检查的数组。 代码示例: >>> import jittor as jt >>> bc = jt.Var([[1, 2, 3], [1, 1, 1]]) >>> print(jt.contrib.check(bc)) [1 2 3] >>> bc = jt.Var([[1, 2, 3], [1, 4, 1]]) >>> print(jt.contrib.check(bc)) Exception: Shape not match. 返回值: 返回输入数组的按照轴 ``0`` 进行最大值操作后的结果( ``int`` )。 """ bc = np.array(bc) if ((bc != 1) * (bc != bc.max(0))).sum() > 0: raise Exception(f"Shape not match.") else: return bc.max(0)
[文档] def slice_var_index(x, slices): """ 对于给定的变量 ``x`` 和切片 ``slices`` ,执行切片操作。切片操作根据数组的索引范围、步长等信息来获取数组的一部分。该函数主要用于实现切片操作,可以将 ``slices`` 中的切片应用到 ``x`` 上,返回一个新的张量。 参数值: - 输入的张量(Var) - slices( ``tuple, list, numpy.ndarray,Var`` ): - 切片可以是 ``int, slice, bool`` 等类型,具体由输入数据决定。 - 如果 ``slices`` 不是元组,那么将 ``slices`` 转换为元组。 - 如果 ``len(slices)==1`` 且 ``slices[0]`` 的 ``dtype = bool``,则将 ``slices[0]`` 重定位。 代码示例: >>> import jittor as jt >>> from jittor.contrib import slice_var_index >>> x = jt.array([[1, 2, 3], [4, 5, 6]]) >>> slices = (0, slice(1, None, 2)) >>> out_shape, out_index, _, __, ___ = slice_var_index(x, slices) >>> print(out_shape) [1, 1] >>> print(out_index) ['0', '1+i1*2'] 返回值: - out_shape: 输出张量形状的列表(``list``)。 - out_index: 输出张量索引的列表(``list``)。 - 0: 一个常数 0,表示无额外输出。 - []: 空列表(``list``),表示无额外输出。 - extras: 其中包含需要执行额外操作的切片(``list``)。 """ if not isinstance(slices, tuple): slices = (slices,) if isinstance(slices[0], jt.Var): if len(slices) == 1 and slices[0].dtype == "bool": return slice_var_index(x, tuple(slices[0].where())) bc = [] ml = -1 for idx, s in enumerate(slices): if isinstance(s, jt.Var): shape = s.shape elif isinstance(s, np.ndarray): shape = list(s.shape) elif isinstance(s, list): shape = list(np.array(s).shape) else: continue if len(shape) >= ml: ml = len(shape) bc.append(shape) for idx, shape in enumerate(bc): if len(shape) < ml: shape = (ml - len(shape)) * [1] + shape bc[idx] = shape if len(bc) >= 1: bc_shape = check(bc) ss = [] for idx, s in enumerate(slices): if isinstance(s, np.ndarray) or isinstance(s, list): ss.append(jt.array(s).broadcast(bc_shape.tolist())) elif isinstance(s, jt.Var): ss.append(s.broadcast(bc_shape.tolist())) else: ss.append(s) slices = ss out_shape = [] out_index = [] shape = x.shape cnt_list = 0 extras_idx = [] extras = [] has_ellipse = 0 ellipse_index = 0 for s,i in zip(slices,range(len(slices))): if isinstance(s,type(...)): has_ellipse+=1 ellipse_index = i if has_ellipse>1: raise Exception(f"There are more than one ...") elif has_ellipse==1: slices = list(slices) del slices[ellipse_index] while len(slices)<len(shape): slices.insert(ellipse_index,slice(None)) for i in range(len(shape)): if i>=len(slices): s = slice(None) else: s = slices[i] sp = shape[i] j = len(out_shape) if isinstance(s, int): if s<0: s += sp out_index.append(str(s)) elif isinstance(s, slice): if s == slice(None): out_shape.append(sp) out_index.append(f"i{j}") continue start = 0 if s.start is None else s.start stop = sp if s.stop is None else s.stop step = 1 if s.step is None else s.step if start<0: start += sp if stop<0: stop += sp if stop>sp+1: stop = sp out_shape.append(1+int(max(0, (stop-start-1)//step))) out_index.append(f"{start}+i{j}*{step}") elif isinstance(s, jt.Var): if cnt_list == 0: for idx in range(len(bc_shape)): extras_idx.append(f"i{len(out_shape) + idx}") out_shape += bc_shape.tolist() out_index.append(f"@e{cnt_list}("+ ",".join(extras_idx) + ")") cnt_list += 1 extras.append(s) else: raise Exception(f"Not support slice {s}") if len(out_shape)==0: out_shape = [1] # Stop fuse both input and output, prevent recompile x.stop_fuse() return (out_shape, out_index, 0, [], extras)
def _slice_var_old(x, slices): """ 使用指定的片段对给定变量进行重索引。函数首先通过调用 :math:`slice` _ :math:`var` _ :math:`index` 函数生成一个新的索引片段对变量进行索引,并通过调用 :math:`reindex` 函数对变量进行重索引。 在开始和结束重索引之前,通过调用 :math:`stop` _ :math:`fuse` 函数停止梯度融合以防止误导梯度计算。 参数: - x(Var): 输入的 Jittor 变量。 - slices(tuple, list): 用于索引的切片。此参数应为一个元组或列表,包含一个或多个整数或者切片对象。 代码示例: >>> import jittor as jt >>> from jittor.contrib import _slice_var_old >>> x = jt.array([1, 2, 3, 4, 5]) >>> slices = (2, 4) >>> result = _slice_var_old(x, slices) jt.Var([3], dtype=int32) 返回值: 重新索引后的变量(Var)。 """ reindex_args = slice_var_index(x, slices) x.stop_fuse() return x.reindex(*reindex_args).stop_fuse() def _setitem_old(x, slices, value): """ 对数组 ``x`` 进行切片赋值。函数的目标是将一个数组的切片赋予一个特定的值。首先通过 :math:`slice` :math:`var` :math:`index` 函数处理切片信息,然后创建一个与目标切片相同形状的广播值。然后将广播值累加到目标切片,并将得到的结果赋值回原数组。 参数: - x (Var): 原始数组。 - slices (int, slice object 或者 tuple): 对数组的切片信息。如果是 :math:`tuple`,其长度需要和数组x的维度一致。 - value (int, float 或者 Var): 要赋给数组切片的值。 代码示例: >>> from jittor.contrib import _setitem_old >>> import jittor as jt >>> x = jt.array([0, 1, 2, 3, 4]) >>> _setitem_old(x, slice(1, 4), 9) [0, 9, 9, 9, 4] 返回值: 赋值后的数组(Var)。 """ reindex_args = slice_var_index(x, slices) reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:] xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse() value = jt.broadcast(value, xslice) value = value.cast(x.dtype) one = jt.broadcast(1, xslice) if not isinstance(reindex_args[0][0], jt.Var): reindex_args = (x.shape,) + reindex_args[1:] mask = one.reindex_reduce("add", *reindex_reduce_args) data = value.reindex_reduce("add", *reindex_reduce_args) # Stop fuse both input and output, prevent recompile out = mask.ternary(data, x).stop_fuse() x.assign(out) return x # PATCH def getitem(x, slices): """ 对数组 ``x`` 进行切片获取。函数的目标是从一个数组中按照指定的切片信息提取部分数据。首先处理切片信息,对于布尔类型的Var,会先转换为对应的索引。然后根据处理后的切片信息从原数组中获取对应部分。 参数: - x (``Var``): 原始数组。 - slices (``int``, ``slice object``, ``tuple`` 或者 ``Var``): 对数组的切片信息。如果是 :math:`tuple`,其长度需要和数组x的维度一致。如果是Var且数据类型为"bool",则将其视为布尔索引。 代码示例: >>> import jittor as jt >>> x = jt.array([0, 1, 2, 3, 4]) >>> print(getitem(x, slice(1, 4))) [1, 2, 3] >>> bool_idx = jt.array([False, True, True, False, True]) >>> print(getitem(x, bool_idx)) [1, 2, 4] 返回值: 提取的数组部分(``Var``)。 """ if isinstance(slices, jt.Var) and slices.dtype == "bool": return getitem(x, slices.where()) if isinstance(slices, tuple): ss = [] for s in slices: if isinstance(s, jt.Var) and s.dtype == "bool": ss.extend(s.where()) else: ss.append(s) slices = tuple(ss) return x.getitem(slices)
[文档] def setitem(x, slices, value): """ 对数组 ``x`` 进行切片赋值。函数的目标是将一个数组的切片赋予一个特定的值。首先通过 ``slice, var, index`` 函数处理切片信息,然后创建一个与目标切片相同形状的广播值。然后将广播值累加到目标切片,并将得到的结果赋值回原数组。 参数: - x (``Var``): 原始数组。 - slices (``int``, ``slice object`` 或者 ``tuple``): 对数组的切片信息。如果是 ``tuple``,其长度需要和数组x的维度一致。 - value (``int``, ``float`` 或者 ``Var``): 要赋给数组切片的值。 代码示例: >>> from jittor.contrib import setitem >>> import jittor as jt >>> x = jt.array([0, 1, 2, 3, 4]) >>> setitem(x, slice(1, 4), 9) jt.Var([0 9 9 9 4], dtype=int32) 返回值: 赋值后的数组(``Var``)。 """ if isinstance(slices, jt.Var) and slices.dtype == "bool": if slices.shape == x.shape: if isinstance(value, (int, float)): value = jt.array(value).broadcast(x.shape) return x.assign(slices.ternary(value, x)) elif isinstance(value, jt.Var) and value.shape == [1,]: value = jt.broadcast(value, x.shape) return x.assign(slices.ternary(value, x)) slices = slices.where() elif isinstance(slices, tuple): ss = [] for s in slices: if isinstance(s, jt.Var) and s.dtype == "bool": ss.extend(s.where()) else: ss.append(s) slices = tuple(ss) return x.check_cascade_setitem(x.setitem(slices, value))
jt.Var.__getitem__ = jt.Var.slice_var = getitem jt.Var.__setitem__ = setitem def _merge_dtypes(dtypes): """ 根据输入的数据类型列表,将这些数据类型进行合并返回一个新的数据类型。合并的规则是使用 Jittor 的 :math:`binary` _ :math:`dtype` _ :math:`infer` 函数,其操作方式类似于 :math:`Python` 的 :math:`add` 操作。 其中,:math:`add` 操作的数学公式可以表示为: .. math:: \\text{{new_dtype}} = \\text{{dtype1}} \\oplus \\text{{dtype2}} 其中,:math:`\\oplus` 表示的是 :math:`add` 操作,:math:`\\text{{dtype1}}` 和 :math:`\\text{{dtype2}}` 分别表示两个进行操作的数据类型。 最后,返回结果为float32,表示合并后的新的数据类型。 参数: dtypes(List[str]): 需要被合并的数据类型列表。 代码例子: >>> from jittor.contrib import _merge_dtypes >>> dtypes = [\"float32\", \"int32\"] >>> _merge_dtypes(dtypes) \"float32\" 返回值: 合并后的新的数据类型(str)。 """ dtype = dtypes[0] for i in range(1, len(dtypes)): dtype = jt.binary_dtype_infer("add", dtype, dtypes[i]) return dtype @jt.flag_scope(amp_reg=4) # _custom_flag def concat(arr, dim=0): """ 沿指定维度连接输入张量序列。所有输入张量的尺寸必须匹配,除了连接维度上的尺寸。在连接维度上,所有其他尺寸必须相同。 参数: - arr (list): 要连接的张量序列 - dim (int): 要连接的维度。默认值:0 代码示例: >>> jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1) jt.Var([[1,2],[2,2]],dtype=int32) 返回值: 连接后的张量(Var) """ if not isinstance(arr, Sequence): raise TypeError("concat arr needs to be a tuple or list") if len(arr) == 0: raise ValueError("need at least one array to concat") total_dim = 0 if dim < 0: dim += len(arr[0].shape) dtypes = [] for a in arr: total_dim += a.shape[dim] dtypes.append(str(a.dtype)) cdim = 0 shape = list(a.shape) shape[dim] = total_dim s = jt.empty(shape, dtype = _merge_dtypes(dtypes)) slices = [slice(None)]*len(a.shape) for a in arr: if a.shape[dim] == 0: continue slices[dim] = slice(cdim, cdim+a.shape[dim]) # print(slices, type(a)) s = s.setitem(tuple(slices), a) # s = jt.setitem(s, tuple(slices), a) cdim += a.shape[dim] return s cat = concat