Ahmed-El-Sharkawy commited on
Commit
a131fad
·
1 Parent(s): 4481824

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -51
app.py CHANGED
@@ -5,77 +5,68 @@ import torch.nn as nn
5
  import torchvision.transforms as transforms
6
  import torchvision.models as models
7
  import os
 
8
 
9
  # Set device
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
- # Define available models and their corresponding file names
13
- model_paths = {
14
- "ResNet-18": "best_model9.pth",
15
- "MobileNetV3 Large": "best_model3_mobilenetv3_large.pth",
16
- "MobileNetV3 Small": "best_model13_mobilenetv3_small.pth"
17
- }
18
-
19
- def initialize_model(model_name):
20
- if model_name == "ResNet-18":
21
- model = models.resnet18(weights=None)
22
- num_ftrs = model.fc.in_features
23
- model.fc = nn.Sequential(
24
- nn.Dropout(p=0.5),
25
- nn.Linear(num_ftrs, 2)
26
- )
27
- elif model_name == "MobileNetV3 Large":
28
- model = models.mobilenet_v3_large(weights=None)
29
- num_ftrs = model.classifier[-1].in_features
30
- model.classifier[-1] = nn.Linear(num_ftrs, 2)
31
- elif model_name == "MobileNetV3 Small":
32
- model = models.mobilenet_v3_small(weights=None)
33
- num_ftrs = model.classifier[3].in_features
34
- model.classifier[3] = nn.Linear(num_ftrs, 2)
35
- else:
36
- raise ValueError("Invalid model name")
37
-
38
- return model
39
 
40
- def load_model(model_name):
41
- model = initialize_model(model_name)
42
- model_path = model_paths[model_name]
43
- model.load_state_dict(torch.load(model_path, map_location=device))
44
- model = model.to(device)
45
- model.eval()
46
- return model
47
 
 
48
  classes_name = ['AI-generated Image', 'Real Image']
49
 
50
- # Define preprocessing transformations
 
 
 
 
 
 
 
 
 
51
  preprocess = transforms.Compose([
52
- transforms.Lambda(lambda img: img.convert('RGB') if img.mode in ('P', 'RGBA') else img),
53
- transforms.Resize((224, 224)),
54
  transforms.ToTensor(),
55
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
56
  ])
57
 
58
- def classify_image(image, model_name):
59
- if image is None:
60
- return "Please upload an image."
61
-
62
- model = load_model(model_name)
63
  image = Image.fromarray(image)
 
 
64
  input_image = preprocess(image).unsqueeze(0).to(device)
65
 
 
66
  with torch.no_grad():
67
- output = model(input_image)
68
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
69
  confidence, predicted_class = torch.max(probabilities, 0)
70
 
71
- return f"Image is: {classes_name[predicted_class]} (Confidence: {confidence.item():.4f})"
 
 
 
 
72
 
73
- # Gradio interface
74
- image_input = gr.Image(image_mode="RGB")
75
- model_choice = gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model", value="ResNet-18")
76
  output_text = gr.Textbox()
77
 
78
- gr.Interface(fn=classify_image, inputs=[image_input, model_choice], outputs=[output_text],
79
- title="AI-Generated Image Detector",
80
- description="Upload an image and choose a model to detect AI-generated images.",
81
  theme="default").launch()
 
5
  import torchvision.transforms as transforms
6
  import torchvision.models as models
7
  import os
8
+ import torch
9
 
10
  # Set device
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
+ # Load the main classifier (Detector_best_model.pth)
14
+ main_model = models.resnet18(weights=None) # Updated: weights=None
15
+ num_ftrs = main_model.fc.in_features
16
+ # main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
17
+ main_model.fc = nn.Sequential(
18
+ nn.Dropout(p=0.5), # Match the training architecture
19
+ nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
20
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ main_model.load_state_dict(torch.load('best_model9.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
23
+ main_model = main_model.to(device)
24
+ main_model.eval()
 
 
 
 
25
 
26
+ # Define class names for the classifier based on the Folder structure
27
  classes_name = ['AI-generated Image', 'Real Image']
28
 
29
+ def convert_to_rgb(image):
30
+ """
31
+ Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
32
+ This is to avoid transparency issues during model training.
33
+ """
34
+ if image.mode in ('P', 'RGBA'):
35
+ return image.convert('RGB')
36
+ return image
37
+
38
+ # Define preprocessing transformations (same used during training)
39
  preprocess = transforms.Compose([
40
+ transforms.Lambda(convert_to_rgb),
41
+ transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image
42
  transforms.ToTensor(),
43
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
44
  ])
45
 
46
+ def classify_image(image):
47
+ # Open the image using PIL
 
 
 
48
  image = Image.fromarray(image)
49
+
50
+ # Preprocess the image
51
  input_image = preprocess(image).unsqueeze(0).to(device)
52
 
53
+ # Perform inference with the main classifier
54
  with torch.no_grad():
55
+ output = main_model(input_image)
56
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
57
  confidence, predicted_class = torch.max(probabilities, 0)
58
 
59
+ # Main classifier result
60
+ main_prediction = classes_name[predicted_class]
61
+ main_confidence = confidence.item()
62
+
63
+ return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"
64
 
65
+ # Gradio interface (updated)
66
+ image_input = gr.Image(image_mode="RGB") # Removed shape argument
 
67
  output_text = gr.Textbox()
68
 
69
+ gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text],
70
+ title="Detect AI-generated Image ",
71
+ description="Upload an image to Detected AI-generated Image .",
72
  theme="default").launch()