# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import init, Module, nn
import numpy as np
import math
[文档]
class MultiheadAttention(Module):
"""
多头注意力机制(multi-head attention)源自论文 `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ ,其定义为:
.. math::
\\text{MultiHead}(Q, K, V) = \\text{Concat}(head_1,\dots,head_h)W^O
其中 :math:`head_i = \\text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
参数:
- embed_dim (int): 输入的维度。
- num_heads (int): 多头的数量。
- kdim (int, optional): 键值的维度。默认值为None,表示与embed_dim相同。
- vdim (int, optional): 值的维度。默认值为None,表示与embed_dim相同。
- dropout (float, optional): dropout概率。默认值为0.0。
- bias (bool, optional): 是否使用偏置。默认值为True。
- add_bias_kv (bool, optional): 是否添加偏置。默认值为False。
- add_zero_attn (bool, optional): 是否添加零注意力。默认值为False。
- self_attention (bool, optional): 是否自注意力。默认值为False。
- encoder_decoder_attention (bool, optional): 是否编码器解码器注意力。默认值为False。
- q_noise (float, optional): query的噪声。默认值为0.0。
- qn_block_size (int, optional): 块大小。默认值为8。
属性:
- embed_dim (int): 输入的维度。
- kdim (int): 键值的维度。
- vdim (int): 值的维度。
- num_heads (int): 多头的数量。
- head_dim (int): 头的维度。
- scaling (float): 缩放因子。
- self_attention (bool): 是否自注意力。
- encoder_decoder_attention (bool): 是否编码器解码器注意力。
- k_proj (nn.Linear): 键的线性变换。
- v_proj (nn.Linear): 值的线性变换。
- q_proj (nn.Linear): 查询的线性变换。
- out_proj (nn.Linear): 输出的线性变换。
- bias_k (None): 键的偏置。
- bias_v (None): 值的偏置。
- add_zero_attn (bool): 是否添加零注意力。
形状:
- 输入: :math:`(L, N, E)` 其中 :math:`L` 表示序列长度, :math:`N` 表示batch大小, :math:`E` 表示输入的维度。
- 输出: :math:`(L, N, E)` 其中 :math:`L` 表示序列长度, :math:`N` 表示batch大小, :math:`E` 表示输入的维度。
代码示例:
>>> multihead_attn = jt.attention.MultiheadAttention(embed_dim, num_heads)
>>> attn, attn_weights = multihead_attn(query, key, value)
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
assert dropout==0, "TODO: dropout>0"
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size")
#TODO: quant_noise
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
assert not add_bias_kv, "TODO: add_bias_kv=True"
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.onnx_trace = False
self.tpu = False
[文档]
def reset_parameters(self):
'''
初始化参数
代码示例:
>>> multihead_attn = jt.attention.MultiheadAttention(embed_dim, num_heads)
>>> multihead_attn.reset_parameters()
'''
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
init.xavier_uniform_(self.k_proj.weight)
init.xavier_uniform_(self.v_proj.weight)
init.xavier_uniform_(self.q_proj.weight)
# init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
init.xavier_normal_(self.bias_v)
def execute(
self,
query,
key = None,
value = None,
key_padding_mask = None,
incremental_state = None,
need_weights = True,
static_kv = False,
attn_mask = None,
before_softmax = False,
need_head_weights = False,
):
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert incremental_state is None, "TODO: incremental_state is not None"
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = q*self.scaling
assert self.bias_k is None, "TODO: self.bias_k is not None:"
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if k is not None:
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if v is not None:
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
assert saved_state is None, "TODO: saved_state is not None"
assert k is not None
src_len = k.shape[1]
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
assert not self.add_zero_attn, "TODO: self.add_zero_attn=True"
attn_weights = nn.bmm(q, k.transpose(0, 2, 1))
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
assert attn_mask is None, "TODO: attn_mask is not None"
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
if before_softmax:
return attn_weights, v
attn_weights_float = nn.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
assert v is not None
attn = nn.bmm(attn_weights, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.shape[1] == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights = None
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dims=[0])
return attn, attn_weights