Update app.py
Browse files
app.py
CHANGED
@@ -11,13 +11,18 @@ import torch
|
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
13 |
# Load the main classifier (Detector_best_model.pth)
|
14 |
-
main_model = models.
|
15 |
-
|
|
|
16 |
# main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
22 |
main_model.load_state_dict(torch.load('best_model6for_RESNET18_After_First_Half_training_part.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
|
23 |
main_model = main_model.to(device)
|
|
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
13 |
# Load the main classifier (Detector_best_model.pth)
|
14 |
+
main_model = models.mobilenet_v3_large(weights=None) # Updated: weights=None
|
15 |
+
|
16 |
+
#num_ftrs = main_model.fc.in_features
|
17 |
# main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
|
18 |
+
|
19 |
+
num_ftrs = main_model.classifier[3].in_features
|
20 |
+
main_model.classifier[3] = nn.Linear(num_ftrs, 2)
|
21 |
+
|
22 |
+
# main_model.fc = nn.Sequential(
|
23 |
+
# nn.Dropout(p=0.5), # Match the training architecture
|
24 |
+
# nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
|
25 |
+
# )
|
26 |
|
27 |
main_model.load_state_dict(torch.load('best_model6for_RESNET18_After_First_Half_training_part.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
|
28 |
main_model = main_model.to(device)
|