Ahmed-El-Sharkawy commited on
Commit
076763e
·
verified ·
1 Parent(s): 6b045fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -11,13 +11,18 @@ import torch
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Load the main classifier (Detector_best_model.pth)
14
- main_model = models.resnet18(weights=None) # Updated: weights=None
15
- num_ftrs = main_model.fc.in_features
 
16
  # main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
17
- main_model.fc = nn.Sequential(
18
- nn.Dropout(p=0.5), # Match the training architecture
19
- nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
20
- )
 
 
 
 
21
 
22
  main_model.load_state_dict(torch.load('best_model6for_RESNET18_After_First_Half_training_part.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
23
  main_model = main_model.to(device)
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Load the main classifier (Detector_best_model.pth)
14
+ main_model = models.mobilenet_v3_large(weights=None) # Updated: weights=None
15
+
16
+ #num_ftrs = main_model.fc.in_features
17
  # main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
18
+
19
+ num_ftrs = main_model.classifier[3].in_features
20
+ main_model.classifier[3] = nn.Linear(num_ftrs, 2)
21
+
22
+ # main_model.fc = nn.Sequential(
23
+ # nn.Dropout(p=0.5), # Match the training architecture
24
+ # nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
25
+ # )
26
 
27
  main_model.load_state_dict(torch.load('best_model6for_RESNET18_After_First_Half_training_part.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
28
  main_model = main_model.to(device)