jittor.contrib

这里是Jittor的贡献代码模块模块的API文档,此模块的代码可能还没有完全成熟,我们将在后续迭代开发中继续完善,您可以通过from jittor import contrib来获取该模块。

jittor.contrib.argmax_pool(x, size, stride, padding=0)[源代码]

对输入张量进行最大值池化操作。即在池化窗口内找到最大值作为该窗口的值。该操作按照给定的步长值来移动池化窗口。

\[Y = \max_{(i, j) \in \text{{window}}(size)} X_{i, j, k}\]

其中,\(\text{{window}}(size)\) 是池化窗口,\(X_{i,j,k}\) 是输入张量在 \((i, j, k)\) 位置的元素,\(Y\) 是进行最大池化操作后的输出张量值。

参数:
  • x(Var): 输入的张量。

  • size(int): 池化窗口尺寸。

  • stride(int): 池化窗口的移动步长,步长值确定了池化窗口的移动速度。

  • padding(int, 可选): 输入的每一条边补充 0 的层数,默认值: 0

代码示例:
>>> from jittor.contrib import argmax_pool
>>> input_array = jt.random([1, 1, 4, 4])
jt.Var([[[[0.49449673 0.00643021 0.07254869 0.3258533 ]
  [0.61617774 0.09950083 0.3104945  0.48131013]
  [0.37913334 0.09407917 0.18861724 0.09006661]
  [0.7495838  0.25495356 0.00436674 0.3918325 ]]]], dtype=float32)
>>> output = argmax_pool(input_array, 2, 2)
jt.Var([[[[0.61617774 0.48131013]
  [0.7495838  0.3918325 ]]]], dtype=float32)
返回值:

经过最大池化操作后的结果(Var)

jittor.contrib.check(bc)[源代码]

检查bc中的每个元素是否等于1或等于bc中维度0的最大值。

参数:
  • bc(Var): 输入的数组,代表要进行检查的数组。

代码示例:
>>> import jittor as jt
>>> bc = jt.Var([[1, 2, 3], [1, 1, 1]])
>>> print(jt.contrib.check(bc))
[1 2 3]
>>> bc = jt.Var([[1, 2, 3], [1, 4, 1]])
>>> print(jt.contrib.check(bc))
Exception: Shape not match.
返回值:

返回输入数组的按照轴 0 进行最大值操作后的结果( int )。

jittor.contrib.setitem(x, slices, value)[源代码]

对数组 x 进行切片赋值。函数的目标是将一个数组的切片赋予一个特定的值。首先通过 slice, var, index 函数处理切片信息,然后创建一个与目标切片相同形状的广播值。然后将广播值累加到目标切片,并将得到的结果赋值回原数组。

参数:
  • x (Var): 原始数组。

  • slices (int, slice object 或者 tuple): 对数组的切片信息。如果是 tuple,其长度需要和数组x的维度一致。

  • value (int, float 或者 Var): 要赋给数组切片的值。

代码示例:
>>> from jittor.contrib import setitem
>>> import jittor as jt
>>> x = jt.array([0, 1, 2, 3, 4])
>>> setitem(x, slice(1, 4), 9)
jt.Var([0 9 9 9 4], dtype=int32)
返回值:

赋值后的数组(Var)。

jittor.contrib.slice_var_index(x, slices)[源代码]

对于给定的变量 x 和切片 slices ,执行切片操作。切片操作根据数组的索引范围、步长等信息来获取数组的一部分。该函数主要用于实现切片操作,可以将 slices 中的切片应用到 x 上,返回一个新的张量。

参数值:
  • 输入的张量(Var)

  • slices( tuple, list, numpy.ndarray,Var ):
    • 切片可以是 int, slice, bool 等类型,具体由输入数据决定。

    • 如果 slices 不是元组,那么将 slices 转换为元组。

    • 如果 len(slices)==1slices[0]dtype = bool,则将 slices[0] 重定位。

代码示例:
>>> import jittor as jt
>>> from jittor.contrib import slice_var_index
>>> x = jt.array([[1, 2, 3], [4, 5, 6]]) 
>>> slices = (0, slice(1, None, 2)) 
>>> out_shape, out_index, _, __, ___ = slice_var_index(x, slices) 
>>> print(out_shape)
[1, 1]
>>> print(out_index)
['0', '1+i1*2']
返回值:
  • out_shape: 输出张量形状的列表(list)。

  • out_index: 输出张量索引的列表(list)。

  • 0: 一个常数 0,表示无额外输出。

  • []: 空列表(list),表示无额外输出。

  • extras: 其中包含需要执行额外操作的切片(list)。