kesimeg commited on
Commit
6c5555a
1 Parent(s): 041895e

Upload mdeol.py

Browse files
Files changed (1) hide show
  1. model.py +39 -0
model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models
5
+ from transformers import AutoTokenizer, AutoModel
6
+
7
+ class Net(nn.Module):
8
+ def __init__(self):
9
+ super(Net, self).__init__()
10
+ self.image_encoder = models.resnet18()
11
+ self.image_encoder.fc = nn.Identity()
12
+
13
+ self.image_out = nn.Sequential(
14
+ nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 256)
15
+ )
16
+
17
+ self.text_encoder = AutoModel.from_pretrained("dbmdz/distilbert-base-turkish-cased")
18
+ self.target_token_idx = 0
19
+
20
+
21
+ self.text_out = nn.Sequential(
22
+ nn.Linear(768, 256), nn.ReLU(), nn.Linear(256, 256)
23
+ )
24
+
25
+
26
+ def forward(self, image, text, mask):
27
+
28
+ image_vec = self.image_encoder(image)
29
+
30
+ image_vec = self.image_out(image_vec.view(-1,512))
31
+
32
+ text_out = self.text_encoder(text, mask)
33
+ last_hidden_states = text_out.last_hidden_state
34
+
35
+ last_hidden_states = last_hidden_states[:,self.target_token_idx,:]
36
+
37
+ text_vec = self.text_out(last_hidden_states.view(-1,768))
38
+
39
+ return image_vec, text_vec