Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
-
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
|
4 |
from torchvision import transforms
|
5 |
import torch
|
6 |
from PIL import Image
|
|
|
7 |
import warnings
|
|
|
8 |
|
9 |
# Suppress warnings
|
10 |
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
@@ -20,13 +22,24 @@ clf_1 = pipeline(model=model_1, task="image-classification", image_processor=ima
|
|
20 |
|
21 |
# Load the second model
|
22 |
model_2_path = "Heem2/AI-vs-Real-Image-Detection"
|
23 |
-
clf_2 = pipeline("image-classification", model=model_2_path
|
24 |
|
25 |
-
#
|
|
|
|
|
|
|
|
|
|
|
26 |
class_names_1 = ['artificial', 'real']
|
27 |
-
class_names_2 = ['AI Image', 'Real Image']
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
@spaces.GPU(duration=
|
30 |
def predict_image(img, confidence_threshold):
|
31 |
# Ensure the image is a PIL Image
|
32 |
if not isinstance(img, Image.Image):
|
@@ -81,10 +94,56 @@ def predict_image(img, confidence_threshold):
|
|
81 |
except Exception as e:
|
82 |
label_2 = f"Error: {str(e)}"
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
# Combine results
|
85 |
combined_results = {
|
86 |
"SwinV2": label_1,
|
87 |
-
"AI-vs-Real-Image-Detection": label_2
|
|
|
|
|
88 |
}
|
89 |
|
90 |
return combined_results
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification, AutoFeatureExtractor, AutoModelForImageClassification
|
4 |
from torchvision import transforms
|
5 |
import torch
|
6 |
from PIL import Image
|
7 |
+
import pandas as pd
|
8 |
import warnings
|
9 |
+
import math
|
10 |
|
11 |
# Suppress warnings
|
12 |
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
|
|
22 |
|
23 |
# Load the second model
|
24 |
model_2_path = "Heem2/AI-vs-Real-Image-Detection"
|
25 |
+
clf_2 = pipeline("image-classification", model=model_2_path)
|
26 |
|
27 |
+
# Load additional models
|
28 |
+
models = ["Organika/sdxl-detector", "cmckinle/sdxl-flux-detector"]
|
29 |
+
pipe0 = pipeline("image-classification", model=models[0])
|
30 |
+
pipe1 = pipeline("image-classification", model=models[1])
|
31 |
+
|
32 |
+
# Define class names for all models
|
33 |
class_names_1 = ['artificial', 'real']
|
34 |
+
class_names_2 = ['AI Image', 'Real Image']
|
35 |
+
class_names_3 = ['AI', 'Real']
|
36 |
+
class_names_4 = ['AI', 'Real']
|
37 |
+
|
38 |
+
def softmax(vector):
|
39 |
+
e = math.exp(vector - vector.max()) # for numerical stability
|
40 |
+
return e / e.sum()
|
41 |
|
42 |
+
@spaces.GPU(duration=10)
|
43 |
def predict_image(img, confidence_threshold):
|
44 |
# Ensure the image is a PIL Image
|
45 |
if not isinstance(img, Image.Image):
|
|
|
94 |
except Exception as e:
|
95 |
label_2 = f"Error: {str(e)}"
|
96 |
|
97 |
+
# Predict using the third model
|
98 |
+
try:
|
99 |
+
prediction_3 = pipe0(img_pil)
|
100 |
+
result_3 = {}
|
101 |
+
for idx, result in enumerate(prediction_3):
|
102 |
+
result_3[class_names_3[idx]] = float(result['score'])
|
103 |
+
|
104 |
+
# Ensure the result dictionary contains all class names
|
105 |
+
for class_name in class_names_3:
|
106 |
+
if class_name not in result_3:
|
107 |
+
result_3[class_name] = 0.0
|
108 |
+
|
109 |
+
# Check if either class meets the confidence threshold
|
110 |
+
if result_3['AI'] >= confidence_threshold:
|
111 |
+
label_3 = f"Label: AI, Confidence: {result_3['AI']:.4f}"
|
112 |
+
elif result_3['Real'] >= confidence_threshold:
|
113 |
+
label_3 = f"Label: Real, Confidence: {result_3['Real']:.4f}"
|
114 |
+
else:
|
115 |
+
label_3 = "Uncertain Classification"
|
116 |
+
except Exception as e:
|
117 |
+
label_3 = f"Error: {str(e)}"
|
118 |
+
|
119 |
+
# Predict using the fourth model
|
120 |
+
try:
|
121 |
+
prediction_4 = pipe1(img_pil)
|
122 |
+
result_4 = {}
|
123 |
+
for idx, result in enumerate(prediction_4):
|
124 |
+
result_4[class_names_4[idx]] = float(result['score'])
|
125 |
+
|
126 |
+
# Ensure the result dictionary contains all class names
|
127 |
+
for class_name in class_names_4:
|
128 |
+
if class_name not in result_4:
|
129 |
+
result_4[class_name] = 0.0
|
130 |
+
|
131 |
+
# Check if either class meets the confidence threshold
|
132 |
+
if result_4['AI'] >= confidence_threshold:
|
133 |
+
label_4 = f"Label: AI, Confidence: {result_4['AI']:.4f}"
|
134 |
+
elif result_4['Real'] >= confidence_threshold:
|
135 |
+
label_4 = f"Label: Real, Confidence: {result_4['Real']:.4f}"
|
136 |
+
else:
|
137 |
+
label_4 = "Uncertain Classification"
|
138 |
+
except Exception as e:
|
139 |
+
label_4 = f"Error: {str(e)}"
|
140 |
+
|
141 |
# Combine results
|
142 |
combined_results = {
|
143 |
"SwinV2": label_1,
|
144 |
+
"AI-vs-Real-Image-Detection": label_2,
|
145 |
+
"Organika/sdxl-detector": label_3,
|
146 |
+
"cmckinle/sdxl-flux-detector": label_4
|
147 |
}
|
148 |
|
149 |
return combined_results
|