Spaces:
Sleeping
Sleeping
Souha Ben Hassine
commited on
Commit
·
4bc692b
1
Parent(s):
eebb5f5
initial commit
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
3 |
-
from torchvision import
|
4 |
import torch.nn as nn
|
5 |
import os
|
6 |
import json
|
@@ -11,22 +11,18 @@ import gradio as gr
|
|
11 |
class MultimodalRiskBehaviorModel(nn.Module):
|
12 |
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
|
13 |
super(MultimodalRiskBehaviorModel, self).__init__()
|
14 |
-
|
15 |
# Text model using AutoModelForSequenceClassification
|
16 |
self.text_model_name = text_model_name
|
17 |
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
|
18 |
|
19 |
# Visual model (ResNet50)
|
20 |
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
21 |
-
|
22 |
-
# Save the original `fc.in_features` before replacing it
|
23 |
visual_feature_dim = self.visual_model.fc.in_features
|
24 |
-
self.visual_model.fc = nn.Identity()
|
25 |
|
26 |
-
#
|
27 |
text_feature_dim = self.text_model.config.hidden_size
|
28 |
-
|
29 |
-
# Fusion and classification layers
|
30 |
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
|
31 |
self.dropout = nn.Dropout(dropout)
|
32 |
self.fc2 = nn.Linear(hidden_dim, 1)
|
@@ -34,36 +30,66 @@ class MultimodalRiskBehaviorModel(nn.Module):
|
|
34 |
def forward(self, encoding, frames):
|
35 |
input_ids = encoding['input_ids'].squeeze(1).to(device)
|
36 |
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
|
37 |
-
|
38 |
-
#
|
39 |
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
|
40 |
-
|
41 |
-
# Visual features from ResNet50
|
42 |
frames = frames.to(device)
|
43 |
|
44 |
batch_size, num_frames, channels, height, width = frames.size()
|
45 |
frames = frames.view(batch_size * num_frames, channels, height, width)
|
46 |
visual_features = self.visual_model(frames)
|
47 |
-
|
48 |
-
# Reshape back to (batch_size, num_frames, visual_feature_dim)
|
49 |
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
|
50 |
-
|
51 |
-
# Combine
|
52 |
combined_features = torch.cat((text_features, visual_features), dim=1)
|
53 |
-
|
54 |
-
# Pass through the classifier
|
55 |
x = self.dropout(torch.relu(self.fc1(combined_features)))
|
56 |
output = torch.sigmoid(self.fc2(x))
|
57 |
-
|
58 |
return output
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
|
62 |
-
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50'
|
|
|
63 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
model.to(device)
|
65 |
|
66 |
|
|
|
67 |
# Function to load frames from a video
|
68 |
def load_frames_from_video(video_path, transform, num_frames=10):
|
69 |
cap = cv2.VideoCapture(video_path)
|
|
|
1 |
import torch
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
3 |
+
from torchvision import models, transforms
|
4 |
import torch.nn as nn
|
5 |
import os
|
6 |
import json
|
|
|
11 |
class MultimodalRiskBehaviorModel(nn.Module):
|
12 |
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
|
13 |
super(MultimodalRiskBehaviorModel, self).__init__()
|
14 |
+
|
15 |
# Text model using AutoModelForSequenceClassification
|
16 |
self.text_model_name = text_model_name
|
17 |
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
|
18 |
|
19 |
# Visual model (ResNet50)
|
20 |
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
|
|
|
21 |
visual_feature_dim = self.visual_model.fc.in_features
|
22 |
+
self.visual_model.fc = nn.Identity()
|
23 |
|
24 |
+
# Fusion and classification layer setup
|
25 |
text_feature_dim = self.text_model.config.hidden_size
|
|
|
|
|
26 |
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
|
27 |
self.dropout = nn.Dropout(dropout)
|
28 |
self.fc2 = nn.Linear(hidden_dim, 1)
|
|
|
30 |
def forward(self, encoding, frames):
|
31 |
input_ids = encoding['input_ids'].squeeze(1).to(device)
|
32 |
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
|
33 |
+
|
34 |
+
# Extract text and visual features
|
35 |
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
|
|
|
|
|
36 |
frames = frames.to(device)
|
37 |
|
38 |
batch_size, num_frames, channels, height, width = frames.size()
|
39 |
frames = frames.view(batch_size * num_frames, channels, height, width)
|
40 |
visual_features = self.visual_model(frames)
|
|
|
|
|
41 |
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
|
42 |
+
|
43 |
+
# Combine and classify
|
44 |
combined_features = torch.cat((text_features, visual_features), dim=1)
|
|
|
|
|
45 |
x = self.dropout(torch.relu(self.fc1(combined_features)))
|
46 |
output = torch.sigmoid(self.fc2(x))
|
47 |
+
|
48 |
return output
|
49 |
|
50 |
+
def save_pretrained(self, save_directory):
|
51 |
+
os.makedirs(save_directory, exist_ok=True)
|
52 |
+
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
|
53 |
+
config = {
|
54 |
+
"text_model_name": self.text_model_name,
|
55 |
+
"hidden_dim": self.fc1.out_features
|
56 |
+
}
|
57 |
+
with open(os.path.join(save_directory, 'config.json'), 'w') as f:
|
58 |
+
json.dump(config, f)
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def from_pretrained(cls, load_directory, map_location=None):
|
62 |
+
if os.path.exists(load_directory):
|
63 |
+
config_path = os.path.join(load_directory, 'config.json')
|
64 |
+
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
|
65 |
+
|
66 |
+
if not os.path.exists(config_path):
|
67 |
+
raise FileNotFoundError(f"{config_path} not found.")
|
68 |
+
if not os.path.exists(state_dict_path):
|
69 |
+
raise FileNotFoundError(f"{state_dict_path} not found.")
|
70 |
+
|
71 |
+
with open(config_path, 'r') as f:
|
72 |
+
config_dict = json.load(f)
|
73 |
+
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
|
74 |
+
state_dict = torch.load(state_dict_path, map_location=map_location)
|
75 |
+
model.load_state_dict(state_dict)
|
76 |
+
|
77 |
+
else:
|
78 |
+
config = AutoConfig.from_pretrained(load_directory)
|
79 |
+
hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, config=config)
|
80 |
+
model = cls(text_model_name=config.name_or_path, hidden_dim=hf_model.config.hidden_size)
|
81 |
+
model.text_model = hf_model
|
82 |
+
|
83 |
+
return model
|
84 |
+
|
85 |
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
|
86 |
+
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
|
87 |
+
|
88 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
89 |
model.to(device)
|
90 |
|
91 |
|
92 |
+
|
93 |
# Function to load frames from a video
|
94 |
def load_frames_from_video(video_path, transform, num_frames=10):
|
95 |
cap = cv2.VideoCapture(video_path)
|