2025 USA-NA-AIO Round 2, Problem 2, Part 2

Part 2 (5 points, non-coding task)

For \mathbf{M} \in \left\{ \mathbf{Q}, \mathbf{K}, \mathbf{V} \right\}, We concatenate \mathbf{M}-projection matrices \left\{ \mathbf{W}^{\mathbf{M}}_h : h \in \left\{ 0, 1, \cdots , H-1 \right\} \right\} along axis 0 as

\mathbf{W}^{\mathbf{M}} = \begin{bmatrix} \mathbf{W}^{\mathbf{M}}_0 \\ \mathbf{W}^{\mathbf{M}}_1 \\ \vdots \\ \mathbf{W}^{\mathbf{M}}_{H-1} \end{bmatrix} .

At each position l_1 in an attending sequence, we concatenate queries \left\{ \mathbf{q}_{l_1,h} : h \in \left\{ 0, 1, \cdots , H-1 \right\} \right\} along axis 0 to get

\mathbf{q}_{l_1} = \begin{bmatrix} \mathbf{q}_{l_1,0} \\ \mathbf{q}_{l_1,1} \\ \vdots \\ \mathbf{q}_{l_1,H-1} \end{bmatrix} .

At each position l_2 in a being attended sequence, we concatenate keys/values \mathbf{m} \in \left\{ \mathbf{k}, \mathbf{v} \right\} \left\{ \mathbf{m}_{l_2,h} : h \in \left\{ 0, 1, \cdots , H-1 \right\} \right\} along axis 0 to get

\mathbf{m}_{l_2} = \begin{bmatrix} \mathbf{m}_{l_2,0} \\ \mathbf{m}_{l_2,1} \\ \vdots \\ \mathbf{m}_{l_2,H-1} \end{bmatrix} .

Do the following tasks (Reasoning is not required).

  1. What is the shape of \mathbf{W}^{\mathbf{M}} for \mathbf{M} \in \left\{ \mathbf{Q}, \mathbf{K}, \mathbf{V} \right\}?

  2. What is the shape of \mathbf{q}_{l_1}?

  3. What is the relationship between \mathbf{q}_{l_1} and \mathbf{W}^{\mathbf{Q}}?

  4. For \mathbf{m} \in \left\{ \mathbf{k}, \mathbf{v} \right\} , what is the shape of \mathbf{m}_{l_2}?

  5. What is the relationship between \mathbf{m}_{l_2} and \mathbf{W}^{\mathbf{M}}?

\color{green}{\text{### WRITE YOUR SOLUTION HERE ###}}

  1. The shape of \mathbf{W}^{\mathbf{Q}} is \left( H \cdot D_{qk}, D_1 \right).

    The shape of \mathbf{W}^{\mathbf{K}} is \left( H \cdot D_{qk}, D_2 \right).

    The shape of \mathbf{W}^{\mathbf{V}}_h is \left( H \cdot D_v, D_2 \right).

  2. The shape of \mathbf{q}_{l_1} is \left( H \cdot D_{qk}, \right).

\mathbf{q}_{l_1} = \mathbf{W}^{\mathbf{Q}} \mathbf{x}_{l_1} .
  1. The shape of \mathbf{k}_{l_2} is \left( H \cdot D_{qk}, \right).

    The shape of \mathbf{v}_{l_2} is \left( H \cdot D_v, \right).

\mathbf{k}_{l_2} = \mathbf{W}^{\mathbf{K}} \mathbf{y}_{l_2} .
\mathbf{v}_{l_2} = \mathbf{W}^{\mathbf{V}} \mathbf{y}_{l_2} .

\color{red}{\text{""" END OF THIS PART """}}