models_demo / modeling.py
Amitz244's picture
Upload modeling.py
bd52cb4 verified
import torch
import torch.nn as nn
from transformers import CLIPModel
from peft import LoraConfig, get_peft_model
class MLP(nn.Module):
def __init__(self, input_dim=768, hidden_dim1=512, hidden_dim2=256, output_dim=8,dropout_rate=0.5):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim1)
self.relu1 = nn.ReLU()
self.dropout = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_dim2, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.dropout(x)
x = self.fc3(x)
return x
class clip_lora_model(nn.Module):
def __init__(self, input_dim=768, hidden_dim1=512, hidden_dim2=256, output_dim=8,dropout_rate=0.5,r=16,lora_alpha=8):
super(clip_lora_model, self).__init__()
self.output_dim=output_dim
self.mlp = MLP(input_dim, hidden_dim1, hidden_dim2, output_dim,dropout_rate)
model_name = 'openai/clip-vit-large-patch14'
model = CLIPModel.from_pretrained(model_name)
self.proj = model.visual_projection
for param in self.proj.parameters():
param.requires_grad = False
encoder = model.vision_model
target_modules = ["k_proj", "v_proj", "q_proj"]
config = LoraConfig(
r=int(r),
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=0.1,
bias="none",
)
self.model = get_peft_model(encoder, config)
def forward(self, x):
model_outputs = self.model(x)
image_embeds = model_outputs[1]
model_outputs = self.proj(image_embeds)
outputs = self.mlp(model_outputs)
return outputs