iLoRA / model /mlp_projector.py
MingLi
fork and bug fix from https://github.com/AkaliKong/iLoRA
9f13819
raw
history blame contribute delete
361 Bytes
from torch import nn
class MlpProjector(nn.Module):
def __init__(self, rec_size=64, llm_size=4096):
super().__init__()
self.mlp_proj = nn.Sequential(
nn.Linear(rec_size, llm_size),
nn.GELU(),
nn.Linear(llm_size, llm_size)
)
def forward(self, x):
x = self.mlp_proj(x)
return x