# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
from jittor import pool
from collections.abc import Sequence
[文档]
def argmax_pool(x, size, stride, padding=0):
"""
对输入张量进行最大值池化操作。即在池化窗口内找到最大值作为该窗口的值。该操作按照给定的步长值来移动池化窗口。
.. math::
Y = \\max_{(i, j) \\in \\text{{window}}(size)} X_{i, j, k}
其中,:math:`\\text{{window}}(size)` 是池化窗口,:math:`X_{i,j,k}` 是输入张量在 :math:`(i, j, k)` 位置的元素,:math:`Y` 是进行最大池化操作后的输出张量值。
参数:
- x(``Var``): 输入的张量。
- size(``int``): 池化窗口尺寸。
- stride(``int``): 池化窗口的移动步长,步长值确定了池化窗口的移动速度。
- padding(``int``, 可选): 输入的每一条边补充 0 的层数,默认值: ``0``
代码示例:
>>> from jittor.contrib import argmax_pool
>>> input_array = jt.random([1, 1, 4, 4])
jt.Var([[[[0.49449673 0.00643021 0.07254869 0.3258533 ]
[0.61617774 0.09950083 0.3104945 0.48131013]
[0.37913334 0.09407917 0.18861724 0.09006661]
[0.7495838 0.25495356 0.00436674 0.3918325 ]]]], dtype=float32)
>>> output = argmax_pool(input_array, 2, 2)
jt.Var([[[[0.61617774 0.48131013]
[0.7495838 0.3918325 ]]]], dtype=float32)
返回值:
经过最大池化操作后的结果(``Var``)
"""
return pool.pool(x, size, 'maximum', padding, stride)
def concat(arr, dim):
"""
沿指定维度连接输入张量序列。所有输入张量的尺寸必须匹配,除了连接维度上的尺寸。在连接维度上,所有其他尺寸必须相同。
参数:
- arr (``list``): 要连接的张量序列
- dim (``int``): 要连接的维度。默认值:``0``
代码示例:
>>> jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
jt.Var([[1,2],[2,2]],dtype=int32)
返回值:
连接后的张量(``Var``)
"""
# TODO: low performance when concat lots of vars
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
for a in arr:
total_dim += a.shape[dim]
cdim = 0
s = None
indexes = [ f"i{i}" for i in range(len(a.shape)) ]
for a in arr:
shape = list(a.shape)
shape[dim] = total_dim
indexes[dim] = f"i{dim}-{cdim}"
b = a.reindex(shape, indexes)
# ugly fix for preventing large fused op
if len(arr)>=100:
b.stop_fuse()
if s is None:
s = b
else:
s += b
cdim += a.shape[dim]
return s
[文档]
def check(bc):
"""
检查bc中的每个元素是否等于1或等于bc中维度0的最大值。
参数:
- bc(``Var``): 输入的数组,代表要进行检查的数组。
代码示例:
>>> import jittor as jt
>>> bc = jt.Var([[1, 2, 3], [1, 1, 1]])
>>> print(jt.contrib.check(bc))
[1 2 3]
>>> bc = jt.Var([[1, 2, 3], [1, 4, 1]])
>>> print(jt.contrib.check(bc))
Exception: Shape not match.
返回值:
返回输入数组的按照轴 ``0`` 进行最大值操作后的结果( ``int`` )。
"""
bc = np.array(bc)
if ((bc != 1) * (bc != bc.max(0))).sum() > 0:
raise Exception(f"Shape not match.")
else:
return bc.max(0)
[文档]
def slice_var_index(x, slices):
"""
对于给定的变量 ``x`` 和切片 ``slices`` ,执行切片操作。切片操作根据数组的索引范围、步长等信息来获取数组的一部分。该函数主要用于实现切片操作,可以将 ``slices`` 中的切片应用到 ``x`` 上,返回一个新的张量。
参数值:
- 输入的张量(Var)
- slices( ``tuple, list, numpy.ndarray,Var`` ):
- 切片可以是 ``int, slice, bool`` 等类型,具体由输入数据决定。
- 如果 ``slices`` 不是元组,那么将 ``slices`` 转换为元组。
- 如果 ``len(slices)==1`` 且 ``slices[0]`` 的 ``dtype = bool``,则将 ``slices[0]`` 重定位。
代码示例:
>>> import jittor as jt
>>> from jittor.contrib import slice_var_index
>>> x = jt.array([[1, 2, 3], [4, 5, 6]])
>>> slices = (0, slice(1, None, 2))
>>> out_shape, out_index, _, __, ___ = slice_var_index(x, slices)
>>> print(out_shape)
[1, 1]
>>> print(out_index)
['0', '1+i1*2']
返回值:
- out_shape: 输出张量形状的列表(``list``)。
- out_index: 输出张量索引的列表(``list``)。
- 0: 一个常数 0,表示无额外输出。
- []: 空列表(``list``),表示无额外输出。
- extras: 其中包含需要执行额外操作的切片(``list``)。
"""
if not isinstance(slices, tuple):
slices = (slices,)
if isinstance(slices[0], jt.Var):
if len(slices) == 1 and slices[0].dtype == "bool":
return slice_var_index(x, tuple(slices[0].where()))
bc = []
ml = -1
for idx, s in enumerate(slices):
if isinstance(s, jt.Var):
shape = s.shape
elif isinstance(s, np.ndarray):
shape = list(s.shape)
elif isinstance(s, list):
shape = list(np.array(s).shape)
else:
continue
if len(shape) >= ml:
ml = len(shape)
bc.append(shape)
for idx, shape in enumerate(bc):
if len(shape) < ml:
shape = (ml - len(shape)) * [1] + shape
bc[idx] = shape
if len(bc) >= 1:
bc_shape = check(bc)
ss = []
for idx, s in enumerate(slices):
if isinstance(s, np.ndarray) or isinstance(s, list):
ss.append(jt.array(s).broadcast(bc_shape.tolist()))
elif isinstance(s, jt.Var):
ss.append(s.broadcast(bc_shape.tolist()))
else:
ss.append(s)
slices = ss
out_shape = []
out_index = []
shape = x.shape
cnt_list = 0
extras_idx = []
extras = []
has_ellipse = 0
ellipse_index = 0
for s,i in zip(slices,range(len(slices))):
if isinstance(s,type(...)):
has_ellipse+=1
ellipse_index = i
if has_ellipse>1:
raise Exception(f"There are more than one ...")
elif has_ellipse==1:
slices = list(slices)
del slices[ellipse_index]
while len(slices)<len(shape):
slices.insert(ellipse_index,slice(None))
for i in range(len(shape)):
if i>=len(slices):
s = slice(None)
else:
s = slices[i]
sp = shape[i]
j = len(out_shape)
if isinstance(s, int):
if s<0: s += sp
out_index.append(str(s))
elif isinstance(s, slice):
if s == slice(None):
out_shape.append(sp)
out_index.append(f"i{j}")
continue
start = 0 if s.start is None else s.start
stop = sp if s.stop is None else s.stop
step = 1 if s.step is None else s.step
if start<0: start += sp
if stop<0: stop += sp
if stop>sp+1: stop = sp
out_shape.append(1+int(max(0, (stop-start-1)//step)))
out_index.append(f"{start}+i{j}*{step}")
elif isinstance(s, jt.Var):
if cnt_list == 0:
for idx in range(len(bc_shape)):
extras_idx.append(f"i{len(out_shape) + idx}")
out_shape += bc_shape.tolist()
out_index.append(f"@e{cnt_list}("+ ",".join(extras_idx) + ")")
cnt_list += 1
extras.append(s)
else:
raise Exception(f"Not support slice {s}")
if len(out_shape)==0:
out_shape = [1]
# Stop fuse both input and output, prevent recompile
x.stop_fuse()
return (out_shape, out_index, 0, [], extras)
def _slice_var_old(x, slices):
"""
使用指定的片段对给定变量进行重索引。函数首先通过调用 :math:`slice` _ :math:`var` _ :math:`index` 函数生成一个新的索引片段对变量进行索引,并通过调用 :math:`reindex` 函数对变量进行重索引。
在开始和结束重索引之前,通过调用 :math:`stop` _ :math:`fuse` 函数停止梯度融合以防止误导梯度计算。
参数:
- x(Var): 输入的 Jittor 变量。
- slices(tuple, list): 用于索引的切片。此参数应为一个元组或列表,包含一个或多个整数或者切片对象。
代码示例:
>>> import jittor as jt
>>> from jittor.contrib import _slice_var_old
>>> x = jt.array([1, 2, 3, 4, 5])
>>> slices = (2, 4)
>>> result = _slice_var_old(x, slices)
jt.Var([3], dtype=int32)
返回值:
重新索引后的变量(Var)。
"""
reindex_args = slice_var_index(x, slices)
x.stop_fuse()
return x.reindex(*reindex_args).stop_fuse()
def _setitem_old(x, slices, value):
"""
对数组 ``x`` 进行切片赋值。函数的目标是将一个数组的切片赋予一个特定的值。首先通过 :math:`slice` :math:`var` :math:`index` 函数处理切片信息,然后创建一个与目标切片相同形状的广播值。然后将广播值累加到目标切片,并将得到的结果赋值回原数组。
参数:
- x (Var): 原始数组。
- slices (int, slice object 或者 tuple): 对数组的切片信息。如果是 :math:`tuple`,其长度需要和数组x的维度一致。
- value (int, float 或者 Var): 要赋给数组切片的值。
代码示例:
>>> from jittor.contrib import _setitem_old
>>> import jittor as jt
>>> x = jt.array([0, 1, 2, 3, 4])
>>> _setitem_old(x, slice(1, 4), 9)
[0, 9, 9, 9, 4]
返回值:
赋值后的数组(Var)。
"""
reindex_args = slice_var_index(x, slices)
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
value = jt.broadcast(value, xslice)
value = value.cast(x.dtype)
one = jt.broadcast(1, xslice)
if not isinstance(reindex_args[0][0], jt.Var):
reindex_args = (x.shape,) + reindex_args[1:]
mask = one.reindex_reduce("add", *reindex_reduce_args)
data = value.reindex_reduce("add", *reindex_reduce_args)
# Stop fuse both input and output, prevent recompile
out = mask.ternary(data, x).stop_fuse()
x.assign(out)
return x
# PATCH
def getitem(x, slices):
"""
对数组 ``x`` 进行切片获取。函数的目标是从一个数组中按照指定的切片信息提取部分数据。首先处理切片信息,对于布尔类型的Var,会先转换为对应的索引。然后根据处理后的切片信息从原数组中获取对应部分。
参数:
- x (``Var``): 原始数组。
- slices (``int``, ``slice object``, ``tuple`` 或者 ``Var``): 对数组的切片信息。如果是 :math:`tuple`,其长度需要和数组x的维度一致。如果是Var且数据类型为"bool",则将其视为布尔索引。
代码示例:
>>> import jittor as jt
>>> x = jt.array([0, 1, 2, 3, 4])
>>> print(getitem(x, slice(1, 4)))
[1, 2, 3]
>>> bool_idx = jt.array([False, True, True, False, True])
>>> print(getitem(x, bool_idx))
[1, 2, 4]
返回值:
提取的数组部分(``Var``)。
"""
if isinstance(slices, jt.Var) and slices.dtype == "bool":
return getitem(x, slices.where())
if isinstance(slices, tuple):
ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":
ss.extend(s.where())
else:
ss.append(s)
slices = tuple(ss)
return x.getitem(slices)
[文档]
def setitem(x, slices, value):
"""
对数组 ``x`` 进行切片赋值。函数的目标是将一个数组的切片赋予一个特定的值。首先通过 ``slice, var, index`` 函数处理切片信息,然后创建一个与目标切片相同形状的广播值。然后将广播值累加到目标切片,并将得到的结果赋值回原数组。
参数:
- x (``Var``): 原始数组。
- slices (``int``, ``slice object`` 或者 ``tuple``): 对数组的切片信息。如果是 ``tuple``,其长度需要和数组x的维度一致。
- value (``int``, ``float`` 或者 ``Var``): 要赋给数组切片的值。
代码示例:
>>> from jittor.contrib import setitem
>>> import jittor as jt
>>> x = jt.array([0, 1, 2, 3, 4])
>>> setitem(x, slice(1, 4), 9)
jt.Var([0 9 9 9 4], dtype=int32)
返回值:
赋值后的数组(``Var``)。
"""
if isinstance(slices, jt.Var) and slices.dtype == "bool":
if slices.shape == x.shape:
if isinstance(value, (int, float)):
value = jt.array(value).broadcast(x.shape)
return x.assign(slices.ternary(value, x))
elif isinstance(value, jt.Var) and value.shape == [1,]:
value = jt.broadcast(value, x.shape)
return x.assign(slices.ternary(value, x))
slices = slices.where()
elif isinstance(slices, tuple):
ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":
ss.extend(s.where())
else:
ss.append(s)
slices = tuple(ss)
return x.check_cascade_setitem(x.setitem(slices, value))
jt.Var.__getitem__ = jt.Var.slice_var = getitem
jt.Var.__setitem__ = setitem
def _merge_dtypes(dtypes):
"""
根据输入的数据类型列表,将这些数据类型进行合并返回一个新的数据类型。合并的规则是使用 Jittor 的 :math:`binary` _ :math:`dtype` _ :math:`infer` 函数,其操作方式类似于 :math:`Python` 的 :math:`add` 操作。
其中,:math:`add` 操作的数学公式可以表示为:
.. math::
\\text{{new_dtype}} = \\text{{dtype1}} \\oplus \\text{{dtype2}}
其中,:math:`\\oplus` 表示的是 :math:`add` 操作,:math:`\\text{{dtype1}}` 和 :math:`\\text{{dtype2}}` 分别表示两个进行操作的数据类型。
最后,返回结果为float32,表示合并后的新的数据类型。
参数:
dtypes(List[str]): 需要被合并的数据类型列表。
代码例子:
>>> from jittor.contrib import _merge_dtypes
>>> dtypes = [\"float32\", \"int32\"]
>>> _merge_dtypes(dtypes)
\"float32\"
返回值:
合并后的新的数据类型(str)。
"""
dtype = dtypes[0]
for i in range(1, len(dtypes)):
dtype = jt.binary_dtype_infer("add", dtype, dtypes[i])
return dtype
@jt.flag_scope(amp_reg=4) # _custom_flag
def concat(arr, dim=0):
"""
沿指定维度连接输入张量序列。所有输入张量的尺寸必须匹配,除了连接维度上的尺寸。在连接维度上,所有其他尺寸必须相同。
参数:
- arr (list): 要连接的张量序列
- dim (int): 要连接的维度。默认值:0
代码示例:
>>> jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
jt.Var([[1,2],[2,2]],dtype=int32)
返回值:
连接后的张量(Var)
"""
if not isinstance(arr, Sequence):
raise TypeError("concat arr needs to be a tuple or list")
if len(arr) == 0:
raise ValueError("need at least one array to concat")
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
dtypes = []
for a in arr:
total_dim += a.shape[dim]
dtypes.append(str(a.dtype))
cdim = 0
shape = list(a.shape)
shape[dim] = total_dim
s = jt.empty(shape, dtype = _merge_dtypes(dtypes))
slices = [slice(None)]*len(a.shape)
for a in arr:
if a.shape[dim] == 0:
continue
slices[dim] = slice(cdim, cdim+a.shape[dim])
# print(slices, type(a))
s = s.setitem(tuple(slices), a)
# s = jt.setitem(s, tuple(slices), a)
cdim += a.shape[dim]
return s
cat = concat