Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from huggingface_hub import PyTorchModelHubMixin | |
class TextEncoderHead(nn.Module): | |
def __init__(self, model): | |
super(TextEncoderHead, self).__init__() | |
self.model = model | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
# uncomment this for chemberta | |
# self.seq1 = nn.Sequential( | |
# nn.Flatten(), | |
# nn.Linear(767*256, 2000), | |
# nn.Dropout(0.3), | |
# nn.ReLU(), | |
# nn.Linear(2000, 512), | |
# nn.LayerNorm(512) | |
# ) | |
self.seq1 = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(768*256, 2000), | |
nn.Dropout(0.3), | |
nn.ReLU(), | |
nn.Linear(2000, 512), | |
nn.LayerNorm(512) | |
) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
# uncomment this for chemberta | |
# outputs = outputs.logits | |
outputs = outputs.last_hidden_state | |
outputs = self.seq1(outputs) | |
return outputs.contiguous() | |
class ImageEncoderHead(nn.Module): | |
def __init__(self, model): | |
super(ImageEncoderHead, self).__init__() | |
self.model = model | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
# for resnet model | |
# self.seq1 = nn.Sequential( | |
# nn.Flatten(), | |
# nn.Linear(512*7*7, 1000), | |
# nn.Linear(1000, 512), | |
# nn.LayerNorm(512) | |
# ) | |
# for vit model | |
self.seq1 = nn.Sequential( | |
nn.Linear(768, 1000), | |
nn.Dropout(0.3), | |
nn.ReLU(), | |
nn.Linear(1000, 512), | |
nn.LayerNorm(512) | |
) | |
def forward(self, pixel_values): | |
outputs = self.model(pixel_values) | |
outputs = outputs.last_hidden_state.mean(dim=1) | |
outputs = self.seq1(outputs) | |
return outputs.contiguous() | |
class CLIPChemistryModel(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, text_encoder, image_encoder): | |
super(CLIPChemistryModel, self).__init__() | |
self.text_encoder = text_encoder | |
self.image_encoder = image_encoder | |
def forward(self, image, input_ids, attention_mask): | |
# calculate the embeddings | |
ie = self.image_encoder(image) | |
te = self.text_encoder(input_ids, attention_mask) | |
return ie, te | |