Ahmed-El-Sharkawy commited on
Commit
68f66aa
·
1 Parent(s): 533dac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -43
app.py CHANGED
@@ -5,68 +5,66 @@ import torch.nn as nn
5
  import torchvision.transforms as transforms
6
  import torchvision.models as models
7
  import os
8
- import torch
9
 
10
  # Set device
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- # Load the main classifier (Detector_best_model.pth)
14
- main_model = models.resnet18(weights=None) # Updated: weights=None
15
- num_ftrs = main_model.fc.in_features
16
- # main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
17
- main_model.fc = nn.Sequential(
18
- nn.Dropout(p=0.5), # Match the training architecture
19
- nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
20
- )
21
-
22
- main_model.load_state_dict(torch.load('best_model9.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
23
- main_model = main_model.to(device)
24
- main_model.eval()
25
 
26
- # Define class names for the classifier based on the Folder structure
27
  classes_name = ['AI-generated Image', 'Real Image']
28
 
29
- def convert_to_rgb(image):
30
- """
31
- Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
32
- This is to avoid transparency issues during model training.
33
- """
34
- if image.mode in ('P', 'RGBA'):
35
- return image.convert('RGB')
36
- return image
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Define preprocessing transformations (same used during training)
39
  preprocess = transforms.Compose([
40
- transforms.Lambda(convert_to_rgb),
41
- transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image
42
  transforms.ToTensor(),
43
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
44
  ])
45
 
46
- def classify_image(image):
47
- # Open the image using PIL
48
- image = Image.fromarray(image)
49
 
50
- # Preprocess the image
51
  input_image = preprocess(image).unsqueeze(0).to(device)
52
 
53
- # Perform inference with the main classifier
54
  with torch.no_grad():
55
- output = main_model(input_image)
56
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
57
  confidence, predicted_class = torch.max(probabilities, 0)
58
 
59
- # Main classifier result
60
- main_prediction = classes_name[predicted_class]
61
- main_confidence = confidence.item()
62
-
63
- return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"
64
 
65
- # Gradio interface (updated)
66
- image_input = gr.Image(image_mode="RGB") # Removed shape argument
 
67
  output_text = gr.Textbox()
68
 
69
- gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text],
70
- title="Detect AI-generated Image ",
71
- description="Upload an image to Detected AI-generated Image .",
72
- theme="default").launch()
 
5
  import torchvision.transforms as transforms
6
  import torchvision.models as models
7
  import os
 
8
 
9
  # Set device
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ # Define available models and their corresponding file names
13
+ model_options = {
14
+ "ResNet-18": (models.resnet18, "resnet18_model.pth"),
15
+ "MobileNetV3 Large": (models.mobilenet_v3_large, "mobilenet_v3_large_model.pth"),
16
+ "MobileNetV3 Small": (models.mobilenet_v3_small, "mobilenet_v3_small_model.pth")
17
+ }
 
 
 
 
 
 
18
 
 
19
  classes_name = ['AI-generated Image', 'Real Image']
20
 
21
+ def load_model(model_name):
22
+ model_func, model_path = model_options[model_name]
23
+ model = model_func(weights=None) # Load model without pretrained weights
24
+
25
+ if "resnet" in model_name.lower():
26
+ num_ftrs = model.fc.in_features
27
+ model.fc = nn.Sequential(
28
+ nn.Dropout(p=0.5),
29
+ nn.Linear(num_ftrs, 2)
30
+ )
31
+ else: # For MobileNetV3
32
+ num_ftrs = model.classifier[-1].in_features
33
+ model.classifier[-1] = nn.Linear(num_ftrs, 2)
34
+
35
+ model.load_state_dict(torch.load(model_path, map_location=device))
36
+ model = model.to(device)
37
+ model.eval()
38
+
39
+ return model
40
 
41
+ # Define preprocessing transformations
42
  preprocess = transforms.Compose([
43
+ transforms.Lambda(lambda img: img.convert('RGB') if img.mode in ('P', 'RGBA') else img),
44
+ transforms.Resize((224, 224)),
45
  transforms.ToTensor(),
46
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
47
  ])
48
 
49
+ def classify_image(image, model_name):
50
+ model = load_model(model_name)
 
51
 
52
+ image = Image.fromarray(image)
53
  input_image = preprocess(image).unsqueeze(0).to(device)
54
 
 
55
  with torch.no_grad():
56
+ output = model(input_image)
57
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
58
  confidence, predicted_class = torch.max(probabilities, 0)
59
 
60
+ return f"Image is: {classes_name[predicted_class]} (Confidence: {confidence.item():.4f})"
 
 
 
 
61
 
62
+ # Gradio interface
63
+ image_input = gr.Image(image_mode="RGB")
64
+ model_choice = gr.Radio(choices=list(model_options.keys()), label="Choose Model", value="ResNet-18")
65
  output_text = gr.Textbox()
66
 
67
+ gr.Interface(fn=classify_image, inputs=[image_input, model_choice], outputs=[output_text],
68
+ title="AI-Generated Image Detector",
69
+ description="Upload an image and choose a model to detect AI-generated images.",
70
+ theme="default").launch()