File size: 2,469 Bytes
0d38ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

class TextEncoderHead(nn.Module):
    def __init__(self, model):
        super(TextEncoderHead, self).__init__()
        self.model = model
        for param in self.model.parameters():
            param.requires_grad = False
        # uncomment this for chemberta
        # self.seq1 = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(767*256, 2000),
        #     nn.Dropout(0.3),
        #     nn.ReLU(),
        #     nn.Linear(2000, 512),
        #     nn.LayerNorm(512)
        # )
        self.seq1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(768*256, 2000),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(2000, 512),
            nn.LayerNorm(512)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        # uncomment this for chemberta
        # outputs = outputs.logits
        outputs = outputs.last_hidden_state
        outputs = self.seq1(outputs)
        return outputs.contiguous()
    
class ImageEncoderHead(nn.Module):
    def __init__(self, model):
        super(ImageEncoderHead, self).__init__()
        self.model = model
        for param in self.model.parameters():
            param.requires_grad = False
        # for resnet model
        # self.seq1 = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(512*7*7, 1000),
        #     nn.Linear(1000, 512),
        #     nn.LayerNorm(512)
        # )
        # for vit model
        self.seq1 = nn.Sequential(
            nn.Linear(768, 1000),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(1000, 512),
            nn.LayerNorm(512)
        )

    
    def forward(self, pixel_values):
        outputs = self.model(pixel_values)
        outputs = outputs.last_hidden_state.mean(dim=1)
        outputs = self.seq1(outputs)
        return outputs.contiguous()
    
class CLIPChemistryModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, text_encoder, image_encoder):
        super(CLIPChemistryModel, self).__init__()
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder

    def forward(self, image, input_ids, attention_mask):
        # calculate the embeddings
        ie = self.image_encoder(image)
        te = self.text_encoder(input_ids, attention_mask)
        return ie, te