jittor.attention 源代码

# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved. 
# Maintainers:
#     Guowei Yang <471184555@qq.com>
#     Dun Liang <randonlang@gmail.com>. 
#
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

import jittor as jt
from jittor import init, Module, nn
import numpy as np
import math

[文档] class MultiheadAttention(Module): """ 多头注意力机制(multi-head attention)源自论文 `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ ,其定义为: .. math:: \\text{MultiHead}(Q, K, V) = \\text{Concat}(head_1,\dots,head_h)W^O 其中 :math:`head_i = \\text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. 参数: - embed_dim (int): 输入的维度。 - num_heads (int): 多头的数量。 - kdim (int, optional): 键值的维度。默认值为None,表示与embed_dim相同。 - vdim (int, optional): 值的维度。默认值为None,表示与embed_dim相同。 - dropout (float, optional): dropout概率。默认值为0.0。 - bias (bool, optional): 是否使用偏置。默认值为True。 - add_bias_kv (bool, optional): 是否添加偏置。默认值为False。 - add_zero_attn (bool, optional): 是否添加零注意力。默认值为False。 - self_attention (bool, optional): 是否自注意力。默认值为False。 - encoder_decoder_attention (bool, optional): 是否编码器解码器注意力。默认值为False。 - q_noise (float, optional): query的噪声。默认值为0.0。 - qn_block_size (int, optional): 块大小。默认值为8。 属性: - embed_dim (int): 输入的维度。 - kdim (int): 键值的维度。 - vdim (int): 值的维度。 - num_heads (int): 多头的数量。 - head_dim (int): 头的维度。 - scaling (float): 缩放因子。 - self_attention (bool): 是否自注意力。 - encoder_decoder_attention (bool): 是否编码器解码器注意力。 - k_proj (nn.Linear): 键的线性变换。 - v_proj (nn.Linear): 值的线性变换。 - q_proj (nn.Linear): 查询的线性变换。 - out_proj (nn.Linear): 输出的线性变换。 - bias_k (None): 键的偏置。 - bias_v (None): 值的偏置。 - add_zero_attn (bool): 是否添加零注意力。 形状: - 输入: :math:`(L, N, E)` 其中 :math:`L` 表示序列长度, :math:`N` 表示batch大小, :math:`E` 表示输入的维度。 - 输出: :math:`(L, N, E)` 其中 :math:`L` 表示序列长度, :math:`N` 表示batch大小, :math:`E` 表示输入的维度。 代码示例: >>> multihead_attn = jt.attention.MultiheadAttention(embed_dim, num_heads) >>> attn, attn_weights = multihead_attn(query, key, value) """ def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8, ): super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads assert dropout==0, "TODO: dropout>0" self.head_dim = embed_dim // num_heads assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim ** -0.5 self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size") #TODO: quant_noise self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) assert not add_bias_kv, "TODO: add_bias_kv=True" self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self.reset_parameters() self.onnx_trace = False self.tpu = False
[文档] def reset_parameters(self): ''' 初始化参数 代码示例: >>> multihead_attn = jt.attention.MultiheadAttention(embed_dim, num_heads) >>> multihead_attn.reset_parameters() ''' if self.qkv_same_dim: # Empirically observed the convergence to be much better with # the scaled initialization init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) else: init.xavier_uniform_(self.k_proj.weight) init.xavier_uniform_(self.v_proj.weight) init.xavier_uniform_(self.q_proj.weight) # init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: init.constant_(self.out_proj.bias, 0.) if self.bias_k is not None: init.xavier_normal_(self.bias_k) if self.bias_v is not None: init.xavier_normal_(self.bias_v)
def execute( self, query, key = None, value = None, key_padding_mask = None, incremental_state = None, need_weights = True, static_kv = False, attn_mask = None, before_softmax = False, need_head_weights = False, ): if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.shape assert embed_dim == self.embed_dim assert list(query.shape) == [tgt_len, bsz, embed_dim] assert incremental_state is None, "TODO: incremental_state is not None" saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q = q*self.scaling assert self.bias_k is None, "TODO: self.bias_k is not None:" q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) if k is not None: k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) if v is not None: v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) assert saved_state is None, "TODO: saved_state is not None" assert k is not None src_len = k.shape[1] assert key_padding_mask is None, "TODO: key_padding_mask is not None" assert not self.add_zero_attn, "TODO: self.add_zero_attn=True" attn_weights = nn.bmm(q, k.transpose(0, 2, 1)) assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] assert attn_mask is None, "TODO: attn_mask is not None" assert key_padding_mask is None, "TODO: key_padding_mask is not None" if before_softmax: return attn_weights, v attn_weights_float = nn.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) assert v is not None attn = nn.bmm(attn_weights, v) assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.shape[1] == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights = None if need_weights: attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dims=[0]) return attn, attn_weights