jittor.einops.einops 源代码

import functools
import itertools
import string
import typing
from collections import OrderedDict
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar

if typing.TYPE_CHECKING:
    import numpy as np

from jittor.einops import EinopsError
from jittor.einops._backends import get_backend
from jittor.einops.parsing import ParsedExpression, _ellipsis, AnonymousAxis

Tensor = TypeVar('Tensor')
ReductionCallable = Callable[[Tensor, List[int]], Tensor]
Reduction = Union[str, ReductionCallable]

_reductions = ('min', 'max', 'sum', 'mean', 'prod')
_ellipsis_not_in_parenthesis: List[int] = [-999]
_unknown_axis_length = -999999


def is_ellipsis_not_in_parenthesis(group: List[int]) -> bool:
    if len(group) != 1:
        return False
    return group[0] == -999


def _product(sequence: List[int]) -> int:
    """ minimalistic product that works both with numbers and symbols. Supports empty lists """
    result = 1
    for element in sequence:
        result *= element
    return result


def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
    reduced_axes = tuple(reduced_axes)
    if callable(reduction_type):
        # custom callable
        return reduction_type(tensor, reduced_axes)
    else:
        # one of built-in operations
        if len(reduced_axes) == 0:
            return tensor
        assert reduction_type in _reductions
        if reduction_type == 'mean':
            if not backend.is_float_type(tensor):
                raise NotImplementedError('reduce_mean is not available for non-floating tensors')
        return backend.reduce(tensor, reduction_type, reduced_axes)


def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
    # 'collapses' neighboring axes if those participate in the result pattern in the same order
    # TODO add support for added_axes
    assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
    # joining consecutive axes that will be reduced
    # possibly we can skip this if all backends can optimize this (not sure)
    reduced_axes = tuple(sorted(reduced_axes))
    for i in range(len(reduced_axes) - 1)[::-1]:
        if reduced_axes[i] + 1 == reduced_axes[i + 1]:
            removed_axis = reduced_axes[i + 1]
            removed_length = init_shapes[removed_axis]
            init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
            init_shapes[removed_axis - 1] *= removed_length
            reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:])

    # removing axes that are moved together during reshape
    def build_mapping():
        init_to_final = {}
        for axis in range(len(init_shapes)):
            if axis in reduced_axes:
                init_to_final[axis] = None
            else:
                after_reduction = sum(x is not None for x in init_to_final.values())
                init_to_final[axis] = list(axes_reordering).index(after_reduction)
        return init_to_final

    init_axis_to_final_axis = build_mapping()

    for init_axis in range(len(init_shapes) - 1)[::-1]:
        if init_axis_to_final_axis[init_axis] is None:
            continue
        if init_axis_to_final_axis[init_axis + 1] is None:
            continue
        if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
            removed_axis = init_axis + 1
            removed_length = init_shapes[removed_axis]
            removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))

            reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
            init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
            init_shapes[removed_axis - 1] *= removed_length
            old_reordering = axes_reordering
            axes_reordering = []
            for axis in old_reordering:
                if axis == removed_axis_after_reduction:
                    pass
                elif axis < removed_axis_after_reduction:
                    axes_reordering.append(axis)
                else:
                    axes_reordering.append(axis - 1)
            init_axis_to_final_axis = build_mapping()

    return init_shapes, reduced_axes, axes_reordering, final_shapes


CookedRecipe = Tuple[List[int], List[int], List[int], Dict[int, int], List[int]]


class TransformRecipe:
    """
    Recipe describes actual computation pathway.
    Recipe can be applied to a tensor or variable.
    """

    # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)

    def __init__(self,
                 # list of expressions (or just sizes) for elementary axes as they appear in left expression.
                 # this is what (after computing unknown parts) will be a shape after first transposition.
                 # If ellipsis is present, it forms one dimension here (in the right position).
                 elementary_axes_lengths: List[int],
                 # each dimension in input can help to reconstruct length of one elementary axis
                 # or verify one of dimensions. Each element points to element of elementary_axes_lengths
                 input_composite_axes: List[Tuple[List[int], List[int]]],
                 # indices of axes to be squashed
                 reduced_elementary_axes: List[int],
                 # in which order should axes be reshuffled after reduction
                 axes_permutation: List[int],
                 # at which positions which of elementary axes should appear
                 added_axes: Dict[int, int],
                 # ids of axes as they appear in result, again pointers to elementary_axes_lengths,
                 # only used to infer result dimensions
                 output_composite_axes: List[List[int]],
                 # positions of ellipsis in lhs and rhs of expression
                 ellipsis_position_in_lhs: Optional[int] = None,
                 ):
        self.elementary_axes_lengths: List[int] = elementary_axes_lengths
        self.input_composite_axes: List[Tuple[List[int], List[int]]] = input_composite_axes
        self.output_composite_axes: List[List[int]] = output_composite_axes
        self.axes_permutation: List[int] = axes_permutation
        self.added_axes: Dict[int, int] = added_axes
        # This is redundant information, but more convenient to use
        self.reduced_elementary_axes: List[int] = reduced_elementary_axes
        # setting to a large number to avoid handling Nones in reconstruct_from_shape
        self.ellipsis_position_in_lhs: int = ellipsis_position_in_lhs if ellipsis_position_in_lhs is not None else 10000


def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> CookedRecipe:
    """
    Reconstruct all actual parameters using shape.
    Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet)
    known axes can be integers or symbols, but not Nones.
    """
    axes_lengths: List[int] = list(self.elementary_axes_lengths)
    if self.ellipsis_position_in_lhs != 10000:
        if len(shape) < len(self.input_composite_axes) - 1:
            raise EinopsError('Expected at least {} dimensions, got {}'.format(
                len(self.input_composite_axes) - 1, len(shape)))
    else:
        if len(shape) != len(self.input_composite_axes):
            raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape)))

    ellipsis_shape: List[int] = []
    for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes):
        before_ellipsis = input_axis
        after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes)
        if input_axis == self.ellipsis_position_in_lhs:
            assert len(known_axes) == 0 and len(unknown_axes) == 1
            unknown_axis, = unknown_axes
            ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1]
            for d in ellipsis_shape:
                if d is None:
                    raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis")
            total_dim_size: int = _product(ellipsis_shape)
            axes_lengths[unknown_axis] = total_dim_size
        else:
            if input_axis < self.ellipsis_position_in_lhs:
                length = shape[before_ellipsis]
            else:
                length = shape[after_ellipsis]
            known_product = 1
            for axis in known_axes:
                known_product *= axes_lengths[axis]

            if len(unknown_axes) == 0:
                if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
                    raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product))
            # this is enforced when recipe is created
            # elif len(unknown_axes) > 1:
            #     raise EinopsError(
            #         "Lengths of two or more axes in parenthesis not provided (dim={}), can't infer dimensions".
            #             format(known_product)
            #     )
            else:
                if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
                    raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format(
                        length, known_product))

                unknown_axis: int = unknown_axes[0]
                inferred_length: int = length // known_product
                axes_lengths[unknown_axis] = inferred_length

    # at this point all axes_lengths are computed (either have values or variables, but not Nones)

    # TODO more readable expression
    init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)]
    final_shapes: List[int] = []
    for output_axis, grouping in enumerate(self.output_composite_axes):
        if is_ellipsis_not_in_parenthesis(grouping):
            final_shapes.extend(ellipsis_shape)
        else:
            lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
            final_shapes.append(_product(lengths))
    reduced_axes = self.reduced_elementary_axes
    axes_reordering = self.axes_permutation
    added_axes: Dict[int, int] = {
        pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()}
    # if optimize:
    #     assert len(self.added_axes) == 0
    #     return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes)
    return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes


_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)


def _apply_recipe(recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction) -> Tensor:
    # this method works for all backends but not compilable with
    backend = get_backend(tensor)
    init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
        _reconstruct_from_shape(recipe, backend.shape(tensor))
    tensor = backend.reshape(tensor, init_shapes)
    tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend)
    tensor = backend.transpose(tensor, axes_reordering)
    if len(added_axes) > 0:
        tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
    return backend.reshape(tensor, final_shapes)


@functools.lru_cache(256)
def _prepare_transformation_recipe(pattern: str,
                                   operation: Reduction,
                                   axes_lengths: Tuple[Tuple, ...]) -> TransformRecipe:
    """ Perform initial parsing of pattern and provided supplementary info
    axes_lengths is a tuple of tuples (axis_name, axis_length)
    """
    left, rght = pattern.split('->')
    left = ParsedExpression(left)
    rght = ParsedExpression(rght)

    # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
    if not left.has_ellipsis and rght.has_ellipsis:
        raise EinopsError('Ellipsis found in right side, but not left side of a pattern {}'.format(pattern))
    if left.has_ellipsis and left.has_ellipsis_parenthesized:
        raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern))
    if operation == 'rearrange':
        difference = set.symmetric_difference(left.identifiers, rght.identifiers)
        if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
            raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)')
        if len(difference) > 0:
            raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference))
    elif operation == 'repeat':
        difference = set.difference(left.identifiers, rght.identifiers)
        if len(difference) > 0:
            raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference))
        axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
                                           {*left.identifiers, *(ax for ax, _ in axes_lengths)})
        if len(axes_without_size) > 0:
            raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size))
    elif operation in _reductions or callable(operation):
        difference = set.difference(rght.identifiers, left.identifiers)
        if len(difference) > 0:
            raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference))
    else:
        raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions))

    # parsing all dimensions to find out lengths
    axis_name2known_length = OrderedDict()
    for composite_axis in left.composition:
        for axis_name in composite_axis:
            if isinstance(axis_name, AnonymousAxis):
                axis_name2known_length[axis_name] = axis_name.value
            else:
                axis_name2known_length[axis_name] = _unknown_axis_length

    # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point

    repeat_axes_names = []
    for axis_name in rght.identifiers:
        if axis_name not in axis_name2known_length:
            if isinstance(axis_name, AnonymousAxis):
                axis_name2known_length[axis_name] = axis_name.value
            else:
                axis_name2known_length[axis_name] = _unknown_axis_length
            repeat_axes_names.append(axis_name)

    axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
    reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if
                               axis not in rght.identifiers]
    reduced_axes: List[int] = list(sorted(reduced_axes))

    for elementary_axis, axis_length in axes_lengths:
        if not ParsedExpression.check_axis_name(elementary_axis):
            raise EinopsError('Invalid name for an axis', elementary_axis)
        if elementary_axis not in axis_name2known_length:
            raise EinopsError('Axis {} is not used in transform'.format(elementary_axis))
        axis_name2known_length[elementary_axis] = axis_length

    input_axes_known_unknown = []
    # some of shapes will be inferred later - all information is prepared for faster inference
    for composite_axis in left.composition:
        known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
        unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
        if len(unknown) > 1:
            raise EinopsError('Could not infer sizes for {}'.format(unknown))
        assert len(unknown) + len(known) == len(composite_axis)
        input_axes_known_unknown.append(
            ([axis_name2position[axis] for axis in known],
             [axis_name2position[axis] for axis in unknown])
        )

    axis_position_after_reduction = {}
    for axis_name in itertools.chain(*left.composition):
        if axis_name in rght.identifiers:
            axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)

    result_axes_grouping: List[List[int]] = []
    for composite_axis in rght.composition:
        if composite_axis == _ellipsis:
            result_axes_grouping.append(_ellipsis_not_in_parenthesis)
        else:
            result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis])

    ordered_axis_right = list(itertools.chain(*rght.composition))
    axes_permutation = [
        axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers]
    added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right)
                  if axis_name not in left.identifiers}

    ellipsis_left = None if _ellipsis not in left.composition else left.composition.index(_ellipsis)

    return TransformRecipe(
        elementary_axes_lengths=list(axis_name2known_length.values()),
        input_composite_axes=input_axes_known_unknown,
        reduced_elementary_axes=reduced_axes,
        axes_permutation=axes_permutation,
        added_axes=added_axes,
        output_composite_axes=result_axes_grouping,
        ellipsis_position_in_lhs=ellipsis_left,
    )


[文档] def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: """ 重新排序和reduce的组合操作。 参数: - tensor (Var): 输入张量 - pattern (str): 字符串, 减少模式 - reduction (str): 减少操作,可用约简('min'、'max'、'sum'、'mean'、'prod')之一 - axes_lengths (int): 轴长度 返回值: - output (Var): 重塑后的张量 代码示例: >>> x = jt.randn(100, 32, 64) # 在第一个轴上执行最大值归约 >>> y = reduce(x, 't b c -> b c', 'max') # 与前面相同,但轴的含义更清晰 >>> y = reduce(x, 'time batch channel -> batch channel', 'max') >>> x = jt.randn(10, 20, 30, 40) # 使用2*2的核大小进行2D最大池化,用于图像处理 >>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) # 如果想恢复到原始的高度和宽度,可以应用深度到空间的技巧 >>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) >>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w') # 自适应2D最大池化到3*4的网格 >>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape (10, 20, 3, 4) # 全局平均池化 >>> reduce(x, 'b c h w -> b c', 'mean').shape (10, 20) # 为每个通道减去批次上的均值 >>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean') # 为每个图像的每个通道减去均值 >>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean') """ try: hashable_axes_lengths = tuple(sorted(axes_lengths.items())) recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths) return _apply_recipe(recipe, tensor, reduction_type=reduction) except EinopsError as e: message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) if not isinstance(tensor, list): message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor)) else: message += '\n Input is list. ' message += 'Additional info: {}.'.format(axes_lengths) raise EinopsError(message + '\n {}'.format(e))
[文档] def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: ''' 对多维张量进行智能、易读的元素重排的操作。此操作包括转置(轴置换)、重塑(视图)、挤压、展开、堆叠、连接等操作。 参数: - tensor (Union[Var, List[Var]]): 支持的任何库(例如 numpy.ndarray, jittor.Var)的张量,或相同类型和形状的张量列表。 - pattern (str): 重排模式的字符串描述。 - axes_lengths: 对维度的额外说明,可选。 返回值: 返回与输入相同类型的张量。尽可能返回原始张量的视图。 代码示例: >>> from jittor import einops >>> import numpy as np # 假设有32张 30x40x3 大小的图像 >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] # 沿批量轴堆叠,输出单一数组 >>> einops.rearrange(images, 'b h w c -> b h w c').shape (32, 30, 40, 3) # 沿高度轴拼接,输出 960x40x3 >>> einops.rearrange(images, 'b h w c -> (b h) w c').shape (960, 40, 3) # 沿宽度轴拼接,输出 30x1280x3 >>> einops.rearrange(images, 'b h w c -> h (b w) c').shape (30, 1280, 3) # 轴重排为 'b c h w' 格式 >>> einops.rearrange(images, 'b h w c -> b c h w').shape (32, 3, 30, 40) # 每个图像展平为矢量,输出 32x3600 >>> einops.rearrange(images, 'b h w c -> b (c h w)').shape (32, 3600) # 将图像分为4个较小的(左上、右上、左下、右下)部分,输出 128x15x20x3 >>> einops.rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape (128, 15, 20, 3) ''' if isinstance(tensor, list): if len(tensor) == 0: raise TypeError("Rearrange can't be applied to an empty list") tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor) return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)
[文档] def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: ''' 以任意组合的方式重新排序和重复元素。该操作包括 repeat、tile 和 broadcast 函数的功能。 参数: - tensor (Union[Var, List[Var]]): 支持的任何库(例如 numpy.ndarray, jittor.Var)的张量,或相同类型和形状的张量列表。 - pattern (str): 重排模式的字符串描述。 - axes_lengths: 维度的额外规格说明。 返回值: - 返回与输入相同类型的张量。如果可能,返回指向原始张量的视图。 代码示例: >>> from jittor import einops >>> import numpy as np # 灰度图像(30x40) >>> image = np.random.randn(30, 40) # 转换为 RGB 格式 >>> einops.repeat(image, 'h w -> h w c', c=3).shape (30, 40, 3) # 沿高度轴重复 2 次 >>> einops.repeat(image, 'h w -> (repeat h) w', repeat=2).shape (60, 40) # 沿高度和宽度分别重复 2 次和 3 次 >>> einops.repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape (60, 120) # 放大图像 2 倍 >>> einops.repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape (60, 80) # 缩小后放大图像 >>> downsampled = einops.reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) >>> einops.repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape (30, 40) ''' return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
[文档] def parse_shape(x, pattern: str) -> dict: ''' 将张量形状解析为字典,将轴名称映射到其长度。 参数: - x (Var): 输入张量 - pattern (str): 字符串, 轴的空格分隔名称,下划线表示跳过轴 返回值: - output (dict): 字典,将轴名称映射到它们的长度 代码示例: >>> x = jt.zeros([2, 3, 5, 7]) # 使用下划线 _ 来在解析时跳过维度 >>> parse_shape(x, 'batch _ h w') {'batch': 2, 'h': 5, 'w': 7} # `parse_shape`输出可用于指定其他操作的axes_lalength: >>> y = jt.zeros([700]) >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape (2, 10, 5, 7) ''' exp = ParsedExpression(pattern, allow_underscore=True) shape = get_backend(x).shape(x) if exp.has_composed_axes(): raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format( pattern=pattern, shape=shape)) if len(shape) != len(exp.composition): if exp.has_ellipsis: if len(shape) < len(exp.composition) - 1: raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format( pattern=pattern, shape=shape)) else: raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format( pattern=pattern, shape=shape)) if exp.has_ellipsis: ellipsis_idx = exp.composition.index(_ellipsis) composition = (exp.composition[:ellipsis_idx] + ['_'] * (len(shape) - len(exp.composition) + 1) + exp.composition[ellipsis_idx + 1:]) else: composition = exp.composition result = {} for (axis_name,), axis_length in zip(composition, shape): if axis_name != '_': result[axis_name] = axis_length return result
# this one is probably not needed in the public API def _enumerate_directions(x): """ For an n-dimensional tensor, returns tensors to enumerate each axis. ```python x = np.zeros([2, 3, 4]) # or any other tensor i, j, k = _enumerate_directions(x) result = i + 2*j + 3*k ``` `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result Works very similarly to numpy.ogrid (open indexing grid) """ backend = get_backend(x) shape = backend.shape(x) result = [] for axis_id, axis_length in enumerate(shape): shape = [1] * len(shape) shape[axis_id] = axis_length result.append(backend.reshape(backend.arange(0, axis_length), shape)) return result
[文档] def asnumpy(tensor) -> 'numpy.ndarray': """ 将一个张量转换为numpy.ndarray 参数: - tensor (Var): 输入张量 返回值: 输入张量转换后得到的numpy.ndarray 代码示例: >>> from jittor import einops >>> einops.asnumpy(jt.ones(3,3)) array([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], dtype=float32) """ return get_backend(tensor).to_numpy(tensor)
def _validate_einsum_axis_name(axis_name): if len(axis_name) == 0: raise NotImplementedError("Singleton () axes are not yet supported in einsum.") if len(axis_name) > 1: raise NotImplementedError("Shape rearrangement is not yet supported in einsum.") axis_name = axis_name[0] if isinstance(axis_name, AnonymousAxis): raise NotImplementedError("Anonymous axes are not yet supported in einsum.") if len(axis_name) == 0: raise RuntimeError("Encountered empty axis name in einsum.") if not isinstance(axis_name, str): raise RuntimeError("Axis name in einsum must be a string.") @functools.lru_cache(256) def _compactify_pattern_for_einsum(pattern: str) -> str: if "->" not in pattern: # numpy allows this, so make sure users # don't accidentally do something like this. raise ValueError("Einsum pattern must contain '->'.") lefts, right = pattern.split('->') lefts = lefts.split(',') lefts = [ ParsedExpression(left, allow_underscore=True, allow_duplicates=True) for left in lefts ] right = ParsedExpression(right, allow_underscore=True) # Start from 'a' and go up to 'Z' output_axis_names = string.ascii_letters i = 0 axis_name_mapping = {} left_patterns = [] for left in lefts: left_pattern = "" for raw_axis_name in left.composition: if raw_axis_name == _ellipsis: left_pattern += '...' continue _validate_einsum_axis_name(raw_axis_name) axis_name = raw_axis_name[0] if axis_name not in axis_name_mapping: if i >= len(output_axis_names): raise RuntimeError("Too many axes in einsum.") axis_name_mapping[axis_name] = output_axis_names[i] i += 1 left_pattern += axis_name_mapping[axis_name] left_patterns.append(left_pattern) compact_pattern = ",".join(left_patterns) + "->" for raw_axis_name in right.composition: if raw_axis_name == _ellipsis: compact_pattern += '...' continue _validate_einsum_axis_name(raw_axis_name) axis_name = raw_axis_name[0] if axis_name not in axis_name_mapping: raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.") compact_pattern += axis_name_mapping[axis_name] return compact_pattern @typing.overload def einsum(tensor: Tensor, pattern: str) -> Tensor: ... @typing.overload def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ... @typing.overload def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ... @typing.overload def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ... def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor: ''' einops 风格命名的轴索引调用 einsum 运算,允许与任意数量的张量计算张量乘积。与典型的 einsum 语法不同,这里必须先传入张量,然后传入模式。 参数: - tensors_and_pattern (List[Union[Var, str]]): 支持的库(如 numpy, jittor)的张量列表以及 einsum 模式字符串。 返回值: - 经过 einsum 处理后的同类型张量。 代码示例: >>> import numpy as np >>> from jittor.einops import einops # 假设有三个 20x20x20 的张量 x, y, z >>> x, y, z = np.random.randn(3, 20, 20, 20) >>> output = einops.einsum(x, y, z, \'''a b c, c b d, a g k -> a b k\''') >>> output.shape (3, 20, 20) # 过滤一组图像 >>> batched_images = np.random.randn(128, 16, 16) >>> filters = np.random.randn(16, 16, 30) >>> result = einops.einsum(batched_images, filters, \'''batch h w, h w channel -> batch channel\''') >>> result.shape (128, 30) # 矩阵乘法,未知输入形状 >>> batch_shape = (50, 30) >>> data = np.random.randn(*batch_shape, 20) >>> weights = np.random.randn(10, 20) >>> result = einops.einsum(weights, data, \'''out_dim in_dim, ... in_dim -> ... out_dim\''') >>> result.shape (50, 30, 10) # 单个张量的矩阵迹 >>> matrix = np.random.randn(10, 10) >>> result = einops.einsum(matrix, \'''i i -> \''') ''' if len(tensors_and_pattern) <= 1: raise ValueError( "`einops.einsum` takes at minimum two arguments: the tensors (at least one)," " followed by the pattern." ) pattern = tensors_and_pattern[-1] if not isinstance(pattern, str): raise ValueError( "The last argument passed to `einops.einsum` must be a string," " representing the einsum pattern." ) tensors = tensors_and_pattern[:-1] pattern = _compactify_pattern_for_einsum(pattern) return get_backend(tensors[0]).einsum(pattern, *tensors)