File size: 361 Bytes
9f13819
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torch import nn

class MlpProjector(nn.Module):
    def __init__(self, rec_size=64, llm_size=4096):
        super().__init__()
        self.mlp_proj = nn.Sequential(
            nn.Linear(rec_size, llm_size),
            nn.GELU(),
            nn.Linear(llm_size, llm_size)
        )

    def forward(self, x):
        x = self.mlp_proj(x)
        return x