Souha Ben Hassine commited on
Commit
64e81a6
·
1 Parent(s): a05784d

initial commit

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from torchvision import transforms, models
5
+ import torch.nn as nn
6
+ import os
7
+ import json
8
+ import cv2
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import torch
12
+
13
+ class MultimodalRiskBehaviorModel(nn.Module):
14
+ def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
15
+ super(MultimodalRiskBehaviorModel, self).__init__()
16
+
17
+ # Text model
18
+ self.text_model_name = text_model_name
19
+ self.text_model = AutoModel.from_pretrained(text_model_name)
20
+
21
+ # Visual model (ResNet50)
22
+ self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
23
+
24
+ # Save the original `fc.in_features` before replacing it
25
+ visual_feature_dim = self.visual_model.fc.in_features
26
+ self.visual_model.fc = nn.Identity() # Replace with identity layer
27
+
28
+ # Get the hidden dimension of the text model
29
+ text_feature_dim = self.text_model.config.hidden_size
30
+
31
+ # Fusion and classification layers
32
+ self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
33
+ self.dropout = nn.Dropout(dropout)
34
+ self.fc2 = nn.Linear(hidden_dim, 1)
35
+
36
+ def forward(self, encoding, frames):
37
+ input_ids = encoding['input_ids'].squeeze(1).to(device)
38
+ attention_mask = encoding['attention_mask'].squeeze(1).to(device)
39
+
40
+ # Text embeddings from BERT
41
+ text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
42
+
43
+ # Visual features from ResNet50
44
+ frames = frames.to(device) # Move frames to the same device as the model
45
+
46
+ batch_size, num_frames, channels, height, width = frames.size()
47
+ frames = frames.view(batch_size * num_frames, channels, height, width)
48
+ visual_features = self.visual_model(frames)
49
+
50
+ # Reshape back to (batch_size, num_frames, visual_feature_dim)
51
+ visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
52
+
53
+ # Combine text and visual features
54
+ combined_features = torch.cat((text_features, visual_features), dim=1)
55
+
56
+ # Pass through the classifier
57
+ x = self.dropout(torch.relu(self.fc1(combined_features)))
58
+ output = torch.sigmoid(self.fc2(x)) # Sigmoid for binary classification
59
+
60
+ return output
61
+ def save_pretrained(self, save_directory):
62
+ os.makedirs(save_directory, exist_ok=True)
63
+ torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
64
+ config = {
65
+ "text_model_name": self.text_model_name,
66
+ "hidden_dim": self.fc1.out_features
67
+ }
68
+ with open(os.path.join(save_directory, 'config.json'), 'w') as f:
69
+ json.dump(config, f)
70
+ @classmethod
71
+ def from_pretrained(cls, load_directory, map_location=None):
72
+ with open(os.path.join(load_directory, 'config.json'), 'r') as f:
73
+ config = json.load(f)
74
+
75
+ model = cls(text_model_name=config["text_model_name"], hidden_dim=config["hidden_dim"])
76
+
77
+ state_dict = torch.load(
78
+ os.path.join(load_directory, 'pytorch_model.bin'),
79
+ map_location=map_location
80
+ )
81
+ model.load_state_dict(state_dict)
82
+
83
+ return model
84
+
85
+
86
+ tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
87
+ model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
88
+
89
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
+ model.to(device)
91
+
92
+
93
+
94
+ # Function to load frames from a video
95
+ def load_frames_from_video(video_path, transform, num_frames=10):
96
+ cap = cv2.VideoCapture(video_path)
97
+ frames = []
98
+ frame_count = 0
99
+ while frame_count < num_frames: # Limit to a number of frames for efficiency
100
+ success, frame = cap.read()
101
+ if not success:
102
+ break
103
+ # Convert frame (NumPy array) to PIL image and apply transformations
104
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
105
+ frame = transform(frame)
106
+ frames.append(frame)
107
+ frame_count += 1
108
+ cap.release()
109
+
110
+ # Stack frames and add batch dimension (1, num_frames, channels, height, width)
111
+ frames = torch.stack(frames)
112
+ frames = frames.unsqueeze(0) # Add batch dimension
113
+ return frames
114
+
115
+ # Prediction function for a single video
116
+ def predict_video(model, video_path, text_input, tokenizer, transform):
117
+ # Set model to evaluation mode
118
+ model.eval()
119
+
120
+ # Tokenize the text input and move to device
121
+ encoding = tokenizer(
122
+ text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
123
+ )
124
+ encoding = {key: val.to(device) for key, val in encoding.items()} # Ensure text input is on the device
125
+
126
+ # Load frames from the video and move to device
127
+ frames = load_frames_from_video(video_path, transform)
128
+ frames = frames.to(device) # Ensure frames are on the device
129
+
130
+ # Perform forward pass through the model
131
+ with torch.no_grad():
132
+ output = model(encoding, frames)
133
+
134
+ # Apply sigmoid to get probability, then threshold to get prediction
135
+ prediction = (output.squeeze(-1) > 0.5).float()
136
+
137
+ # Return the predicted label (0 or 1)
138
+ return prediction.item()
139
+
140
+
141
+
142
+
143
+ transform = transforms.Compose([
144
+ transforms.Resize((224, 224)),
145
+ transforms.ToTensor(),
146
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
147
+ ])
148
+
149
+
150
+ # Define your video paths and captions
151
+ video_paths = [
152
+ 'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM',
153
+ 'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n',
154
+ 'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj'
155
+ ]
156
+
157
+ video_captions = [
158
+ "Everytime i start a diet كل مرة أحاول أبدأ ريجيم 😓 #dietmemes #funnyvideos #animetiktok",
159
+ "New sandwich from burger king 🍔👑 #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King",
160
+ "all workout guides l!nked in bi0 // honestly huge moment 😂 I’ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp"
161
+ ]
162
+
163
+
164
+ def predict_risk(video_index):
165
+ video_path = video_paths[video_index]
166
+ text_input = video_captions[video_index]
167
+
168
+ # Make prediction
169
+ prediction = predict_video(model, video_path, text_input, tokenizer, transform)
170
+
171
+ # Return the corresponding label
172
+ return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior"
173
+
174
+ # Interface setup
175
+ with gr.Blocks() as interface:
176
+ gr.Markdown("# Risk Behavior Prediction")
177
+ gr.Markdown("Select a video to classify its behavior as risky or not.")
178
+
179
+ # Input option selector
180
+ video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video")
181
+
182
+ # Use function to return URLs which are handled by the Gradio `gr.Video` component
183
+ def show_selected_video(choice):
184
+ idx = int(choice.split()[-1]) - 1
185
+ return video_paths[idx], f"**Caption:** {video_captions[idx]}"
186
+
187
+ video_player = gr.Video()
188
+ caption_box = gr.Markdown()
189
+
190
+ video_selector.change(
191
+ fn=show_selected_video,
192
+ inputs=video_selector,
193
+ outputs=[video_player, caption_box]
194
+ )
195
+
196
+ # Prediction button and output
197
+ predict_button = gr.Button("Predict Risk")
198
+ output_text = gr.Textbox(label="Prediction")
199
+
200
+ predict_button.click(
201
+ fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1),
202
+ inputs=video_selector,
203
+ outputs=output_text
204
+ )
205
+
206
+ # Launch the app
207
+ interface.launch()