2025 USA-NA-AIO Round 2, Problem 2, Part 12

MLA does not only enjoy its advantage of being more general than MHA and GQA, it is also computationally more efficient.

An intuitive approach of computing MLA.

  1. Compute the key-projection matrix \mathbf{W}^{\mathbf{UK}, MLA} \mathbf{W}^{\mathbf{DKV}, MLA} \in \Bbb R^{D \times D} and the value-projection matrix \mathbf{W}^{\mathbf{UV}, MLA} \mathbf{W}^{\mathbf{DKV}, MLA} \in \Bbb R^{D \times D}.

  2. Follow the standard steps in MHA.

This approach is hereafter called a \color{red}{\textbf{vanilla approach}}. This approach fails to enjoy the low-rank feature of \mathbf{W}^{\mathbf{DKV}, MLA}, \mathbf{W}^{\mathbf{UK}, MLA}, and \mathbf{W}^{\mathbf{UV}, MLA}.

Part 12 (10 points, non-coding task)

In this part, you are asked to study an alternative approach to compute MLA.

  1. Find a head-independent reduced key-projection matrix \color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}} \in \Bbb R^{r \times D} and a reduced query-projection matrix \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}} \in \Bbb R^{H \cdot r \times D}, such that

    • The reduced key at position l_2 for head h in a being attended sequence is head-independent and is given by:
    \mathbf{\hat k}_{l_2} = \color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}} \mathbf{y}_{l_2} \in \Bbb R^r
    • The reduced query at position l_1 for head h in an attending sequence is given by:
    \mathbf{\hat q}_{l_1, h} = \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_h} \mathbf{x}_{l_1} \in \Bbb R^r

    \quad \quad where

    \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}} = \begin{bmatrix} \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_0} \\ \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_1 }\\ \vdots \\ \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_{H-1}} \end{bmatrix}
    • The attention score (query-key similarity) is invariant in both the original and the reduced forms. That is
    \frac{\mathbf{q}_{l_1,h}^\top \mathbf{v}_{l_2,h}}{\sqrt{D/H}} = \frac{\mathbf{\hat q}_{l_1,h}^\top \mathbf{\hat v}_{l_2}}{\sqrt{r}} . \quad (1)
  2. Find a head-independent reduced value-projection matrix \color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}} \in \Bbb R^{r \times D} and a reduced out-projection matrix \color{blue}{\hat {\mathbf W}^{O, MLA}} \in \Bbb R^{D \times H \cdot r}, such that

    • The reduced value with head h on position l_2 in a being attended sequence is head-independent and is given by:
    \mathbf{\hat v}_{l_2} = {\color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}}} \mathbf{y}_{l_2} \in \Bbb R^r
    • Post-out-projection is invariant in both the original and the reduced forms.

      Let

    \color{blue}{\hat {\mathbf W}^{O, MLA}} = \begin{bmatrix} \color{blue}{\hat {\mathbf W}^{O, MLA}_0} & \color{blue}{\hat {\mathbf W}^{O, MLA}_1} & \cdots & \color{blue}{\hat {\mathbf W}^{O, MLA}_{H-1}} \end{bmatrix}

    \quad \quad Then we must have

    \sum_{h=0}^{H-1} \mathbf W^O_h \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} \mathbf{v}_{l_2,h} = \sum_{h=0}^{H-1} {\color{blue}{\hat {\mathbf W}^{O, MLA}_h}} \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} \mathbf{\hat v}_{l_2,h} . \quad (2)
  3. You answer of \color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}}, \color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}}, \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}}, and \color{blue}{\hat {\mathbf W}^{O, MLA}} should be written in terms of \mathbf{W}^{\mathbf{DKV}}, \mathbf{W}^{\mathbf{UK}}, \mathbf{W}^{\mathbf{UV}}, \mathbf{W}^{\mathbf{Q}}, and \mathbf{W}^O.

\color{green}{\text{### WRITE YOUR SOLUTION HERE ###}}

First, we study Equation (1).

For the LHS in (1), we have

\begin{align*} \frac{\mathbf{q}_{l_1,h}^\top \mathbf{v}_{l_2,h}}{\sqrt{D/H}} & = \frac{1}{\sqrt{D/H}} \left( {\mathbf W}^{\mathbf{Q}}_h \mathbf{x}_{l_1} \right)^\top \left( {\mathbf W}^{\mathbf{UK}}_h {\mathbf W}^{\mathbf{DKV}} \mathbf{y}_{l_2} \right) \\ & = \frac{1}{\sqrt{D/H}} \mathbf{x}_{l_1}^\top {\mathbf W}^{\mathbf{Q}, \top}_h {\mathbf W}^{\mathbf{UK}}_h {\mathbf W}^{\mathbf{DKV}} \mathbf{y}_{l_2} \quad (1.1) \end{align*}

For the RHS in (1), we have

\begin{align*} \frac{\mathbf{\hat q}_{l_1,h}^\top \mathbf{\hat v}_{l_2}}{\sqrt{r}} & = \frac{1}{\sqrt{r}} \left( {\color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_h} }\mathbf{x}_{l_1} \right)^\top \left( {\color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}}} \mathbf{y}_{l_2} \right) \\ & = \frac{1}{\sqrt{r}} \mathbf{x}_{l_1}^\top {\color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA, \top}_h} \color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}}} \mathbf{y}_{l_2} \quad (1.2) \end{align*}

By equating (1.1) and (1.2), we can set

{\color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}}} = \boxed{\mathbf{W}^{\mathbf{DKV}} }

and

{\color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_h}} = \frac{\sqrt{r}}{\sqrt{D/H}} \mathbf{W}^{\mathbf{UK}, \top}_h \mathbf{W}^{\mathbf{Q}}_h

Therefore,

\begin{align*} \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}} & = \begin{bmatrix} \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_0} \\ \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_1} \\ \vdots \\ \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_{H-1}} \end{bmatrix} \\ & = \boxed{\frac{\sqrt{r}}{\sqrt{D/H}} \begin{bmatrix} \mathbf{W}^{\mathbf{UK}, \top}_0 \mathbf{W}^{\mathbf{Q}}_0 \\ \mathbf{W}^{\mathbf{UK}, \top}_1 \mathbf{W}^{\mathbf{Q}}_1 \\ \vdots \\ \mathbf{W}^{\mathbf{UK}, \top}_{H-1} \mathbf{W}^{\mathbf{Q}}_{H-1} \end{bmatrix}} \end{align*}

Second, we study Equation (2).

For the LHS in (2), we have

\begin{align*} \sum_{h=0}^{H-1} \mathbf W^{O, MLA}_h \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} \mathbf{v}_{l_2,h} & = \sum_{h=0}^{H-1} \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} \mathbf W^O_h {\mathbf W}^{\mathbf{UV}}_h {\mathbf W}^{\mathbf{DKV}} \mathbf{y}_{l_2} \end{align*}

For the RHS in (2), we have

\begin{align*} \sum_{h=0}^{H-1} {\color{blue}{\hat {\mathbf W}^{O, MLA}_h}} \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} \mathbf{\hat v}_{l_2,h} & = \sum_{h=0}^{H-1} \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} {\color{blue}{\hat {\mathbf W}^{O, MLA}_h} \color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}}} \mathbf{y}_{l_2} \end{align*}

By equating (2.1) and (2.2), we can set

{\color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}}} = \boxed{\mathbf{W}^{\mathbf{DKV}} }

and

{\color{blue}{\hat {\mathbf W}^{O, MLA}_h}} = \mathbf{W}^O_h \mathbf{W}^{\mathbf{UV}}_h

Therefore

\begin{align*} \color{blue}{\hat {\mathbf W}^{O, MLA}} & = \begin{bmatrix} \color{blue}{\hat {\mathbf W}^{O, MLA}_0} & \color{blue}{\hat {\mathbf W}^{O, MLA}_1} & \cdots & \color{blue}{\hat {\mathbf W}^{O, MLA}_{H-1}} \end{bmatrix} \\ & = \boxed{\begin{bmatrix} \mathbf{W}^O_0 \mathbf{W}^{\mathbf{UV}}_0 & \mathbf{W}^O_1 \mathbf{W}^{\mathbf{UV}}_1 & \cdots & \mathbf{W}^O_{H-1} \mathbf{W}^{\mathbf{UV}}_{H-1} \end{bmatrix}} \end{align*}

\color{red}{\text{""" END OF THIS PART """}}