Ahmed-El-Sharkawy commited on
Commit
bc6f5db
·
verified ·
1 Parent(s): 8a8cb96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -10,34 +10,43 @@ import os
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
  # Define available models and their corresponding file names
13
- model_options = {
14
- "ResNet-18": (models.resnet18, "resnet18_model.pth"),
15
- "MobileNetV3 Large": (models.mobilenet_v3_large, "best_model3_mobilenetv3_large.pth"),
16
- "MobileNetV3 Small": (models.mobilenet_v3_small, "best_model13_mobilenetv3_small.pth")
17
  }
18
 
19
- classes_name = ['AI-generated Image', 'Real Image']
20
-
21
- def load_model(model_name):
22
- model_func, model_path = model_options[model_name]
23
- model = model_func(weights=None) # Load model without pretrained weights
24
-
25
- if "resnet" in model_name.lower():
26
  num_ftrs = model.fc.in_features
27
  model.fc = nn.Sequential(
28
  nn.Dropout(p=0.5),
29
  nn.Linear(num_ftrs, 2)
30
  )
31
- else: # For MobileNetV3
 
 
 
 
 
32
  num_ftrs = model.classifier[-1].in_features
33
  model.classifier[-1] = nn.Linear(num_ftrs, 2)
 
 
34
 
 
 
 
 
 
35
  model.load_state_dict(torch.load(model_path, map_location=device))
36
  model = model.to(device)
37
  model.eval()
38
-
39
  return model
40
 
 
 
41
  # Define preprocessing transformations
42
  preprocess = transforms.Compose([
43
  transforms.Lambda(lambda img: img.convert('RGB') if img.mode in ('P', 'RGBA') else img),
@@ -63,10 +72,10 @@ def classify_image(image, model_name):
63
 
64
  # Gradio interface
65
  image_input = gr.Image(image_mode="RGB")
66
- model_choice = gr.Dropdown(choices=list(model_options.keys()), label="Choose Model", value="ResNet-18")
67
  output_text = gr.Textbox()
68
 
69
  gr.Interface(fn=classify_image, inputs=[image_input, model_choice], outputs=[output_text],
70
  title="AI-Generated Image Detector",
71
  description="Upload an image and choose a model to detect AI-generated images.",
72
- theme="default").launch()
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
  # Define available models and their corresponding file names
13
+ model_paths = {
14
+ "ResNet-18": "resnet18_model.pth",
15
+ "MobileNetV3 Large": "best_model3_mobilenetv3_large.pth",
16
+ "MobileNetV3 Small": "best_model13_mobilenetv3_small.pth"
17
  }
18
 
19
+ def initialize_model(model_name):
20
+ if model_name == "ResNet-18":
21
+ model = models.resnet18(weights=None)
 
 
 
 
22
  num_ftrs = model.fc.in_features
23
  model.fc = nn.Sequential(
24
  nn.Dropout(p=0.5),
25
  nn.Linear(num_ftrs, 2)
26
  )
27
+ elif model_name == "MobileNetV3 Large":
28
+ model = models.mobilenet_v3_large(weights=None)
29
+ num_ftrs = model.classifier[-1].in_features
30
+ model.classifier[-1] = nn.Linear(num_ftrs, 2)
31
+ elif model_name == "MobileNetV3 Small":
32
+ model = models.mobilenet_v3_small(weights=None)
33
  num_ftrs = model.classifier[-1].in_features
34
  model.classifier[-1] = nn.Linear(num_ftrs, 2)
35
+ else:
36
+ raise ValueError("Invalid model name")
37
 
38
+ return model
39
+
40
+ def load_model(model_name):
41
+ model = initialize_model(model_name)
42
+ model_path = model_paths[model_name]
43
  model.load_state_dict(torch.load(model_path, map_location=device))
44
  model = model.to(device)
45
  model.eval()
 
46
  return model
47
 
48
+ classes_name = ['AI-generated Image', 'Real Image']
49
+
50
  # Define preprocessing transformations
51
  preprocess = transforms.Compose([
52
  transforms.Lambda(lambda img: img.convert('RGB') if img.mode in ('P', 'RGBA') else img),
 
72
 
73
  # Gradio interface
74
  image_input = gr.Image(image_mode="RGB")
75
+ model_choice = gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model", value="ResNet-18")
76
  output_text = gr.Textbox()
77
 
78
  gr.Interface(fn=classify_image, inputs=[image_input, model_choice], outputs=[output_text],
79
  title="AI-Generated Image Detector",
80
  description="Upload an image and choose a model to detect AI-generated images.",
81
+ theme="default").launch()