2025 USA-NA-AIO Round 2, Problem 2, Part 6

Next, let us study a variant of MHA: Group Query Attention (GQA).

Recall that in MHA, the number of heads in queries, keys and values are the same, H. Thus, query \mathbf{q}_{l_1, h} attends to key \mathbf{k}_{l_2, h} with the same head index h.

In GQA, we relax this constraint by allowing keys and values to have G heads (G \leq H), where G is factor of H. For instance, if H = 12, then G \in \left\{ 1, 2, 3, 4, 6, 12 \right\}.

In GQA, a query \mathbf{q}_{l_1, \color{red}{h}} with head \color{red}{h} is permitted to attend to a key \mathbf{k}_{l_2, \color{blue}{g}} and use value \mathbf{v}_{l_2, \color{blue}{g}} in computing its output with head \color{blue}{g} if

\color{red}{h} \equiv \color{blue}{g} \pmod{G} .

Thus, each head in keys and values is mapped to \frac{H}{G} \geq 1 heads in queries.

As an example, suppose H = 12 and G = 3. Then

  • Head \color{blue}{g} = 0 in keys and values is associated with heads \color{red}{h} = 0, 3, 6, 9 in queries.

  • Head \color{blue}{g} = 1 in keys and values is associated with heads \color{red}{h} = 1, 4, 7, 10 in queries.

  • Head \color{blue}{g} = 2 in keys and values is associated with heads \color{red}{h} = 2, 5, 8, 11 in queries.

Part 6 (5 points, non-coding task)

For \mathbf{M} \in \left\{ \mathbf{K}, \mathbf{V} \right\}, Denote the \mathbf{M}-projection matrix as

\mathbf{W}^{\mathbf{M}, GQA} = \begin{bmatrix} \mathbf{W}^{\mathbf{M}, GQA}_0 \\ \vdots \\ \mathbf{W}^{\mathbf{M}, GQA}_{G-1} \end{bmatrix}

Now, we concatenate \frac{H}{G} copies of the above matrix along axis 0:

\mathbf{\tilde W}^{\mathbf{M}, GQA} = \begin{bmatrix} \mathbf{W}^{\mathbf{M}, GQA} \\ \mathbf{W}^{\mathbf{M}, GQA} \\ \vdots \\ \mathbf{W}^{\mathbf{M}, GQA} \end{bmatrix}

What is the relationship between \text{rank} \left( \mathbf{\tilde W}^{\mathbf{M}, GQA} \right) and \text{rank} \left( \mathbf{W}^{\mathbf{M}, GQA} \right)?

  • Reasoning is required.

\color{green}{\text{### WRITE YOUR SOLUTION HERE ###}}

Let \left\{ \mathbf{w}^*_i : i \in \left\{ 0, 1, \cdots , r-1 \right\} \right\} be r linearly independent row vectors that span all row vectors \mathbf{W}^{\mathbf{M}, GQA}.

Because each row vector in \mathbf{W}^{\mathbf{M}, GQA} has \frac{H}{G} copies in \mathbf{\tilde W}^{\mathbf{M}, GQA}, we must have that \left\{ \mathbf{w}^*_i : i \in \left\{ 0, 1, \cdots , r-1 \right\} \right\} also spans \mathbf{\tilde W}^{\mathbf{M}, GQA}.

Therefore,

\text{rank} \left( \mathbf{\tilde W}^{\mathbf{M}, GQA} \right) = \text{rank} \left( \mathbf{W}^{\mathbf{M}, GQA} \right) .

\color{red}{\text{""" END OF THIS PART """}}