import torch import torch.nn as nn from transformers import CLIPModel from peft import LoraConfig, get_peft_model class MLP(nn.Module): def __init__(self, input_dim=768, hidden_dim1=512, hidden_dim2=256, output_dim=8,dropout_rate=0.5): super(MLP, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim1) self.relu1 = nn.ReLU() self.dropout = nn.Dropout(dropout_rate) self.fc2 = nn.Linear(hidden_dim1, hidden_dim2) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(hidden_dim2, output_dim) def forward(self, x): x = self.fc1(x) x = self.relu1(x) x = self.dropout(x) x = self.fc2(x) x = self.relu2(x) x = self.dropout(x) x = self.fc3(x) return x class clip_lora_model(nn.Module): def __init__(self, input_dim=768, hidden_dim1=512, hidden_dim2=256, output_dim=8,dropout_rate=0.5,r=16,lora_alpha=8): super(clip_lora_model, self).__init__() self.output_dim=output_dim self.mlp = MLP(input_dim, hidden_dim1, hidden_dim2, output_dim,dropout_rate) model_name = 'openai/clip-vit-large-patch14' model = CLIPModel.from_pretrained(model_name) self.proj = model.visual_projection for param in self.proj.parameters(): param.requires_grad = False encoder = model.vision_model target_modules = ["k_proj", "v_proj", "q_proj"] config = LoraConfig( r=int(r), lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=0.1, bias="none", ) self.model = get_peft_model(encoder, config) def forward(self, x): model_outputs = self.model(x) image_embeds = model_outputs[1] model_outputs = self.proj(image_embeds) outputs = self.mlp(model_outputs) return outputs