jittor.attention
这里是Jittor的 注意力 模块的API文档,您可以通过from jittor import attention
来获取该模块。
- class jittor.attention.MultiheadAttention(embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8)[源代码]
多头注意力机制(multi-head attention)源自论文 Attention Is All You Need ,其定义为:
\[\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O\]其中 \(head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\).
- 参数:
embed_dim (int): 输入的维度。
num_heads (int): 多头的数量。
kdim (int, optional): 键值的维度。默认值为None,表示与embed_dim相同。
vdim (int, optional): 值的维度。默认值为None,表示与embed_dim相同。
dropout (float, optional): dropout概率。默认值为0.0。
bias (bool, optional): 是否使用偏置。默认值为True。
add_bias_kv (bool, optional): 是否添加偏置。默认值为False。
add_zero_attn (bool, optional): 是否添加零注意力。默认值为False。
self_attention (bool, optional): 是否自注意力。默认值为False。
encoder_decoder_attention (bool, optional): 是否编码器解码器注意力。默认值为False。
q_noise (float, optional): query的噪声。默认值为0.0。
qn_block_size (int, optional): 块大小。默认值为8。
- 属性:
embed_dim (int): 输入的维度。
kdim (int): 键值的维度。
vdim (int): 值的维度。
num_heads (int): 多头的数量。
head_dim (int): 头的维度。
scaling (float): 缩放因子。
self_attention (bool): 是否自注意力。
encoder_decoder_attention (bool): 是否编码器解码器注意力。
k_proj (nn.Linear): 键的线性变换。
v_proj (nn.Linear): 值的线性变换。
q_proj (nn.Linear): 查询的线性变换。
out_proj (nn.Linear): 输出的线性变换。
bias_k (None): 键的偏置。
bias_v (None): 值的偏置。
add_zero_attn (bool): 是否添加零注意力。
- 形状:
输入: \((L, N, E)\) 其中 \(L\) 表示序列长度, \(N\) 表示batch大小, \(E\) 表示输入的维度。
输出: \((L, N, E)\) 其中 \(L\) 表示序列长度, \(N\) 表示batch大小, \(E\) 表示输入的维度。
- 代码示例:
>>> multihead_attn = jt.attention.MultiheadAttention(embed_dim, num_heads) >>> attn, attn_weights = multihead_attn(query, key, value)