2025 USA-NA-AIO Round 2, Problem 2, Part 10

Part 10 (5 points, coding task)

This question follows Part 9.

You are asked to define a function called GQA_2_MLA that performs the following tasks:

  • Input:

    • W_M_GQA: A numpy array with shape (r,D), where r is guaranteed to be a factor of D (not something you need to worry about).
  • Outputs:

    • W_DKV_MLA: A numpy array with shape (r,D).

    • W_UM_MLA: A numpy array with shape (D,r).

  • Things to do inside this function:

    • Compute W_M_GQA_tilde that concatenates D/r copies of W_M_GQA along axis 0.

    • Print the shapes of W_UM_MLA and W_DKV_MLA.

    • Print the mean-squared error between W_M_GQA_tilde and W_UM_MLA @ W_DKV_MLA.

Hints:

  • You may use np.linalg.

  • PyTorch is not allowed.

  • No loop in your code.

After defining this function, test it with the input np.random.randn(4,24).

### WRITE YOUR SOLUTION HERE ###

def GQA_2_MLA(W_M_GQA):
    r = W_M_GQA.shape[0]
    D = W_M_GQA.shape[1]
    num_copies = D // r

    W_K_GQA_tilde = np.concatenate([W_M_GQA] * num_copies, axis=0)
    U, S, V = np.linalg.svd(W_K_GQA_tilde)
    W_UM_MLA = U[:, :r]
    W_DKV_MLA = S.reshape(-1,1)[:r, :] * V[:r, :]
    print(f"Shape of W_UK_MLA: {W_UM_MLA.shape}")
    print(f"Shape of W_DKV_MLA: {W_DKV_MLA.shape}")

    MSE = np.mean((W_K_GQA_tilde - W_UM_MLA @ W_DKV_MLA)**2)
    print(f"Mean-squared error: {MSE}")

    return W_DKV_MLA, W_UM_MLA

GQA_2_MLA(np.random.randn(4,24))

""" END OF THIS PART """