jittor.attention

这里是Jittor的 注意力 模块的API文档,您可以通过from jittor import attention来获取该模块。

class jittor.attention.MultiheadAttention(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8)[源代码]

多头注意力机制(multi-head attention)源自论文 Attention Is All You Need ,其定义为:

\[\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O\]

其中 \(head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\).

参数:
  • embed_dim (int): 输入的维度。

  • num_heads (int): 多头的数量。

  • kdim (int, optional): 键值的维度。默认值为None,表示与embed_dim相同。

  • vdim (int, optional): 值的维度。默认值为None,表示与embed_dim相同。

  • dropout (float, optional): dropout概率。默认值为0.0。

  • bias (bool, optional): 是否使用偏置。默认值为True。

  • add_bias_kv (bool, optional): 是否添加偏置。默认值为False。

  • add_zero_attn (bool, optional): 是否添加零注意力。默认值为False。

  • self_attention (bool, optional): 是否自注意力。默认值为False。

  • encoder_decoder_attention (bool, optional): 是否编码器解码器注意力。默认值为False。

  • q_noise (float, optional): query的噪声。默认值为0.0。

  • qn_block_size (int, optional): 块大小。默认值为8。

属性:
  • embed_dim (int): 输入的维度。

  • kdim (int): 键值的维度。

  • vdim (int): 值的维度。

  • num_heads (int): 多头的数量。

  • head_dim (int): 头的维度。

  • scaling (float): 缩放因子。

  • self_attention (bool): 是否自注意力。

  • encoder_decoder_attention (bool): 是否编码器解码器注意力。

  • k_proj (nn.Linear): 键的线性变换。

  • v_proj (nn.Linear): 值的线性变换。

  • q_proj (nn.Linear): 查询的线性变换。

  • out_proj (nn.Linear): 输出的线性变换。

  • bias_k (None): 键的偏置。

  • bias_v (None): 值的偏置。

  • add_zero_attn (bool): 是否添加零注意力。

形状:
  • 输入: \((L, N, E)\) 其中 \(L\) 表示序列长度, \(N\) 表示batch大小, \(E\) 表示输入的维度。

  • 输出: \((L, N, E)\) 其中 \(L\) 表示序列长度, \(N\) 表示batch大小, \(E\) 表示输入的维度。

代码示例:
>>> multihead_attn = jt.attention.MultiheadAttention(embed_dim, num_heads)
>>> attn, attn_weights = multihead_attn(query, key, value)
reset_parameters()[源代码]

初始化参数

代码示例:
>>> multihead_attn = jt.attention.MultiheadAttention(embed_dim, num_heads)
>>> multihead_attn.reset_parameters()