Souha Ben Hassine commited on
Commit
088c633
·
1 Parent(s): af300e8

initial commit

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -12,16 +12,16 @@ class MultimodalRiskBehaviorModel(nn.Module):
12
  def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
13
  super(MultimodalRiskBehaviorModel, self).__init__()
14
 
15
- # Use AutoModelForSequenceClassification for classification tasks
16
  self.text_model_name = text_model_name
17
  self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
18
 
19
- # Visual model initialization with ResNet50
20
  self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
21
  visual_feature_dim = self.visual_model.fc.in_features
22
  self.visual_model.fc = nn.Identity()
23
 
24
- # Fusion and classification layers
25
  text_feature_dim = self.text_model.config.hidden_size
26
  self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
27
  self.dropout = nn.Dropout(dropout)
@@ -31,22 +31,32 @@ class MultimodalRiskBehaviorModel(nn.Module):
31
  input_ids = encoding['input_ids'].squeeze(1).to(device)
32
  attention_mask = encoding['attention_mask'].squeeze(1).to(device)
33
 
34
- # Text and visual features extraction
35
- text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits.squeeze(-1)
36
  frames = frames.to(device)
37
 
38
  batch_size, num_frames, channels, height, width = frames.size()
39
  frames = frames.view(batch_size * num_frames, channels, height, width)
40
  visual_features = self.visual_model(frames)
41
  visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
42
-
43
- # Combine features and classify
44
  combined_features = torch.cat((text_features, visual_features), dim=1)
45
  x = self.dropout(torch.relu(self.fc1(combined_features)))
46
  output = torch.sigmoid(self.fc2(x))
47
 
48
  return output
49
 
 
 
 
 
 
 
 
 
 
 
50
  @classmethod
51
  def from_pretrained(cls, load_directory, map_location=None):
52
  if os.path.exists(load_directory):
 
12
  def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
13
  super(MultimodalRiskBehaviorModel, self).__init__()
14
 
15
+ # Text model using AutoModelForSequenceClassification
16
  self.text_model_name = text_model_name
17
  self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
18
 
19
+ # Visual model (ResNet50)
20
  self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
21
  visual_feature_dim = self.visual_model.fc.in_features
22
  self.visual_model.fc = nn.Identity()
23
 
24
+ # Fusion and classification layer setup
25
  text_feature_dim = self.text_model.config.hidden_size
26
  self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
27
  self.dropout = nn.Dropout(dropout)
 
31
  input_ids = encoding['input_ids'].squeeze(1).to(device)
32
  attention_mask = encoding['attention_mask'].squeeze(1).to(device)
33
 
34
+ # Extract text and visual features
35
+ text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
36
  frames = frames.to(device)
37
 
38
  batch_size, num_frames, channels, height, width = frames.size()
39
  frames = frames.view(batch_size * num_frames, channels, height, width)
40
  visual_features = self.visual_model(frames)
41
  visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
42
+
43
+ # Combine and classify
44
  combined_features = torch.cat((text_features, visual_features), dim=1)
45
  x = self.dropout(torch.relu(self.fc1(combined_features)))
46
  output = torch.sigmoid(self.fc2(x))
47
 
48
  return output
49
 
50
+ def save_pretrained(self, save_directory):
51
+ os.makedirs(save_directory, exist_ok=True)
52
+ torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
53
+ config = {
54
+ "text_model_name": self.text_model_name,
55
+ "hidden_dim": self.fc1.out_features
56
+ }
57
+ with open(os.path.join(save_directory, 'config.json'), 'w') as f:
58
+ json.dump(config, f)
59
+
60
  @classmethod
61
  def from_pretrained(cls, load_directory, map_location=None):
62
  if os.path.exists(load_directory):