MLA does not only enjoy its advantage of being more general than MHA and GQA, it is also computationally more efficient.
An intuitive approach of computing MLA.
-
Compute the key-projection matrix \mathbf{W}^{\mathbf{UK}, MLA} \mathbf{W}^{\mathbf{DKV}, MLA} \in \Bbb R^{D \times D} and the value-projection matrix \mathbf{W}^{\mathbf{UV}, MLA} \mathbf{W}^{\mathbf{DKV}, MLA} \in \Bbb R^{D \times D}.
-
Follow the standard steps in MHA.
This approach is hereafter called a \color{red}{\textbf{vanilla approach}}. This approach fails to enjoy the low-rank feature of \mathbf{W}^{\mathbf{DKV}, MLA}, \mathbf{W}^{\mathbf{UK}, MLA}, and \mathbf{W}^{\mathbf{UV}, MLA}.
Part 12 (10 points, non-coding task)
In this part, you are asked to study an alternative approach to compute MLA.
-
Find a head-independent reduced key-projection matrix \color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}} \in \Bbb R^{r \times D} and a reduced query-projection matrix \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}} \in \Bbb R^{H \cdot r \times D}, such that
- The reduced key at position l_2 for head h in a being attended sequence is head-independent and is given by:
\mathbf{\hat k}_{l_2} = \color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}} \mathbf{y}_{l_2} \in \Bbb R^r- The reduced query at position l_1 for head h in an attending sequence is given by:
\mathbf{\hat q}_{l_1, h} = \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_h} \mathbf{x}_{l_1} \in \Bbb R^r\quad \quad where
\color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}} = \begin{bmatrix} \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_0} \\ \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_1 }\\ \vdots \\ \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}_{H-1}} \end{bmatrix}- The attention score (query-key similarity) is invariant in both the original and the reduced forms. That is
\frac{\mathbf{q}_{l_1,h}^\top \mathbf{v}_{l_2,h}}{\sqrt{D/H}} = \frac{\mathbf{\hat q}_{l_1,h}^\top \mathbf{\hat v}_{l_2}}{\sqrt{r}} . \quad (1) -
Find a head-independent reduced value-projection matrix \color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}} \in \Bbb R^{r \times D} and a reduced out-projection matrix \color{blue}{\hat {\mathbf W}^{O, MLA}} \in \Bbb R^{D \times H \cdot r}, such that
- The reduced value with head h on position l_2 in a being attended sequence is head-independent and is given by:
\mathbf{\hat v}_{l_2} = {\color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}}} \mathbf{y}_{l_2} \in \Bbb R^r-
Post-out-projection is invariant in both the original and the reduced forms.
Let
\color{blue}{\hat {\mathbf W}^{O, MLA}} = \begin{bmatrix} \color{blue}{\hat {\mathbf W}^{O, MLA}_0} & \color{blue}{\hat {\mathbf W}^{O, MLA}_1} & \cdots & \color{blue}{\hat {\mathbf W}^{O, MLA}_{H-1}} \end{bmatrix}\quad \quad Then we must have
\sum_{h=0}^{H-1} \mathbf W^O_h \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} \mathbf{v}_{l_2,h} = \sum_{h=0}^{H-1} {\color{blue}{\hat {\mathbf W}^{O, MLA}_h}} \sum_{l_2 = 0}^{L_2 - 1} \alpha_{h, l_1 l_2} \mathbf{\hat v}_{l_2,h} . \quad (2) -
You answer of \color{blue}{\hat {\mathbf W}^{\mathbf{K}, MLA}}, \color{blue}{\hat {\mathbf W}^{\mathbf{V}, MLA}}, \color{blue}{\hat {\mathbf W}^{\mathbf{Q}, MLA}}, and \color{blue}{\hat {\mathbf W}^{O, MLA}} should be written in terms of \mathbf{W}^{\mathbf{DKV}}, \mathbf{W}^{\mathbf{UK}}, \mathbf{W}^{\mathbf{UV}}, \mathbf{W}^{\mathbf{Q}}, and \mathbf{W}^O.