USAAIO
May 14, 2025, 10:32pm
1
Part 2 (5 points, non-coding task)
For \mathbf{M} \in \left\{ \mathbf{Q}, \mathbf{K}, \mathbf{V} \right\} , We concatenate \mathbf{M} -projection matrices \left\{ \mathbf{W}^{\mathbf{M}}_h : h \in \left\{ 0, 1, \cdots , H-1 \right\} \right\} along axis 0 as
\mathbf{W}^{\mathbf{M}}
= \begin{bmatrix}
\mathbf{W}^{\mathbf{M}}_0 \\
\mathbf{W}^{\mathbf{M}}_1 \\
\vdots \\
\mathbf{W}^{\mathbf{M}}_{H-1}
\end{bmatrix} .
At each position l_1 in an attending sequence, we concatenate queries \left\{ \mathbf{q}_{l_1,h} : h \in \left\{ 0, 1, \cdots , H-1 \right\} \right\} along axis 0 to get
\mathbf{q}_{l_1}
= \begin{bmatrix}
\mathbf{q}_{l_1,0} \\
\mathbf{q}_{l_1,1} \\
\vdots \\
\mathbf{q}_{l_1,H-1}
\end{bmatrix} .
At each position l_2 in a being attended sequence, we concatenate keys/values \mathbf{m} \in \left\{ \mathbf{k}, \mathbf{v} \right\} \left\{ \mathbf{m}_{l_2,h} : h \in \left\{ 0, 1, \cdots , H-1 \right\} \right\} along axis 0 to get
\mathbf{m}_{l_2}
= \begin{bmatrix}
\mathbf{m}_{l_2,0} \\
\mathbf{m}_{l_2,1} \\
\vdots \\
\mathbf{m}_{l_2,H-1}
\end{bmatrix} .
Do the following tasks (Reasoning is not required).
What is the shape of \mathbf{W}^{\mathbf{M}} for \mathbf{M} \in \left\{ \mathbf{Q}, \mathbf{K}, \mathbf{V} \right\} ?
What is the shape of \mathbf{q}_{l_1} ?
What is the relationship between \mathbf{q}_{l_1} and \mathbf{W}^{\mathbf{Q}} ?
For \mathbf{m} \in \left\{ \mathbf{k}, \mathbf{v} \right\} , what is the shape of \mathbf{m}_{l_2} ?
What is the relationship between \mathbf{m}_{l_2} and \mathbf{W}^{\mathbf{M}} ?
USAAIO
May 14, 2025, 10:32pm
2
\color{green}{\text{### WRITE YOUR SOLUTION HERE ###}}
The shape of \mathbf{W}^{\mathbf{Q}} is \left( H \cdot D_{qk}, D_1 \right) .
The shape of \mathbf{W}^{\mathbf{K}} is \left( H \cdot D_{qk}, D_2 \right) .
The shape of \mathbf{W}^{\mathbf{V}}_h is \left( H \cdot D_v, D_2 \right) .
The shape of \mathbf{q}_{l_1} is \left( H \cdot D_{qk}, \right) .
\mathbf{q}_{l_1} = \mathbf{W}^{\mathbf{Q}} \mathbf{x}_{l_1} .
The shape of \mathbf{k}_{l_2} is \left( H \cdot D_{qk}, \right) .
The shape of \mathbf{v}_{l_2} is \left( H \cdot D_v, \right) .
\mathbf{k}_{l_2} = \mathbf{W}^{\mathbf{K}} \mathbf{y}_{l_2} .
\mathbf{v}_{l_2} = \mathbf{W}^{\mathbf{V}} \mathbf{y}_{l_2} .
\color{red}{\text{""" END OF THIS PART """}}