LPX55
commited on
Commit
·
2484926
1
Parent(s):
4378fd8
Enhance prediction output and result display in app.py
Browse files- Modify prediction methods to generate structured output lists for each model
- Add model-specific output tracking with confidence scores and classification labels
- Update HTML results display to include model badges
- Adjust Gradio interface layout for better visualization
- Improve error handling and logging in prediction functions
- .gitignore +1 -0
- app.py +42 -18
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.env
|
app.py
CHANGED
@@ -67,7 +67,8 @@ def predict_image(img, confidence_threshold):
|
|
67 |
try:
|
68 |
prediction_1 = clf_1(img_pil)
|
69 |
result_1 = {pred['label']: pred['score'] for pred in prediction_1}
|
70 |
-
|
|
|
71 |
# Ensure the result dictionary contains all class names
|
72 |
for class_name in class_names_1:
|
73 |
if class_name not in result_1:
|
@@ -75,18 +76,23 @@ def predict_image(img, confidence_threshold):
|
|
75 |
# Check if either class meets the confidence threshold
|
76 |
if result_1['artificial'] >= confidence_threshold:
|
77 |
label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
|
|
|
78 |
elif result_1['real'] >= confidence_threshold:
|
79 |
label_1 = f"Real, Confidence: {result_1['real']:.4f}"
|
|
|
80 |
else:
|
81 |
label_1 = "Uncertain Classification"
|
|
|
|
|
82 |
except Exception as e:
|
83 |
label_1 = f"Error: {str(e)}"
|
84 |
-
|
85 |
# Predict using the second model
|
86 |
try:
|
87 |
prediction_2 = clf_2(img_pil)
|
88 |
result_2 = {pred['label']: pred['score'] for pred in prediction_2}
|
89 |
-
|
|
|
90 |
# Ensure the result dictionary contains all class names
|
91 |
for class_name in class_names_2:
|
92 |
if class_name not in result_2:
|
@@ -94,10 +100,13 @@ def predict_image(img, confidence_threshold):
|
|
94 |
# Check if either class meets the confidence threshold
|
95 |
if result_2['AI Image'] >= confidence_threshold:
|
96 |
label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
|
|
|
97 |
elif result_2['Real Image'] >= confidence_threshold:
|
98 |
label_2 = f"Real, Confidence: {result_2['Real Image']:.4f}"
|
|
|
99 |
else:
|
100 |
label_2 = "Uncertain Classification"
|
|
|
101 |
except Exception as e:
|
102 |
label_2 = f"Error: {str(e)}"
|
103 |
|
@@ -109,10 +118,11 @@ def predict_image(img, confidence_threshold):
|
|
109 |
logits_3 = outputs_3.logits
|
110 |
probabilities_3 = softmax(logits_3.cpu().numpy()[0])
|
111 |
result_3 = {
|
112 |
-
labels_3[
|
113 |
-
labels_3[
|
114 |
}
|
115 |
-
|
|
|
116 |
# Ensure the result dictionary contains all class names
|
117 |
for class_name in labels_3:
|
118 |
if class_name not in result_3:
|
@@ -120,10 +130,13 @@ def predict_image(img, confidence_threshold):
|
|
120 |
# Check if either class meets the confidence threshold
|
121 |
if result_3['AI'] >= confidence_threshold:
|
122 |
label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
|
|
|
123 |
elif result_3['Real'] >= confidence_threshold:
|
124 |
label_3 = f"Real, Confidence: {result_3['Real']:.4f}"
|
|
|
125 |
else:
|
126 |
label_3 = "Uncertain Classification"
|
|
|
127 |
except Exception as e:
|
128 |
label_3 = f"Error: {str(e)}"
|
129 |
|
@@ -135,9 +148,10 @@ def predict_image(img, confidence_threshold):
|
|
135 |
logits_4 = outputs_4.logits
|
136 |
probabilities_4 = softmax(logits_4.cpu().numpy()[0])
|
137 |
result_4 = {
|
138 |
-
labels_4[
|
139 |
-
labels_4[
|
140 |
}
|
|
|
141 |
print(result_4)
|
142 |
# Ensure the result dictionary contains all class names
|
143 |
for class_name in labels_4:
|
@@ -146,19 +160,27 @@ def predict_image(img, confidence_threshold):
|
|
146 |
# Check if either class meets the confidence threshold
|
147 |
if result_4['AI'] >= confidence_threshold:
|
148 |
label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
|
|
|
149 |
elif result_4['Real'] >= confidence_threshold:
|
150 |
label_4 = f"Real, Confidence: {result_4['Real']:.4f}"
|
|
|
151 |
else:
|
152 |
label_4 = "Uncertain Classification"
|
|
|
153 |
except Exception as e:
|
154 |
label_4 = f"Error: {str(e)}"
|
155 |
|
156 |
try:
|
|
|
157 |
img_bytes = convert_pil_to_bytes(img_pil)
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
160 |
print(response5)
|
161 |
label_5 = f"Result: {response5}"
|
|
|
162 |
except Exception as e:
|
163 |
label_5 = f"Error: {str(e)}"
|
164 |
|
@@ -170,32 +192,34 @@ def predict_image(img, confidence_threshold):
|
|
170 |
"Swin/SDXL-FLUX": label_4,
|
171 |
"GOAT": label_5
|
172 |
}
|
173 |
-
|
|
|
174 |
|
175 |
# Define a function to generate the HTML content
|
176 |
def generate_results_html(results):
|
|
|
177 |
html_content = f"""
|
178 |
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
179 |
<div class="container">
|
180 |
<div class="row mt-4">
|
181 |
<div class="col">
|
182 |
-
<h5>SwinV2/detect</h5>
|
183 |
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
184 |
</div>
|
185 |
<div class="col">
|
186 |
-
<h5>ViT/AI-vs-Real</h5>
|
187 |
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
188 |
</div>
|
189 |
<div class="col">
|
190 |
-
<h5>Swin/SDXL</h5>
|
191 |
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
192 |
</div>
|
193 |
<div class="col">
|
194 |
-
<h5>Swin/SDXL-FLUX</h5>
|
195 |
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
196 |
</div>
|
197 |
<div class="col">
|
198 |
-
<h5>GOAT</h5>
|
199 |
<p>{results.get("GOAT", "N/A")}</p>
|
200 |
</div>
|
201 |
</div>
|
@@ -214,11 +238,11 @@ with gr.Blocks() as iface:
|
|
214 |
gr.Markdown("# AI Generated Image Classification")
|
215 |
|
216 |
with gr.Row():
|
217 |
-
with gr.Column():
|
218 |
image_input = gr.Image(label="Upload Image to Analyze", sources=['upload'], type='pil')
|
219 |
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
|
220 |
inputs = [image_input, confidence_slider]
|
221 |
-
with gr.Column():
|
222 |
image_output = gr.Image(label="Processed Image")
|
223 |
# Custom HTML component to display results in 5 columns
|
224 |
results_html = gr.HTML(label="Model Predictions")
|
|
|
67 |
try:
|
68 |
prediction_1 = clf_1(img_pil)
|
69 |
result_1 = {pred['label']: pred['score'] for pred in prediction_1}
|
70 |
+
result_1output = [1, result_1['real'], result_1['artificial']]
|
71 |
+
print(result_1output)
|
72 |
# Ensure the result dictionary contains all class names
|
73 |
for class_name in class_names_1:
|
74 |
if class_name not in result_1:
|
|
|
76 |
# Check if either class meets the confidence threshold
|
77 |
if result_1['artificial'] >= confidence_threshold:
|
78 |
label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
|
79 |
+
result_1output += ['AI']
|
80 |
elif result_1['real'] >= confidence_threshold:
|
81 |
label_1 = f"Real, Confidence: {result_1['real']:.4f}"
|
82 |
+
result_1output += ['REAL']
|
83 |
else:
|
84 |
label_1 = "Uncertain Classification"
|
85 |
+
result_1output += ['UNCERTAIN']
|
86 |
+
|
87 |
except Exception as e:
|
88 |
label_1 = f"Error: {str(e)}"
|
89 |
+
print(result_1output)
|
90 |
# Predict using the second model
|
91 |
try:
|
92 |
prediction_2 = clf_2(img_pil)
|
93 |
result_2 = {pred['label']: pred['score'] for pred in prediction_2}
|
94 |
+
result_2output = [2, result_2['Real Image'], result_2['AI Image']]
|
95 |
+
print(result_2output)
|
96 |
# Ensure the result dictionary contains all class names
|
97 |
for class_name in class_names_2:
|
98 |
if class_name not in result_2:
|
|
|
100 |
# Check if either class meets the confidence threshold
|
101 |
if result_2['AI Image'] >= confidence_threshold:
|
102 |
label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
|
103 |
+
result_2output += ['AI']
|
104 |
elif result_2['Real Image'] >= confidence_threshold:
|
105 |
label_2 = f"Real, Confidence: {result_2['Real Image']:.4f}"
|
106 |
+
result_2output += ['REAL']
|
107 |
else:
|
108 |
label_2 = "Uncertain Classification"
|
109 |
+
result_2output += ['UNCERTAIN']
|
110 |
except Exception as e:
|
111 |
label_2 = f"Error: {str(e)}"
|
112 |
|
|
|
118 |
logits_3 = outputs_3.logits
|
119 |
probabilities_3 = softmax(logits_3.cpu().numpy()[0])
|
120 |
result_3 = {
|
121 |
+
labels_3[1]: float(probabilities_3[1]), # Real
|
122 |
+
labels_3[0]: float(probabilities_3[0]) # AI
|
123 |
}
|
124 |
+
result_3output = [3, float(probabilities_3[1]), float(probabilities_3[0])]
|
125 |
+
print(result_3output)
|
126 |
# Ensure the result dictionary contains all class names
|
127 |
for class_name in labels_3:
|
128 |
if class_name not in result_3:
|
|
|
130 |
# Check if either class meets the confidence threshold
|
131 |
if result_3['AI'] >= confidence_threshold:
|
132 |
label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
|
133 |
+
result_3output += ['AI']
|
134 |
elif result_3['Real'] >= confidence_threshold:
|
135 |
label_3 = f"Real, Confidence: {result_3['Real']:.4f}"
|
136 |
+
result_3output += ['REAL']
|
137 |
else:
|
138 |
label_3 = "Uncertain Classification"
|
139 |
+
result_3output += ['UNCERTAIN']
|
140 |
except Exception as e:
|
141 |
label_3 = f"Error: {str(e)}"
|
142 |
|
|
|
148 |
logits_4 = outputs_4.logits
|
149 |
probabilities_4 = softmax(logits_4.cpu().numpy()[0])
|
150 |
result_4 = {
|
151 |
+
labels_4[1]: float(probabilities_4[1]), # Real
|
152 |
+
labels_4[0]: float(probabilities_4[0]) # AI
|
153 |
}
|
154 |
+
result_4output = [4, float(probabilities_4[1]), float(probabilities_4[0])]
|
155 |
print(result_4)
|
156 |
# Ensure the result dictionary contains all class names
|
157 |
for class_name in labels_4:
|
|
|
160 |
# Check if either class meets the confidence threshold
|
161 |
if result_4['AI'] >= confidence_threshold:
|
162 |
label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
|
163 |
+
result_4output += ['AI']
|
164 |
elif result_4['Real'] >= confidence_threshold:
|
165 |
label_4 = f"Real, Confidence: {result_4['Real']:.4f}"
|
166 |
+
result_4output += ['REAL']
|
167 |
else:
|
168 |
label_4 = "Uncertain Classification"
|
169 |
+
result_4output += ['UNCERTAIN']
|
170 |
except Exception as e:
|
171 |
label_4 = f"Error: {str(e)}"
|
172 |
|
173 |
try:
|
174 |
+
result_5output = [5, 0.0, 0.0, 'MAINTENANCE']
|
175 |
img_bytes = convert_pil_to_bytes(img_pil)
|
176 |
+
# print(img)
|
177 |
+
# print(img_bytes)
|
178 |
+
response5_raw = call_inference(img)
|
179 |
+
print(response5_raw)
|
180 |
+
response5 = response5_raw
|
181 |
print(response5)
|
182 |
label_5 = f"Result: {response5}"
|
183 |
+
|
184 |
except Exception as e:
|
185 |
label_5 = f"Error: {str(e)}"
|
186 |
|
|
|
192 |
"Swin/SDXL-FLUX": label_4,
|
193 |
"GOAT": label_5
|
194 |
}
|
195 |
+
combined_outputs = [ result_1output, result_2output, result_3output, result_4output, result_5output ]
|
196 |
+
return img_pil, combined_outputs
|
197 |
|
198 |
# Define a function to generate the HTML content
|
199 |
def generate_results_html(results):
|
200 |
+
print(results)
|
201 |
html_content = f"""
|
202 |
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
203 |
<div class="container">
|
204 |
<div class="row mt-4">
|
205 |
<div class="col">
|
206 |
+
<h5>SwinV2/detect <span class="badge badge-secondary">M1</span></h5>
|
207 |
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
208 |
</div>
|
209 |
<div class="col">
|
210 |
+
<h5>ViT/AI-vs-Real <span class="badge badge-secondary">M2</span></h5>
|
211 |
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
212 |
</div>
|
213 |
<div class="col">
|
214 |
+
<h5>Swin/SDXL <span class="badge badge-secondary">M3</span></h5>
|
215 |
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
216 |
</div>
|
217 |
<div class="col">
|
218 |
+
<h5>Swin/SDXL-FLUX <span class="badge badge-secondary">M4</span></h5>
|
219 |
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
220 |
</div>
|
221 |
<div class="col">
|
222 |
+
<h5>GOAT <span class="badge badge-secondary">M5</span></h5>
|
223 |
<p>{results.get("GOAT", "N/A")}</p>
|
224 |
</div>
|
225 |
</div>
|
|
|
238 |
gr.Markdown("# AI Generated Image Classification")
|
239 |
|
240 |
with gr.Row():
|
241 |
+
with gr.Column(scale=2):
|
242 |
image_input = gr.Image(label="Upload Image to Analyze", sources=['upload'], type='pil')
|
243 |
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
|
244 |
inputs = [image_input, confidence_slider]
|
245 |
+
with gr.Column(scale=3):
|
246 |
image_output = gr.Image(label="Processed Image")
|
247 |
# Custom HTML component to display results in 5 columns
|
248 |
results_html = gr.HTML(label="Model Predictions")
|