attention 计算过程的一些细节
最近,有粉丝问我,attention 结构中计算 qkv 的时候,为什么要做 kvcache 呢?他看了一些文章,没看懂。
为什么要做 kvcache?
假设模型的输入序列长度是 2,隐藏层的维度是 H,那么 q、k、v 的维度分别是[2, H]
假设它们的值分别是:
那么首先 q*k 的结果为:
然后需要做一个 mask,只留下下三角的值,其他值都取 0,得到:
为什么要做 mask,我认为是要和训练时的规则保持一致,因为训练的时候,是认为每个 token 只能看到它前面的词的。
然后计算 qk*v:
完成后续的计算可以预测得到 1 个新的 token。如果还需要继续预测下一个词,在下一次计算的时候我们假设 q、k、v 为:
同样得到 q*k 为:
qk*v 为:
可以看到,第 2 次计算得到的 qk 相比于第 1 次的 qk 只是多了第 3 行。而第 3 行的值是 q3*[k1, k2, k3],所以为了避免重复计算,我们只需要在第 2 次计算的时候,只计算新 token 对应的 q3 和 k3,然后把 k3 和第 1 次计算得到的[k1, k2]拼接起来即可,[k1, k2]就是 k cache。
同样可以发现,第 2 次计算得到的 qkv 相比于第 1 次的 qkv 只是多了第 3 行。而第 3 行的值是 qk*[v1, v2, v3],所以为了避免重复计算,我们只需要在第 2 次计算的时候,只计算新 token 对应的 v3,然后把 v3 和第 1 次计算得到的[v1, v2]拼接起来即可,[v1, v2]就是 v cache。
以此类推,在后续的增量推理过程中,每次只需要计算新 token 的 q、k、v,然后利用之前缓存的 kv cache 计算 qk 和 qkv。
transformer 是怎样预测出下一个词的?
首先,从数学层面来讲,是这样计算的:
首先,假设输入序列的长度是 L,隐藏层的特征维度是 H,词汇表的长度是 V,那么在计算 qkv 的过程中,输入 x 的 shape 变化如下:
q*k:(L, H)x(H, L)->(L, L)
qk*v:(L, L)x(L, H)->(L, H)
然后再经过 forward layer 的一系列全连接层,得到的输出 shape 为(L, V),而它的最后一个分量,也就是 output[L-1],就是预测结果的概率分布。
那么怎么理解这个计算过程呢?这个就可以有很多答案了,我一般是这么给别人解释的:首先在计算 q*k 的时候,qk 的最后一个分量是用最后一个词去和其他词的 key 值做乘法,这一步相当于计算最后一个词和句子中每个词的相关性,然后乘以 v 就相当于把最后一个词和其他词的相关性进行一个组合,后面再通过多个全连接层进行上下文理解这个词在整个句子中的含义,并预测出下一个词。
这里又引入了另一个问题,既然在首次计算时,只用到了最后一个分量,为什么还要计算 qk 和 qkv 的第 1 到第 L-1 个分量的值呢?这是因为大模型由多个 decoder layer 叠加组成。第 1 个 decoder 输出的结果还需要作为 x 输入给第 2 个 decoder layer,进行多轮"思考"。再具体一点,我们还是假设输入序列长度是 2,经过第 1 个 decoder layer 后输出为:
那么它再作为输入传给第 2 个 decoder layer,第 2 个 decoder layer 计算得到的 qkv 是:
它的最后一个分量是 q2k1v1+q2k2v2,其中的 k1、v1 都和 h1 相关,所以做首次计算(也就是我们常说的全量计算)时,qk 和 qkv 的每个分量都要计算。
大家还有什么疑问呢?欢迎讨论哦!
评论