2025 USA-NA-AIO Round 2, Problem 2, Part 9

Now, let us study another variant of MHA: Multi-head Latent Attention (MLA). MLA was introduced by DeepSeek. It is a core component of DeepSeek’s large language model (LLM).

The key intuition of MLA is as follows. In MHA, the key and value projection matrices

\mathbf{W}^{\mathbf{K}, MHA} \in \Bbb R^{H \cdot D_{qk} \times D_2} , \quad \mathbf{W}^{\mathbf{V}, MHA} \in \Bbb R^{H \cdot D_v \times D_2}

may be high dimensional.

For instance, suppose H \cdot D_{qk} = H \cdot D_v = D_2 = 4096.

However, it is not necessarily the case that these matrices are with high ranks (such as 4096). Their actual ranks (or top few ranks that make their truncated singular value decomposition (SVD) to be close to the actual matrices) may be much lower than that.

To capture the low-rank feature, MLA proposed the following model:

\begin{align*} \mathbf{W}^{\mathbf{K}, MHA} & = \color{blue}{\mathbf{W}^{\mathbf{UK}, MLA}} \color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}} \\ \mathbf{W}^{\mathbf{V}, MHA} & = \color{green}{\mathbf{W}^{\mathbf{UV}, MLA}} \color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}} \end{align*}

where

  • \color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}} \in \Bbb R^{r \times D_2}: down-projection matrix for computing keys and values.

  • \color{blue}{\mathbf{W}^{\mathbf{UK}, MLA}} \in \Bbb R^{H \cdot D_{qk} \times r}: up-projection matrix for computing keys.

  • \color{green}{\mathbf{W}^{\mathbf{UV}, MLA}} \in \Bbb R^{H \cdot D_v \times r}: up-projection matrix for computing values.

In practice, rank r is typically much smaller than \min \left\{ H \cdot D_{qk} , H \cdot D_v , D_2 \right\}.

In all remaining parts of this problem, to simplify your analysis and highlight the relationships of MHA, GQA and MLA, we make the following assumptions:

  • D_1 = D_2 = \color{red}{D}.

  • D_{qk} = D_v = \color{red}{d}.

  • d is a factor of D.

Under these assumptions, the number heads H satisfies

\color{red}{H} = \frac{D}{d} .

Part 9 (10 points, non-coding task)

In this part, you are asked to prove that GQA can be equivalently represented by MLA.

In your solution, it is sufficient for you to prove that for \mathbf{M} \in \left\{ \mathbf{K}, \mathbf{V} \right\}, for matrix

\mathbf{\tilde W}^{\mathbf{M}, GQA} = \begin{bmatrix} \color{orange}{\mathbf{W}^{\mathbf{M}, GQA}} \\ \color{orange}{\mathbf{W}^{\mathbf{M}, GQA}} \\ \vdots \\ \color{orange}{\mathbf{W}^{\mathbf{M}, GQA}} \end{bmatrix} \in \Bbb R^{D \times D}

(defined in Part 6) who is the concatenation of \frac{H}{G} copies of

{\color{orange}{\mathbf{W}^{\mathbf{M}, GQA}}} = \begin{bmatrix} \mathbf{W}^{\mathbf{M}, GQA}_0 \\ \vdots \\ \mathbf{W}^{\mathbf{M}, GQA}_{G-1} \end{bmatrix} \in \Bbb R^{G \cdot d \times D}

matrix \mathbf{\tilde W}^{\mathbf{M}, GQA} can be decomposed as

\mathbf{\tilde W}^{\mathbf{M}, GQA} = \color{blue}{\mathbf{W}^{\mathbf{UM}, MLA}} \color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}}

where

  • \color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}} \in \Bbb R^{r \times D}: down-projection matrix for computing keys and values.

  • \color{blue}{\mathbf{W}^{\mathbf{UM}, MLA}} \in \Bbb R^{D \times r}: up-projection matrix for computing \mathbf{M} (keys or values).

  • r = G \cdot d.

\color{green}{\text{### WRITE YOUR SOLUTION HERE ###}}

We have

\begin{align*} \text{rank} \left( \mathbf{\tilde W}^{\mathbf{M}, GQA} \right) & = \text{rank} \left( \mathbf{W}^{\mathbf{M}, GQA} \right) \\ & \leq \min \left\{ G \cdot d, D \right\} \\ & = \min \left\{ r, D \right\} \\ & = r , \end{align*}

where the first equality follows from Part 6.

Therefore, SVD implies

\begin{align*} \mathbf{\tilde W}^{\mathbf{M}, GQA} & = \sum_{i=0}^{r-1} \sigma_i \mathbf{u}_i \mathbf{v}_i^\top \\ & = \underbrace{\begin{bmatrix} \mathbf{u}_0 & \mathbf{u}_1 & \cdots & \mathbf{u}_{r-1} \end{bmatrix}}_{\color{blue}{\mathbf{W}^{\mathbf{UM}, MLA}}} \underbrace{\begin{bmatrix} \sigma_0 \mathbf{v}_0^\top \\ \sigma_1 \mathbf{v}_1^\top \\ \vdots \\ \sigma_{r-1} \mathbf{v}_{r-1}^\top \end{bmatrix}}_{\color{red}{\mathbf{W}^{\mathbf{DKV}, MLA}}} . \end{align*}

\color{red}{\text{""" END OF THIS PART """}}