import functools
import itertools
import string
import typing
from collections import OrderedDict
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
if typing.TYPE_CHECKING:
import numpy as np
from jittor.einops import EinopsError
from jittor.einops._backends import get_backend
from jittor.einops.parsing import ParsedExpression, _ellipsis, AnonymousAxis
Tensor = TypeVar('Tensor')
ReductionCallable = Callable[[Tensor, List[int]], Tensor]
Reduction = Union[str, ReductionCallable]
_reductions = ('min', 'max', 'sum', 'mean', 'prod')
_ellipsis_not_in_parenthesis: List[int] = [-999]
_unknown_axis_length = -999999
def is_ellipsis_not_in_parenthesis(group: List[int]) -> bool:
if len(group) != 1:
return False
return group[0] == -999
def _product(sequence: List[int]) -> int:
""" minimalistic product that works both with numbers and symbols. Supports empty lists """
result = 1
for element in sequence:
result *= element
return result
def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
reduced_axes = tuple(reduced_axes)
if callable(reduction_type):
# custom callable
return reduction_type(tensor, reduced_axes)
else:
# one of built-in operations
if len(reduced_axes) == 0:
return tensor
assert reduction_type in _reductions
if reduction_type == 'mean':
if not backend.is_float_type(tensor):
raise NotImplementedError('reduce_mean is not available for non-floating tensors')
return backend.reduce(tensor, reduction_type, reduced_axes)
def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
# 'collapses' neighboring axes if those participate in the result pattern in the same order
# TODO add support for added_axes
assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
# joining consecutive axes that will be reduced
# possibly we can skip this if all backends can optimize this (not sure)
reduced_axes = tuple(sorted(reduced_axes))
for i in range(len(reduced_axes) - 1)[::-1]:
if reduced_axes[i] + 1 == reduced_axes[i + 1]:
removed_axis = reduced_axes[i + 1]
removed_length = init_shapes[removed_axis]
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
init_shapes[removed_axis - 1] *= removed_length
reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:])
# removing axes that are moved together during reshape
def build_mapping():
init_to_final = {}
for axis in range(len(init_shapes)):
if axis in reduced_axes:
init_to_final[axis] = None
else:
after_reduction = sum(x is not None for x in init_to_final.values())
init_to_final[axis] = list(axes_reordering).index(after_reduction)
return init_to_final
init_axis_to_final_axis = build_mapping()
for init_axis in range(len(init_shapes) - 1)[::-1]:
if init_axis_to_final_axis[init_axis] is None:
continue
if init_axis_to_final_axis[init_axis + 1] is None:
continue
if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
removed_axis = init_axis + 1
removed_length = init_shapes[removed_axis]
removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
init_shapes[removed_axis - 1] *= removed_length
old_reordering = axes_reordering
axes_reordering = []
for axis in old_reordering:
if axis == removed_axis_after_reduction:
pass
elif axis < removed_axis_after_reduction:
axes_reordering.append(axis)
else:
axes_reordering.append(axis - 1)
init_axis_to_final_axis = build_mapping()
return init_shapes, reduced_axes, axes_reordering, final_shapes
CookedRecipe = Tuple[List[int], List[int], List[int], Dict[int, int], List[int]]
class TransformRecipe:
"""
Recipe describes actual computation pathway.
Recipe can be applied to a tensor or variable.
"""
# structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)
def __init__(self,
# list of expressions (or just sizes) for elementary axes as they appear in left expression.
# this is what (after computing unknown parts) will be a shape after first transposition.
# If ellipsis is present, it forms one dimension here (in the right position).
elementary_axes_lengths: List[int],
# each dimension in input can help to reconstruct length of one elementary axis
# or verify one of dimensions. Each element points to element of elementary_axes_lengths
input_composite_axes: List[Tuple[List[int], List[int]]],
# indices of axes to be squashed
reduced_elementary_axes: List[int],
# in which order should axes be reshuffled after reduction
axes_permutation: List[int],
# at which positions which of elementary axes should appear
added_axes: Dict[int, int],
# ids of axes as they appear in result, again pointers to elementary_axes_lengths,
# only used to infer result dimensions
output_composite_axes: List[List[int]],
# positions of ellipsis in lhs and rhs of expression
ellipsis_position_in_lhs: Optional[int] = None,
):
self.elementary_axes_lengths: List[int] = elementary_axes_lengths
self.input_composite_axes: List[Tuple[List[int], List[int]]] = input_composite_axes
self.output_composite_axes: List[List[int]] = output_composite_axes
self.axes_permutation: List[int] = axes_permutation
self.added_axes: Dict[int, int] = added_axes
# This is redundant information, but more convenient to use
self.reduced_elementary_axes: List[int] = reduced_elementary_axes
# setting to a large number to avoid handling Nones in reconstruct_from_shape
self.ellipsis_position_in_lhs: int = ellipsis_position_in_lhs if ellipsis_position_in_lhs is not None else 10000
def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> CookedRecipe:
"""
Reconstruct all actual parameters using shape.
Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet)
known axes can be integers or symbols, but not Nones.
"""
axes_lengths: List[int] = list(self.elementary_axes_lengths)
if self.ellipsis_position_in_lhs != 10000:
if len(shape) < len(self.input_composite_axes) - 1:
raise EinopsError('Expected at least {} dimensions, got {}'.format(
len(self.input_composite_axes) - 1, len(shape)))
else:
if len(shape) != len(self.input_composite_axes):
raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape)))
ellipsis_shape: List[int] = []
for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes):
before_ellipsis = input_axis
after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes)
if input_axis == self.ellipsis_position_in_lhs:
assert len(known_axes) == 0 and len(unknown_axes) == 1
unknown_axis, = unknown_axes
ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1]
for d in ellipsis_shape:
if d is None:
raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis")
total_dim_size: int = _product(ellipsis_shape)
axes_lengths[unknown_axis] = total_dim_size
else:
if input_axis < self.ellipsis_position_in_lhs:
length = shape[before_ellipsis]
else:
length = shape[after_ellipsis]
known_product = 1
for axis in known_axes:
known_product *= axes_lengths[axis]
if len(unknown_axes) == 0:
if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product))
# this is enforced when recipe is created
# elif len(unknown_axes) > 1:
# raise EinopsError(
# "Lengths of two or more axes in parenthesis not provided (dim={}), can't infer dimensions".
# format(known_product)
# )
else:
if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format(
length, known_product))
unknown_axis: int = unknown_axes[0]
inferred_length: int = length // known_product
axes_lengths[unknown_axis] = inferred_length
# at this point all axes_lengths are computed (either have values or variables, but not Nones)
# TODO more readable expression
init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)]
final_shapes: List[int] = []
for output_axis, grouping in enumerate(self.output_composite_axes):
if is_ellipsis_not_in_parenthesis(grouping):
final_shapes.extend(ellipsis_shape)
else:
lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
final_shapes.append(_product(lengths))
reduced_axes = self.reduced_elementary_axes
axes_reordering = self.axes_permutation
added_axes: Dict[int, int] = {
pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()}
# if optimize:
# assert len(self.added_axes) == 0
# return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes)
return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes
_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)
def _apply_recipe(recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction) -> Tensor:
# this method works for all backends but not compilable with
backend = get_backend(tensor)
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
_reconstruct_from_shape(recipe, backend.shape(tensor))
tensor = backend.reshape(tensor, init_shapes)
tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend)
tensor = backend.transpose(tensor, axes_reordering)
if len(added_axes) > 0:
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
return backend.reshape(tensor, final_shapes)
@functools.lru_cache(256)
def _prepare_transformation_recipe(pattern: str,
operation: Reduction,
axes_lengths: Tuple[Tuple, ...]) -> TransformRecipe:
""" Perform initial parsing of pattern and provided supplementary info
axes_lengths is a tuple of tuples (axis_name, axis_length)
"""
left, rght = pattern.split('->')
left = ParsedExpression(left)
rght = ParsedExpression(rght)
# checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
if not left.has_ellipsis and rght.has_ellipsis:
raise EinopsError('Ellipsis found in right side, but not left side of a pattern {}'.format(pattern))
if left.has_ellipsis and left.has_ellipsis_parenthesized:
raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern))
if operation == 'rearrange':
difference = set.symmetric_difference(left.identifiers, rght.identifiers)
if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)')
if len(difference) > 0:
raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference))
elif operation == 'repeat':
difference = set.difference(left.identifiers, rght.identifiers)
if len(difference) > 0:
raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference))
axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
{*left.identifiers, *(ax for ax, _ in axes_lengths)})
if len(axes_without_size) > 0:
raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size))
elif operation in _reductions or callable(operation):
difference = set.difference(rght.identifiers, left.identifiers)
if len(difference) > 0:
raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference))
else:
raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions))
# parsing all dimensions to find out lengths
axis_name2known_length = OrderedDict()
for composite_axis in left.composition:
for axis_name in composite_axis:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = _unknown_axis_length
# axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point
repeat_axes_names = []
for axis_name in rght.identifiers:
if axis_name not in axis_name2known_length:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = _unknown_axis_length
repeat_axes_names.append(axis_name)
axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if
axis not in rght.identifiers]
reduced_axes: List[int] = list(sorted(reduced_axes))
for elementary_axis, axis_length in axes_lengths:
if not ParsedExpression.check_axis_name(elementary_axis):
raise EinopsError('Invalid name for an axis', elementary_axis)
if elementary_axis not in axis_name2known_length:
raise EinopsError('Axis {} is not used in transform'.format(elementary_axis))
axis_name2known_length[elementary_axis] = axis_length
input_axes_known_unknown = []
# some of shapes will be inferred later - all information is prepared for faster inference
for composite_axis in left.composition:
known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
if len(unknown) > 1:
raise EinopsError('Could not infer sizes for {}'.format(unknown))
assert len(unknown) + len(known) == len(composite_axis)
input_axes_known_unknown.append(
([axis_name2position[axis] for axis in known],
[axis_name2position[axis] for axis in unknown])
)
axis_position_after_reduction = {}
for axis_name in itertools.chain(*left.composition):
if axis_name in rght.identifiers:
axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
result_axes_grouping: List[List[int]] = []
for composite_axis in rght.composition:
if composite_axis == _ellipsis:
result_axes_grouping.append(_ellipsis_not_in_parenthesis)
else:
result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis])
ordered_axis_right = list(itertools.chain(*rght.composition))
axes_permutation = [
axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers]
added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right)
if axis_name not in left.identifiers}
ellipsis_left = None if _ellipsis not in left.composition else left.composition.index(_ellipsis)
return TransformRecipe(
elementary_axes_lengths=list(axis_name2known_length.values()),
input_composite_axes=input_axes_known_unknown,
reduced_elementary_axes=reduced_axes,
axes_permutation=axes_permutation,
added_axes=added_axes,
output_composite_axes=result_axes_grouping,
ellipsis_position_in_lhs=ellipsis_left,
)
[文档]
def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
"""
重新排序和reduce的组合操作。
参数:
- tensor (Var): 输入张量
- pattern (str): 字符串, 减少模式
- reduction (str): 减少操作,可用约简('min'、'max'、'sum'、'mean'、'prod')之一
- axes_lengths (int): 轴长度
返回值:
- output (Var): 重塑后的张量
代码示例:
>>> x = jt.randn(100, 32, 64)
# 在第一个轴上执行最大值归约
>>> y = reduce(x, 't b c -> b c', 'max')
# 与前面相同,但轴的含义更清晰
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
>>> x = jt.randn(10, 20, 30, 40)
# 使用2*2的核大小进行2D最大池化,用于图像处理
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# 如果想恢复到原始的高度和宽度,可以应用深度到空间的技巧
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
# 自适应2D最大池化到3*4的网格
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
(10, 20, 3, 4)
# 全局平均池化
>>> reduce(x, 'b c h w -> b c', 'mean').shape
(10, 20)
# 为每个通道减去批次上的均值
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
# 为每个图像的每个通道减去均值
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
"""
try:
hashable_axes_lengths = tuple(sorted(axes_lengths.items()))
recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths)
return _apply_recipe(recipe, tensor, reduction_type=reduction)
except EinopsError as e:
message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
if not isinstance(tensor, list):
message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor))
else:
message += '\n Input is list. '
message += 'Additional info: {}.'.format(axes_lengths)
raise EinopsError(message + '\n {}'.format(e))
[文档]
def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor:
'''
对多维张量进行智能、易读的元素重排的操作。此操作包括转置(轴置换)、重塑(视图)、挤压、展开、堆叠、连接等操作。
参数:
- tensor (Union[Var, List[Var]]): 支持的任何库(例如 numpy.ndarray, jittor.Var)的张量,或相同类型和形状的张量列表。
- pattern (str): 重排模式的字符串描述。
- axes_lengths: 对维度的额外说明,可选。
返回值:
返回与输入相同类型的张量。尽可能返回原始张量的视图。
代码示例:
>>> from jittor import einops
>>> import numpy as np
# 假设有32张 30x40x3 大小的图像
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
# 沿批量轴堆叠,输出单一数组
>>> einops.rearrange(images, 'b h w c -> b h w c').shape
(32, 30, 40, 3)
# 沿高度轴拼接,输出 960x40x3
>>> einops.rearrange(images, 'b h w c -> (b h) w c').shape
(960, 40, 3)
# 沿宽度轴拼接,输出 30x1280x3
>>> einops.rearrange(images, 'b h w c -> h (b w) c').shape
(30, 1280, 3)
# 轴重排为 'b c h w' 格式
>>> einops.rearrange(images, 'b h w c -> b c h w').shape
(32, 3, 30, 40)
# 每个图像展平为矢量,输出 32x3600
>>> einops.rearrange(images, 'b h w c -> b (c h w)').shape
(32, 3600)
# 将图像分为4个较小的(左上、右上、左下、右下)部分,输出 128x15x20x3
>>> einops.rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
(128, 15, 20, 3)
'''
if isinstance(tensor, list):
if len(tensor) == 0:
raise TypeError("Rearrange can't be applied to an empty list")
tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor)
return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)
[文档]
def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
'''
以任意组合的方式重新排序和重复元素。该操作包括 repeat、tile 和 broadcast 函数的功能。
参数:
- tensor (Union[Var, List[Var]]): 支持的任何库(例如 numpy.ndarray, jittor.Var)的张量,或相同类型和形状的张量列表。
- pattern (str): 重排模式的字符串描述。
- axes_lengths: 维度的额外规格说明。
返回值:
- 返回与输入相同类型的张量。如果可能,返回指向原始张量的视图。
代码示例:
>>> from jittor import einops
>>> import numpy as np
# 灰度图像(30x40)
>>> image = np.random.randn(30, 40)
# 转换为 RGB 格式
>>> einops.repeat(image, 'h w -> h w c', c=3).shape
(30, 40, 3)
# 沿高度轴重复 2 次
>>> einops.repeat(image, 'h w -> (repeat h) w', repeat=2).shape
(60, 40)
# 沿高度和宽度分别重复 2 次和 3 次
>>> einops.repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape
(60, 120)
# 放大图像 2 倍
>>> einops.repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
(60, 80)
# 缩小后放大图像
>>> downsampled = einops.reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
>>> einops.repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
(30, 40)
'''
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
[文档]
def parse_shape(x, pattern: str) -> dict:
'''
将张量形状解析为字典,将轴名称映射到其长度。
参数:
- x (Var): 输入张量
- pattern (str): 字符串, 轴的空格分隔名称,下划线表示跳过轴
返回值:
- output (dict): 字典,将轴名称映射到它们的长度
代码示例:
>>> x = jt.zeros([2, 3, 5, 7]) # 使用下划线 _ 来在解析时跳过维度
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}
# `parse_shape`输出可用于指定其他操作的axes_lalength:
>>> y = jt.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)
'''
exp = ParsedExpression(pattern, allow_underscore=True)
shape = get_backend(x).shape(x)
if exp.has_composed_axes():
raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format(
pattern=pattern, shape=shape))
if len(shape) != len(exp.composition):
if exp.has_ellipsis:
if len(shape) < len(exp.composition) - 1:
raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format(
pattern=pattern, shape=shape))
else:
raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format(
pattern=pattern, shape=shape))
if exp.has_ellipsis:
ellipsis_idx = exp.composition.index(_ellipsis)
composition = (exp.composition[:ellipsis_idx] +
['_'] * (len(shape) - len(exp.composition) + 1) +
exp.composition[ellipsis_idx + 1:])
else:
composition = exp.composition
result = {}
for (axis_name,), axis_length in zip(composition, shape):
if axis_name != '_':
result[axis_name] = axis_length
return result
# this one is probably not needed in the public API
def _enumerate_directions(x):
"""
For an n-dimensional tensor, returns tensors to enumerate each axis.
```python
x = np.zeros([2, 3, 4]) # or any other tensor
i, j, k = _enumerate_directions(x)
result = i + 2*j + 3*k
```
`result[i, j, k] = i + 2j + 3k`, and also has the same shape as result
Works very similarly to numpy.ogrid (open indexing grid)
"""
backend = get_backend(x)
shape = backend.shape(x)
result = []
for axis_id, axis_length in enumerate(shape):
shape = [1] * len(shape)
shape[axis_id] = axis_length
result.append(backend.reshape(backend.arange(0, axis_length), shape))
return result
[文档]
def asnumpy(tensor) -> 'numpy.ndarray':
"""
将一个张量转换为numpy.ndarray
参数:
- tensor (Var): 输入张量
返回值:
输入张量转换后得到的numpy.ndarray
代码示例:
>>> from jittor import einops
>>> einops.asnumpy(jt.ones(3,3))
array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
"""
return get_backend(tensor).to_numpy(tensor)
def _validate_einsum_axis_name(axis_name):
if len(axis_name) == 0:
raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
if len(axis_name) > 1:
raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
axis_name = axis_name[0]
if isinstance(axis_name, AnonymousAxis):
raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
if len(axis_name) == 0:
raise RuntimeError("Encountered empty axis name in einsum.")
if not isinstance(axis_name, str):
raise RuntimeError("Axis name in einsum must be a string.")
@functools.lru_cache(256)
def _compactify_pattern_for_einsum(pattern: str) -> str:
if "->" not in pattern:
# numpy allows this, so make sure users
# don't accidentally do something like this.
raise ValueError("Einsum pattern must contain '->'.")
lefts, right = pattern.split('->')
lefts = lefts.split(',')
lefts = [
ParsedExpression(left, allow_underscore=True, allow_duplicates=True)
for left in lefts
]
right = ParsedExpression(right, allow_underscore=True)
# Start from 'a' and go up to 'Z'
output_axis_names = string.ascii_letters
i = 0
axis_name_mapping = {}
left_patterns = []
for left in lefts:
left_pattern = ""
for raw_axis_name in left.composition:
if raw_axis_name == _ellipsis:
left_pattern += '...'
continue
_validate_einsum_axis_name(raw_axis_name)
axis_name = raw_axis_name[0]
if axis_name not in axis_name_mapping:
if i >= len(output_axis_names):
raise RuntimeError("Too many axes in einsum.")
axis_name_mapping[axis_name] = output_axis_names[i]
i += 1
left_pattern += axis_name_mapping[axis_name]
left_patterns.append(left_pattern)
compact_pattern = ",".join(left_patterns) + "->"
for raw_axis_name in right.composition:
if raw_axis_name == _ellipsis:
compact_pattern += '...'
continue
_validate_einsum_axis_name(raw_axis_name)
axis_name = raw_axis_name[0]
if axis_name not in axis_name_mapping:
raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.")
compact_pattern += axis_name_mapping[axis_name]
return compact_pattern
@typing.overload
def einsum(tensor: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ...
def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor:
'''
einops 风格命名的轴索引调用 einsum 运算,允许与任意数量的张量计算张量乘积。与典型的 einsum 语法不同,这里必须先传入张量,然后传入模式。
参数:
- tensors_and_pattern (List[Union[Var, str]]): 支持的库(如 numpy, jittor)的张量列表以及 einsum 模式字符串。
返回值:
- 经过 einsum 处理后的同类型张量。
代码示例:
>>> import numpy as np
>>> from jittor.einops import einops
# 假设有三个 20x20x20 的张量 x, y, z
>>> x, y, z = np.random.randn(3, 20, 20, 20)
>>> output = einops.einsum(x, y, z, \'''a b c, c b d, a g k -> a b k\''')
>>> output.shape
(3, 20, 20)
# 过滤一组图像
>>> batched_images = np.random.randn(128, 16, 16)
>>> filters = np.random.randn(16, 16, 30)
>>> result = einops.einsum(batched_images, filters, \'''batch h w, h w channel -> batch channel\''')
>>> result.shape
(128, 30)
# 矩阵乘法,未知输入形状
>>> batch_shape = (50, 30)
>>> data = np.random.randn(*batch_shape, 20)
>>> weights = np.random.randn(10, 20)
>>> result = einops.einsum(weights, data, \'''out_dim in_dim, ... in_dim -> ... out_dim\''')
>>> result.shape
(50, 30, 10)
# 单个张量的矩阵迹
>>> matrix = np.random.randn(10, 10)
>>> result = einops.einsum(matrix, \'''i i -> \''')
'''
if len(tensors_and_pattern) <= 1:
raise ValueError(
"`einops.einsum` takes at minimum two arguments: the tensors (at least one),"
" followed by the pattern."
)
pattern = tensors_and_pattern[-1]
if not isinstance(pattern, str):
raise ValueError(
"The last argument passed to `einops.einsum` must be a string,"
" representing the einsum pattern."
)
tensors = tensors_and_pattern[:-1]
pattern = _compactify_pattern_for_einsum(pattern)
return get_backend(tensors[0]).einsum(pattern, *tensors)