Spaces:
Sleeping
Sleeping
Souha Ben Hassine
commited on
Commit
·
088c633
1
Parent(s):
af300e8
initial commit
Browse files
app.py
CHANGED
@@ -12,16 +12,16 @@ class MultimodalRiskBehaviorModel(nn.Module):
|
|
12 |
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
|
13 |
super(MultimodalRiskBehaviorModel, self).__init__()
|
14 |
|
15 |
-
#
|
16 |
self.text_model_name = text_model_name
|
17 |
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
|
18 |
|
19 |
-
# Visual model
|
20 |
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
21 |
visual_feature_dim = self.visual_model.fc.in_features
|
22 |
self.visual_model.fc = nn.Identity()
|
23 |
|
24 |
-
# Fusion and classification
|
25 |
text_feature_dim = self.text_model.config.hidden_size
|
26 |
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
|
27 |
self.dropout = nn.Dropout(dropout)
|
@@ -31,22 +31,32 @@ class MultimodalRiskBehaviorModel(nn.Module):
|
|
31 |
input_ids = encoding['input_ids'].squeeze(1).to(device)
|
32 |
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
|
33 |
|
34 |
-
#
|
35 |
-
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
|
36 |
frames = frames.to(device)
|
37 |
|
38 |
batch_size, num_frames, channels, height, width = frames.size()
|
39 |
frames = frames.view(batch_size * num_frames, channels, height, width)
|
40 |
visual_features = self.visual_model(frames)
|
41 |
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
|
42 |
-
|
43 |
-
# Combine
|
44 |
combined_features = torch.cat((text_features, visual_features), dim=1)
|
45 |
x = self.dropout(torch.relu(self.fc1(combined_features)))
|
46 |
output = torch.sigmoid(self.fc2(x))
|
47 |
|
48 |
return output
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
@classmethod
|
51 |
def from_pretrained(cls, load_directory, map_location=None):
|
52 |
if os.path.exists(load_directory):
|
|
|
12 |
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
|
13 |
super(MultimodalRiskBehaviorModel, self).__init__()
|
14 |
|
15 |
+
# Text model using AutoModelForSequenceClassification
|
16 |
self.text_model_name = text_model_name
|
17 |
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
|
18 |
|
19 |
+
# Visual model (ResNet50)
|
20 |
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
21 |
visual_feature_dim = self.visual_model.fc.in_features
|
22 |
self.visual_model.fc = nn.Identity()
|
23 |
|
24 |
+
# Fusion and classification layer setup
|
25 |
text_feature_dim = self.text_model.config.hidden_size
|
26 |
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
|
27 |
self.dropout = nn.Dropout(dropout)
|
|
|
31 |
input_ids = encoding['input_ids'].squeeze(1).to(device)
|
32 |
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
|
33 |
|
34 |
+
# Extract text and visual features
|
35 |
+
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
|
36 |
frames = frames.to(device)
|
37 |
|
38 |
batch_size, num_frames, channels, height, width = frames.size()
|
39 |
frames = frames.view(batch_size * num_frames, channels, height, width)
|
40 |
visual_features = self.visual_model(frames)
|
41 |
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
|
42 |
+
|
43 |
+
# Combine and classify
|
44 |
combined_features = torch.cat((text_features, visual_features), dim=1)
|
45 |
x = self.dropout(torch.relu(self.fc1(combined_features)))
|
46 |
output = torch.sigmoid(self.fc2(x))
|
47 |
|
48 |
return output
|
49 |
|
50 |
+
def save_pretrained(self, save_directory):
|
51 |
+
os.makedirs(save_directory, exist_ok=True)
|
52 |
+
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
|
53 |
+
config = {
|
54 |
+
"text_model_name": self.text_model_name,
|
55 |
+
"hidden_dim": self.fc1.out_features
|
56 |
+
}
|
57 |
+
with open(os.path.join(save_directory, 'config.json'), 'w') as f:
|
58 |
+
json.dump(config, f)
|
59 |
+
|
60 |
@classmethod
|
61 |
def from_pretrained(cls, load_directory, map_location=None):
|
62 |
if os.path.exists(load_directory):
|