Souha Ben Hassine commited on
Commit
4bc692b
·
1 Parent(s): eebb5f5

initial commit

Browse files
Files changed (1) hide show
  1. app.py +47 -21
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
3
- from torchvision import transforms, models
4
  import torch.nn as nn
5
  import os
6
  import json
@@ -11,22 +11,18 @@ import gradio as gr
11
  class MultimodalRiskBehaviorModel(nn.Module):
12
  def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
13
  super(MultimodalRiskBehaviorModel, self).__init__()
14
-
15
  # Text model using AutoModelForSequenceClassification
16
  self.text_model_name = text_model_name
17
  self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
18
 
19
  # Visual model (ResNet50)
20
  self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
21
-
22
- # Save the original `fc.in_features` before replacing it
23
  visual_feature_dim = self.visual_model.fc.in_features
24
- self.visual_model.fc = nn.Identity() # Replace with identity layer
25
 
26
- # Get the hidden dimension of the text model
27
  text_feature_dim = self.text_model.config.hidden_size
28
-
29
- # Fusion and classification layers
30
  self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
31
  self.dropout = nn.Dropout(dropout)
32
  self.fc2 = nn.Linear(hidden_dim, 1)
@@ -34,36 +30,66 @@ class MultimodalRiskBehaviorModel(nn.Module):
34
  def forward(self, encoding, frames):
35
  input_ids = encoding['input_ids'].squeeze(1).to(device)
36
  attention_mask = encoding['attention_mask'].squeeze(1).to(device)
37
-
38
- # Text embeddings from BERT
39
  text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
40
-
41
- # Visual features from ResNet50
42
  frames = frames.to(device)
43
 
44
  batch_size, num_frames, channels, height, width = frames.size()
45
  frames = frames.view(batch_size * num_frames, channels, height, width)
46
  visual_features = self.visual_model(frames)
47
-
48
- # Reshape back to (batch_size, num_frames, visual_feature_dim)
49
  visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
50
-
51
- # Combine text and visual features
52
  combined_features = torch.cat((text_features, visual_features), dim=1)
53
-
54
- # Pass through the classifier
55
  x = self.dropout(torch.relu(self.fc1(combined_features)))
56
  output = torch.sigmoid(self.fc2(x))
57
-
58
  return output
59
 
60
- # Loading the tokenizer and model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
62
- model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50', map_location='cpu')
 
63
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
  model.to(device)
65
 
66
 
 
67
  # Function to load frames from a video
68
  def load_frames_from_video(video_path, transform, num_frames=10):
69
  cap = cv2.VideoCapture(video_path)
 
1
  import torch
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
3
+ from torchvision import models, transforms
4
  import torch.nn as nn
5
  import os
6
  import json
 
11
  class MultimodalRiskBehaviorModel(nn.Module):
12
  def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
13
  super(MultimodalRiskBehaviorModel, self).__init__()
14
+
15
  # Text model using AutoModelForSequenceClassification
16
  self.text_model_name = text_model_name
17
  self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
18
 
19
  # Visual model (ResNet50)
20
  self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
 
 
21
  visual_feature_dim = self.visual_model.fc.in_features
22
+ self.visual_model.fc = nn.Identity()
23
 
24
+ # Fusion and classification layer setup
25
  text_feature_dim = self.text_model.config.hidden_size
 
 
26
  self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
27
  self.dropout = nn.Dropout(dropout)
28
  self.fc2 = nn.Linear(hidden_dim, 1)
 
30
  def forward(self, encoding, frames):
31
  input_ids = encoding['input_ids'].squeeze(1).to(device)
32
  attention_mask = encoding['attention_mask'].squeeze(1).to(device)
33
+
34
+ # Extract text and visual features
35
  text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
 
 
36
  frames = frames.to(device)
37
 
38
  batch_size, num_frames, channels, height, width = frames.size()
39
  frames = frames.view(batch_size * num_frames, channels, height, width)
40
  visual_features = self.visual_model(frames)
 
 
41
  visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
42
+
43
+ # Combine and classify
44
  combined_features = torch.cat((text_features, visual_features), dim=1)
 
 
45
  x = self.dropout(torch.relu(self.fc1(combined_features)))
46
  output = torch.sigmoid(self.fc2(x))
47
+
48
  return output
49
 
50
+ def save_pretrained(self, save_directory):
51
+ os.makedirs(save_directory, exist_ok=True)
52
+ torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
53
+ config = {
54
+ "text_model_name": self.text_model_name,
55
+ "hidden_dim": self.fc1.out_features
56
+ }
57
+ with open(os.path.join(save_directory, 'config.json'), 'w') as f:
58
+ json.dump(config, f)
59
+
60
+ @classmethod
61
+ def from_pretrained(cls, load_directory, map_location=None):
62
+ if os.path.exists(load_directory):
63
+ config_path = os.path.join(load_directory, 'config.json')
64
+ state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
65
+
66
+ if not os.path.exists(config_path):
67
+ raise FileNotFoundError(f"{config_path} not found.")
68
+ if not os.path.exists(state_dict_path):
69
+ raise FileNotFoundError(f"{state_dict_path} not found.")
70
+
71
+ with open(config_path, 'r') as f:
72
+ config_dict = json.load(f)
73
+ model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
74
+ state_dict = torch.load(state_dict_path, map_location=map_location)
75
+ model.load_state_dict(state_dict)
76
+
77
+ else:
78
+ config = AutoConfig.from_pretrained(load_directory)
79
+ hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, config=config)
80
+ model = cls(text_model_name=config.name_or_path, hidden_dim=hf_model.config.hidden_size)
81
+ model.text_model = hf_model
82
+
83
+ return model
84
+
85
  tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
86
+ model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
87
+
88
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
  model.to(device)
90
 
91
 
92
+
93
  # Function to load frames from a video
94
  def load_frames_from_video(video_path, transform, num_frames=10):
95
  cap = cv2.VideoCapture(video_path)