Part 10 (5 points, coding task)
This question follows Part 9.
You are asked to define a function called GQA_2_MLA that performs the following tasks:
-
Input:
W_M_GQA: A numpy array with shape(r,D), whereris guaranteed to be a factor ofD(not something you need to worry about).
-
Outputs:
-
W_DKV_MLA: A numpy array with shape(r,D). -
W_UM_MLA: A numpy array with shape(D,r).
-
-
Things to do inside this function:
-
Compute
W_M_GQA_tildethat concatenatesD/rcopies ofW_M_GQAalong axis 0. -
Print the shapes of
W_UM_MLAandW_DKV_MLA. -
Print the mean-squared error between
W_M_GQA_tildeandW_UM_MLA @ W_DKV_MLA.
-
Hints:
-
You may use
np.linalg. -
PyTorch is not allowed.
-
No loop in your code.
After defining this function, test it with the input np.random.randn(4,24).