Part 13 (5 points, coding task)
Do the following tasks:
-
Define a function called
reduced_matrices
.-
Input arguments
W_DKV, W_UK, W_UV, W_Q, W_O, H
-
Outputs
W_K_MLA_hat, W_V_MLA_hat, W_Q_MLA_hat, W_O_MLA_hat
-
Requirment of your code
-
The code of computing each output must be in one line
-
Loop is not allowed
-
-
-
Set your device as
gpu
:
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
3. Construct the following synthetic data:
D = 1024
H = 32
D_qkv = D // H
r = 50
W_DKV = torch.randn(r, D)
W_UK = torch.randn(D, r)
W_UV = torch.randn(D, r)
W_Q = torch.randn(D, D)
W_O = torch.randn(D, D)
B = 32
L_1 = 100
L_2 = 300
x = torch.randn(B, L_1, D).to(device)
y = torch.randn(B, L_2, D).to(device)
4. Study a vanilla attention model
* Initialize the model
```
model_MHA_vanilla = MyMHA(D, D, D_qkv, D_qkv, H)
```
* Update model paramteres
* `model_MHA_vanilla.W_K.weight, model_MHA_vanilla.W_V.weight, model_MHA_vanilla.W_Q.weight, model_MHA_vanilla.W_O.weight`
* Compute the output
```
output_vanilla = model_MHA_vanilla(x, y)
```
5. Study a reduced attention model
* Initialize the model
```
model_MHA_reduced = MyMHA(D, D, r, r, H)
```
* Update model paramteres
* `model_MHA_reduced.W_K.weight, model_MHA_reduced.W_V.weight, model_MHA_reduced.W_Q.weight, model_MHA_reduced.W_O.weight`
* Compute the output
```
output_reduced = model_MHA_reduced(x, y)
```
6. Check the correctness of the reduced model by computing and printing a relative error:
relative_error = mse_output**.5 / torch.mean(output_vanilla2).5