Part 10 (5 points, coding task)
This question follows Part 9.
You are asked to define a function called GQA_2_MLA
that performs the following tasks:
-
Input:
W_M_GQA
: A numpy array with shape(r,D)
, wherer
is guaranteed to be a factor ofD
(not something you need to worry about).
-
Outputs:
-
W_DKV_MLA
: A numpy array with shape(r,D)
. -
W_UM_MLA
: A numpy array with shape(D,r)
.
-
-
Things to do inside this function:
-
Compute
W_M_GQA_tilde
that concatenatesD/r
copies ofW_M_GQA
along axis 0. -
Print the shapes of
W_UM_MLA
andW_DKV_MLA
. -
Print the mean-squared error between
W_M_GQA_tilde
andW_UM_MLA @ W_DKV_MLA
.
-
Hints:
-
You may use
np.linalg
. -
PyTorch is not allowed.
-
No loop in your code.
After defining this function, test it with the input np.random.randn(4,24)
.