Now, let us study another variant of MHA: Multi-head Latent Attention (MLA). MLA was introduced by DeepSeek. It is a core component of DeepSeek’s large language model (LLM).
The key intuition of MLA is as follows. In MHA, the key and value projection matrices
may be high dimensional.
For instance, suppose H \cdot D_{qk} = H \cdot D_v = D_2 = 4096.
However, it is not necessarily the case that these matrices are with high ranks (such as 4096). Their actual ranks (or top few ranks that make their truncated singular value decomposition (SVD) to be close to the actual matrices) may be much lower than that.
To capture the low-rank feature, MLA proposed the following model:
where
-
\color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}} \in \Bbb R^{r \times D_2}: down-projection matrix for computing keys and values.
-
\color{blue}{\mathbf{W}^{\mathbf{UK}, MLA}} \in \Bbb R^{H \cdot D_{qk} \times r}: up-projection matrix for computing keys.
-
\color{green}{\mathbf{W}^{\mathbf{UV}, MLA}} \in \Bbb R^{H \cdot D_v \times r}: up-projection matrix for computing values.
In practice, rank r is typically much smaller than \min \left\{ H \cdot D_{qk} , H \cdot D_v , D_2 \right\}.
In all remaining parts of this problem, to simplify your analysis and highlight the relationships of MHA, GQA and MLA, we make the following assumptions:
-
D_1 = D_2 = \color{red}{D}.
-
D_{qk} = D_v = \color{red}{d}.
-
d is a factor of D.
Under these assumptions, the number heads H satisfies
Part 9 (10 points, non-coding task)
In this part, you are asked to prove that GQA can be equivalently represented by MLA.
In your solution, it is sufficient for you to prove that for \mathbf{M} \in \left\{ \mathbf{K}, \mathbf{V} \right\}, for matrix
(defined in Part 6) who is the concatenation of \frac{H}{G} copies of
matrix \mathbf{\tilde W}^{\mathbf{M}, GQA} can be decomposed as
where
-
\color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}} \in \Bbb R^{r \times D}: down-projection matrix for computing keys and values.
-
\color{blue}{\mathbf{W}^{\mathbf{UM}, MLA}} \in \Bbb R^{D \times r}: up-projection matrix for computing \mathbf{M} (keys or values).
-
r = G \cdot d.