Spaces:
Sleeping
Sleeping
Souha Ben Hassine
commited on
Commit
·
64e81a6
1
Parent(s):
a05784d
initial commit
Browse files
app.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from transformers import AutoModel, AutoTokenizer
|
4 |
+
from torchvision import transforms, models
|
5 |
+
import torch.nn as nn
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import cv2
|
9 |
+
from PIL import Image
|
10 |
+
import gradio as gr
|
11 |
+
import torch
|
12 |
+
|
13 |
+
class MultimodalRiskBehaviorModel(nn.Module):
|
14 |
+
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
|
15 |
+
super(MultimodalRiskBehaviorModel, self).__init__()
|
16 |
+
|
17 |
+
# Text model
|
18 |
+
self.text_model_name = text_model_name
|
19 |
+
self.text_model = AutoModel.from_pretrained(text_model_name)
|
20 |
+
|
21 |
+
# Visual model (ResNet50)
|
22 |
+
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
23 |
+
|
24 |
+
# Save the original `fc.in_features` before replacing it
|
25 |
+
visual_feature_dim = self.visual_model.fc.in_features
|
26 |
+
self.visual_model.fc = nn.Identity() # Replace with identity layer
|
27 |
+
|
28 |
+
# Get the hidden dimension of the text model
|
29 |
+
text_feature_dim = self.text_model.config.hidden_size
|
30 |
+
|
31 |
+
# Fusion and classification layers
|
32 |
+
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
|
33 |
+
self.dropout = nn.Dropout(dropout)
|
34 |
+
self.fc2 = nn.Linear(hidden_dim, 1)
|
35 |
+
|
36 |
+
def forward(self, encoding, frames):
|
37 |
+
input_ids = encoding['input_ids'].squeeze(1).to(device)
|
38 |
+
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
|
39 |
+
|
40 |
+
# Text embeddings from BERT
|
41 |
+
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
|
42 |
+
|
43 |
+
# Visual features from ResNet50
|
44 |
+
frames = frames.to(device) # Move frames to the same device as the model
|
45 |
+
|
46 |
+
batch_size, num_frames, channels, height, width = frames.size()
|
47 |
+
frames = frames.view(batch_size * num_frames, channels, height, width)
|
48 |
+
visual_features = self.visual_model(frames)
|
49 |
+
|
50 |
+
# Reshape back to (batch_size, num_frames, visual_feature_dim)
|
51 |
+
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
|
52 |
+
|
53 |
+
# Combine text and visual features
|
54 |
+
combined_features = torch.cat((text_features, visual_features), dim=1)
|
55 |
+
|
56 |
+
# Pass through the classifier
|
57 |
+
x = self.dropout(torch.relu(self.fc1(combined_features)))
|
58 |
+
output = torch.sigmoid(self.fc2(x)) # Sigmoid for binary classification
|
59 |
+
|
60 |
+
return output
|
61 |
+
def save_pretrained(self, save_directory):
|
62 |
+
os.makedirs(save_directory, exist_ok=True)
|
63 |
+
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
|
64 |
+
config = {
|
65 |
+
"text_model_name": self.text_model_name,
|
66 |
+
"hidden_dim": self.fc1.out_features
|
67 |
+
}
|
68 |
+
with open(os.path.join(save_directory, 'config.json'), 'w') as f:
|
69 |
+
json.dump(config, f)
|
70 |
+
@classmethod
|
71 |
+
def from_pretrained(cls, load_directory, map_location=None):
|
72 |
+
with open(os.path.join(load_directory, 'config.json'), 'r') as f:
|
73 |
+
config = json.load(f)
|
74 |
+
|
75 |
+
model = cls(text_model_name=config["text_model_name"], hidden_dim=config["hidden_dim"])
|
76 |
+
|
77 |
+
state_dict = torch.load(
|
78 |
+
os.path.join(load_directory, 'pytorch_model.bin'),
|
79 |
+
map_location=map_location
|
80 |
+
)
|
81 |
+
model.load_state_dict(state_dict)
|
82 |
+
|
83 |
+
return model
|
84 |
+
|
85 |
+
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
|
87 |
+
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
|
88 |
+
|
89 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
90 |
+
model.to(device)
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
# Function to load frames from a video
|
95 |
+
def load_frames_from_video(video_path, transform, num_frames=10):
|
96 |
+
cap = cv2.VideoCapture(video_path)
|
97 |
+
frames = []
|
98 |
+
frame_count = 0
|
99 |
+
while frame_count < num_frames: # Limit to a number of frames for efficiency
|
100 |
+
success, frame = cap.read()
|
101 |
+
if not success:
|
102 |
+
break
|
103 |
+
# Convert frame (NumPy array) to PIL image and apply transformations
|
104 |
+
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
105 |
+
frame = transform(frame)
|
106 |
+
frames.append(frame)
|
107 |
+
frame_count += 1
|
108 |
+
cap.release()
|
109 |
+
|
110 |
+
# Stack frames and add batch dimension (1, num_frames, channels, height, width)
|
111 |
+
frames = torch.stack(frames)
|
112 |
+
frames = frames.unsqueeze(0) # Add batch dimension
|
113 |
+
return frames
|
114 |
+
|
115 |
+
# Prediction function for a single video
|
116 |
+
def predict_video(model, video_path, text_input, tokenizer, transform):
|
117 |
+
# Set model to evaluation mode
|
118 |
+
model.eval()
|
119 |
+
|
120 |
+
# Tokenize the text input and move to device
|
121 |
+
encoding = tokenizer(
|
122 |
+
text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
|
123 |
+
)
|
124 |
+
encoding = {key: val.to(device) for key, val in encoding.items()} # Ensure text input is on the device
|
125 |
+
|
126 |
+
# Load frames from the video and move to device
|
127 |
+
frames = load_frames_from_video(video_path, transform)
|
128 |
+
frames = frames.to(device) # Ensure frames are on the device
|
129 |
+
|
130 |
+
# Perform forward pass through the model
|
131 |
+
with torch.no_grad():
|
132 |
+
output = model(encoding, frames)
|
133 |
+
|
134 |
+
# Apply sigmoid to get probability, then threshold to get prediction
|
135 |
+
prediction = (output.squeeze(-1) > 0.5).float()
|
136 |
+
|
137 |
+
# Return the predicted label (0 or 1)
|
138 |
+
return prediction.item()
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
transform = transforms.Compose([
|
144 |
+
transforms.Resize((224, 224)),
|
145 |
+
transforms.ToTensor(),
|
146 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
147 |
+
])
|
148 |
+
|
149 |
+
|
150 |
+
# Define your video paths and captions
|
151 |
+
video_paths = [
|
152 |
+
'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM',
|
153 |
+
'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n',
|
154 |
+
'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj'
|
155 |
+
]
|
156 |
+
|
157 |
+
video_captions = [
|
158 |
+
"Everytime i start a diet كل مرة أحاول أبدأ ريجيم 😓 #dietmemes #funnyvideos #animetiktok",
|
159 |
+
"New sandwich from burger king 🍔👑 #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King",
|
160 |
+
"all workout guides l!nked in bi0 // honestly huge moment 😂 I’ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp"
|
161 |
+
]
|
162 |
+
|
163 |
+
|
164 |
+
def predict_risk(video_index):
|
165 |
+
video_path = video_paths[video_index]
|
166 |
+
text_input = video_captions[video_index]
|
167 |
+
|
168 |
+
# Make prediction
|
169 |
+
prediction = predict_video(model, video_path, text_input, tokenizer, transform)
|
170 |
+
|
171 |
+
# Return the corresponding label
|
172 |
+
return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior"
|
173 |
+
|
174 |
+
# Interface setup
|
175 |
+
with gr.Blocks() as interface:
|
176 |
+
gr.Markdown("# Risk Behavior Prediction")
|
177 |
+
gr.Markdown("Select a video to classify its behavior as risky or not.")
|
178 |
+
|
179 |
+
# Input option selector
|
180 |
+
video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video")
|
181 |
+
|
182 |
+
# Use function to return URLs which are handled by the Gradio `gr.Video` component
|
183 |
+
def show_selected_video(choice):
|
184 |
+
idx = int(choice.split()[-1]) - 1
|
185 |
+
return video_paths[idx], f"**Caption:** {video_captions[idx]}"
|
186 |
+
|
187 |
+
video_player = gr.Video()
|
188 |
+
caption_box = gr.Markdown()
|
189 |
+
|
190 |
+
video_selector.change(
|
191 |
+
fn=show_selected_video,
|
192 |
+
inputs=video_selector,
|
193 |
+
outputs=[video_player, caption_box]
|
194 |
+
)
|
195 |
+
|
196 |
+
# Prediction button and output
|
197 |
+
predict_button = gr.Button("Predict Risk")
|
198 |
+
output_text = gr.Textbox(label="Prediction")
|
199 |
+
|
200 |
+
predict_button.click(
|
201 |
+
fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1),
|
202 |
+
inputs=video_selector,
|
203 |
+
outputs=output_text
|
204 |
+
)
|
205 |
+
|
206 |
+
# Launch the app
|
207 |
+
interface.launch()
|