jittor.contrib
这里是Jittor的贡献代码模块模块的API文档,此模块的代码可能还没有完全成熟,我们将在后续迭代开发中继续完善,您可以通过from jittor import contrib
来获取该模块。
- jittor.contrib.argmax_pool(x, size, stride, padding=0)[源代码]
对输入张量进行最大值池化操作。即在池化窗口内找到最大值作为该窗口的值。该操作按照给定的步长值来移动池化窗口。
\[Y = \max_{(i, j) \in \text{{window}}(size)} X_{i, j, k}\]其中,\(\text{{window}}(size)\) 是池化窗口,\(X_{i,j,k}\) 是输入张量在 \((i, j, k)\) 位置的元素,\(Y\) 是进行最大池化操作后的输出张量值。
- 参数:
x(
Var
): 输入的张量。size(
int
): 池化窗口尺寸。stride(
int
): 池化窗口的移动步长,步长值确定了池化窗口的移动速度。padding(
int
, 可选): 输入的每一条边补充 0 的层数,默认值:0
- 代码示例:
>>> from jittor.contrib import argmax_pool >>> input_array = jt.random([1, 1, 4, 4]) jt.Var([[[[0.49449673 0.00643021 0.07254869 0.3258533 ] [0.61617774 0.09950083 0.3104945 0.48131013] [0.37913334 0.09407917 0.18861724 0.09006661] [0.7495838 0.25495356 0.00436674 0.3918325 ]]]], dtype=float32) >>> output = argmax_pool(input_array, 2, 2) jt.Var([[[[0.61617774 0.48131013] [0.7495838 0.3918325 ]]]], dtype=float32)
- 返回值:
经过最大池化操作后的结果(
Var
)
- jittor.contrib.check(bc)[源代码]
检查bc中的每个元素是否等于1或等于bc中维度0的最大值。
- 参数:
bc(
Var
): 输入的数组,代表要进行检查的数组。
- 代码示例:
>>> import jittor as jt >>> bc = jt.Var([[1, 2, 3], [1, 1, 1]]) >>> print(jt.contrib.check(bc)) [1 2 3] >>> bc = jt.Var([[1, 2, 3], [1, 4, 1]]) >>> print(jt.contrib.check(bc)) Exception: Shape not match.
- 返回值:
返回输入数组的按照轴
0
进行最大值操作后的结果(int
)。
- jittor.contrib.setitem(x, slices, value)[源代码]
对数组
x
进行切片赋值。函数的目标是将一个数组的切片赋予一个特定的值。首先通过slice, var, index
函数处理切片信息,然后创建一个与目标切片相同形状的广播值。然后将广播值累加到目标切片,并将得到的结果赋值回原数组。- 参数:
x (
Var
): 原始数组。slices (
int
,slice object
或者tuple
): 对数组的切片信息。如果是tuple
,其长度需要和数组x的维度一致。value (
int
,float
或者Var
): 要赋给数组切片的值。
- 代码示例:
>>> from jittor.contrib import setitem >>> import jittor as jt >>> x = jt.array([0, 1, 2, 3, 4]) >>> setitem(x, slice(1, 4), 9) jt.Var([0 9 9 9 4], dtype=int32)
- 返回值:
赋值后的数组(
Var
)。
- jittor.contrib.slice_var_index(x, slices)[源代码]
对于给定的变量
x
和切片slices
,执行切片操作。切片操作根据数组的索引范围、步长等信息来获取数组的一部分。该函数主要用于实现切片操作,可以将slices
中的切片应用到x
上,返回一个新的张量。- 参数值:
输入的张量(Var)
- slices(
tuple, list, numpy.ndarray,Var
): 切片可以是
int, slice, bool
等类型,具体由输入数据决定。如果
slices
不是元组,那么将slices
转换为元组。如果
len(slices)==1
且slices[0]
的dtype = bool
,则将slices[0]
重定位。
- slices(
- 代码示例:
>>> import jittor as jt >>> from jittor.contrib import slice_var_index >>> x = jt.array([[1, 2, 3], [4, 5, 6]]) >>> slices = (0, slice(1, None, 2)) >>> out_shape, out_index, _, __, ___ = slice_var_index(x, slices) >>> print(out_shape) [1, 1] >>> print(out_index) ['0', '1+i1*2']
- 返回值:
out_shape: 输出张量形状的列表(
list
)。out_index: 输出张量索引的列表(
list
)。0: 一个常数 0,表示无额外输出。
[]: 空列表(
list
),表示无额外输出。extras: 其中包含需要执行额外操作的切片(
list
)。