Spaces:
Sleeping
Sleeping
File size: 8,232 Bytes
64e81a6 eebb5f5 4bc692b 64e81a6 4bc692b 088c633 64e81a6 acf601d eebb5f5 088c633 64e81a6 4bc692b 64e81a6 088c633 64e81a6 4bc692b 088c633 eebb5f5 64e81a6 088c633 64e81a6 eebb5f5 4bc692b 64e81a6 088c633 4bc692b af300e8 7d6b9a9 4bc692b 64e81a6 4bc692b 64e81a6 4bc692b 64e81a6 c32f512 64e81a6 c32f512 64e81a6 8d2a8d3 64e81a6 760eded 64e81a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torchvision import models, transforms
import torch.nn as nn
import os
import json
import cv2
from PIL import Image
import gradio as gr
class MultimodalRiskBehaviorModel(nn.Module):
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
super(MultimodalRiskBehaviorModel, self).__init__()
# Text model using AutoModelForSequenceClassification
self.text_model_name = text_model_name
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=2)
# Visual model (ResNet50)
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
visual_feature_dim = self.visual_model.fc.in_features
self.visual_model.fc = nn.Identity()
# Fusion and classification layer setup
text_feature_dim = self.text_model.config.hidden_size
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden_dim, 1)
def forward(self, encoding, frames):
input_ids = encoding['input_ids'].squeeze(1).to(device)
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
# Extract text and visual features
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
frames = frames.to(device)
batch_size, num_frames, channels, height, width = frames.size()
frames = frames.view(batch_size * num_frames, channels, height, width)
visual_features = self.visual_model(frames)
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
# Combine and classify
combined_features = torch.cat((text_features, visual_features), dim=1)
x = self.dropout(torch.relu(self.fc1(combined_features)))
output = torch.sigmoid(self.fc2(x))
return output
def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
config = {
"text_model_name": self.text_model_name,
"hidden_dim": self.fc1.out_features
}
with open(os.path.join(save_directory, 'config.json'), 'w') as f:
json.dump(config, f)
@classmethod
def from_pretrained(cls, load_directory, map_location=None):
if os.path.exists(load_directory):
config_path = os.path.join(load_directory, 'config.json')
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
with open(config_path, 'r') as f:
config_dict = json.load(f)
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
state_dict = torch.load(state_dict_path, map_location=map_location)
model.load_state_dict(state_dict)
else:
hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2)
model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size)
model.text_model = hf_model
return model
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Function to load frames from a video
def load_frames_from_video(video_path, transform, num_frames=10):
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while frame_count < num_frames: # Limit to a number of frames for efficiency
success, frame = cap.read()
if not success:
break
# Convert frame (NumPy array) to PIL image and apply transformations
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frame = transform(frame)
frames.append(frame)
frame_count += 1
cap.release()
# Stack frames and add batch dimension (1, num_frames, channels, height, width)
frames = torch.stack(frames)
frames = frames.unsqueeze(0) # Add batch dimension
return frames
def predict_video(model, video_path, text_input, tokenizer, transform):
try:
# Set model to evaluation mode
model.eval()
# Tokenize the text input
encoding = tokenizer(
text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
)
encoding = {key: val.to(device) for key, val in encoding.items()}
# Load frames from the video
frames = load_frames_from_video(video_path, transform)
frames = frames.to(device)
# Log input shapes and devices
print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}")
# Perform forward pass through the model
with torch.no_grad():
output = model(encoding, frames)
# Apply sigmoid to get probability, then threshold to get prediction
prediction = (output.squeeze(-1) > 0.5).float()
return prediction.item()
except Exception as e:
print(f"Prediction error: {e}")
return "Error during prediction"
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define your video paths and captions
video_paths = [
'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM',
'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n',
'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj'
]
video_captions = [
"Everytime i start a diet ูู ู
ุฑุฉ ุฃุญุงูู ุฃุจุฏุฃ ุฑูุฌูู
๐ #dietmemes #funnyvideos #animetiktok",
"New sandwich from burger king ๐๐ #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King",
"all workout guides l!nked in bi0 // honestly huge moment ๐ Iโve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp"
]
def predict_risk(video_index):
video_path = video_paths[video_index]
text_input = video_captions[video_index]
# Make prediction
prediction = predict_video(model, video_path, text_input, tokenizer, transform)
# Return the corresponding label
if prediction == "Error during prediction":
return "Error during prediction"
return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior"
# Interface setup
with gr.Blocks() as interface:
gr.Markdown("# Risk Behavior Prediction")
gr.Markdown("Select a video to classify its behavior as risky or not.")
# Input option selector
video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video")
# Use function to return URLs which are handled by the Gradio `gr.Video` component
def show_selected_video(choice):
idx = int(choice.split()[-1]) - 1
return video_paths[idx], f"**Caption:** {video_captions[idx]}"
video_player = gr.Video(width=320, height=240)
caption_box = gr.Markdown()
video_selector.change(
fn=show_selected_video,
inputs=video_selector,
outputs=[video_player, caption_box]
)
# Prediction button and output
predict_button = gr.Button("Predict Risk")
output_text = gr.Textbox(label="Prediction")
predict_button.click(
fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1),
inputs=video_selector,
outputs=output_text
)
# Launch the app
interface.launch() |