LPX55 commited on
Commit
2484926
·
1 Parent(s): 4378fd8

Enhance prediction output and result display in app.py

Browse files

- Modify prediction methods to generate structured output lists for each model
- Add model-specific output tracking with confidence scores and classification labels
- Update HTML results display to include model badges
- Adjust Gradio interface layout for better visualization
- Improve error handling and logging in prediction functions

Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +42 -18
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
app.py CHANGED
@@ -67,7 +67,8 @@ def predict_image(img, confidence_threshold):
67
  try:
68
  prediction_1 = clf_1(img_pil)
69
  result_1 = {pred['label']: pred['score'] for pred in prediction_1}
70
- print(result_1)
 
71
  # Ensure the result dictionary contains all class names
72
  for class_name in class_names_1:
73
  if class_name not in result_1:
@@ -75,18 +76,23 @@ def predict_image(img, confidence_threshold):
75
  # Check if either class meets the confidence threshold
76
  if result_1['artificial'] >= confidence_threshold:
77
  label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
 
78
  elif result_1['real'] >= confidence_threshold:
79
  label_1 = f"Real, Confidence: {result_1['real']:.4f}"
 
80
  else:
81
  label_1 = "Uncertain Classification"
 
 
82
  except Exception as e:
83
  label_1 = f"Error: {str(e)}"
84
-
85
  # Predict using the second model
86
  try:
87
  prediction_2 = clf_2(img_pil)
88
  result_2 = {pred['label']: pred['score'] for pred in prediction_2}
89
- print(result_2)
 
90
  # Ensure the result dictionary contains all class names
91
  for class_name in class_names_2:
92
  if class_name not in result_2:
@@ -94,10 +100,13 @@ def predict_image(img, confidence_threshold):
94
  # Check if either class meets the confidence threshold
95
  if result_2['AI Image'] >= confidence_threshold:
96
  label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
 
97
  elif result_2['Real Image'] >= confidence_threshold:
98
  label_2 = f"Real, Confidence: {result_2['Real Image']:.4f}"
 
99
  else:
100
  label_2 = "Uncertain Classification"
 
101
  except Exception as e:
102
  label_2 = f"Error: {str(e)}"
103
 
@@ -109,10 +118,11 @@ def predict_image(img, confidence_threshold):
109
  logits_3 = outputs_3.logits
110
  probabilities_3 = softmax(logits_3.cpu().numpy()[0])
111
  result_3 = {
112
- labels_3[0]: float(probabilities_3[0]), # AI
113
- labels_3[1]: float(probabilities_3[1]) # Real
114
  }
115
- print(result_3)
 
116
  # Ensure the result dictionary contains all class names
117
  for class_name in labels_3:
118
  if class_name not in result_3:
@@ -120,10 +130,13 @@ def predict_image(img, confidence_threshold):
120
  # Check if either class meets the confidence threshold
121
  if result_3['AI'] >= confidence_threshold:
122
  label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
 
123
  elif result_3['Real'] >= confidence_threshold:
124
  label_3 = f"Real, Confidence: {result_3['Real']:.4f}"
 
125
  else:
126
  label_3 = "Uncertain Classification"
 
127
  except Exception as e:
128
  label_3 = f"Error: {str(e)}"
129
 
@@ -135,9 +148,10 @@ def predict_image(img, confidence_threshold):
135
  logits_4 = outputs_4.logits
136
  probabilities_4 = softmax(logits_4.cpu().numpy()[0])
137
  result_4 = {
138
- labels_4[0]: float(probabilities_4[0]), # AI
139
- labels_4[1]: float(probabilities_4[1]) # Real
140
  }
 
141
  print(result_4)
142
  # Ensure the result dictionary contains all class names
143
  for class_name in labels_4:
@@ -146,19 +160,27 @@ def predict_image(img, confidence_threshold):
146
  # Check if either class meets the confidence threshold
147
  if result_4['AI'] >= confidence_threshold:
148
  label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
 
149
  elif result_4['Real'] >= confidence_threshold:
150
  label_4 = f"Real, Confidence: {result_4['Real']:.4f}"
 
151
  else:
152
  label_4 = "Uncertain Classification"
 
153
  except Exception as e:
154
  label_4 = f"Error: {str(e)}"
155
 
156
  try:
 
157
  img_bytes = convert_pil_to_bytes(img_pil)
158
- response5_raw = call_inference(img_bytes)
159
- response5 = response5_raw.json()
 
 
 
160
  print(response5)
161
  label_5 = f"Result: {response5}"
 
162
  except Exception as e:
163
  label_5 = f"Error: {str(e)}"
164
 
@@ -170,32 +192,34 @@ def predict_image(img, confidence_threshold):
170
  "Swin/SDXL-FLUX": label_4,
171
  "GOAT": label_5
172
  }
173
- return img_pil, combined_results
 
174
 
175
  # Define a function to generate the HTML content
176
  def generate_results_html(results):
 
177
  html_content = f"""
178
  <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
179
  <div class="container">
180
  <div class="row mt-4">
181
  <div class="col">
182
- <h5>SwinV2/detect</h5>
183
  <p>{results.get("SwinV2/detect", "N/A")}</p>
184
  </div>
185
  <div class="col">
186
- <h5>ViT/AI-vs-Real</h5>
187
  <p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
188
  </div>
189
  <div class="col">
190
- <h5>Swin/SDXL</h5>
191
  <p>{results.get("Swin/SDXL", "N/A")}</p>
192
  </div>
193
  <div class="col">
194
- <h5>Swin/SDXL-FLUX</h5>
195
  <p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
196
  </div>
197
  <div class="col">
198
- <h5>GOAT</h5>
199
  <p>{results.get("GOAT", "N/A")}</p>
200
  </div>
201
  </div>
@@ -214,11 +238,11 @@ with gr.Blocks() as iface:
214
  gr.Markdown("# AI Generated Image Classification")
215
 
216
  with gr.Row():
217
- with gr.Column():
218
  image_input = gr.Image(label="Upload Image to Analyze", sources=['upload'], type='pil')
219
  confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
220
  inputs = [image_input, confidence_slider]
221
- with gr.Column():
222
  image_output = gr.Image(label="Processed Image")
223
  # Custom HTML component to display results in 5 columns
224
  results_html = gr.HTML(label="Model Predictions")
 
67
  try:
68
  prediction_1 = clf_1(img_pil)
69
  result_1 = {pred['label']: pred['score'] for pred in prediction_1}
70
+ result_1output = [1, result_1['real'], result_1['artificial']]
71
+ print(result_1output)
72
  # Ensure the result dictionary contains all class names
73
  for class_name in class_names_1:
74
  if class_name not in result_1:
 
76
  # Check if either class meets the confidence threshold
77
  if result_1['artificial'] >= confidence_threshold:
78
  label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
79
+ result_1output += ['AI']
80
  elif result_1['real'] >= confidence_threshold:
81
  label_1 = f"Real, Confidence: {result_1['real']:.4f}"
82
+ result_1output += ['REAL']
83
  else:
84
  label_1 = "Uncertain Classification"
85
+ result_1output += ['UNCERTAIN']
86
+
87
  except Exception as e:
88
  label_1 = f"Error: {str(e)}"
89
+ print(result_1output)
90
  # Predict using the second model
91
  try:
92
  prediction_2 = clf_2(img_pil)
93
  result_2 = {pred['label']: pred['score'] for pred in prediction_2}
94
+ result_2output = [2, result_2['Real Image'], result_2['AI Image']]
95
+ print(result_2output)
96
  # Ensure the result dictionary contains all class names
97
  for class_name in class_names_2:
98
  if class_name not in result_2:
 
100
  # Check if either class meets the confidence threshold
101
  if result_2['AI Image'] >= confidence_threshold:
102
  label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
103
+ result_2output += ['AI']
104
  elif result_2['Real Image'] >= confidence_threshold:
105
  label_2 = f"Real, Confidence: {result_2['Real Image']:.4f}"
106
+ result_2output += ['REAL']
107
  else:
108
  label_2 = "Uncertain Classification"
109
+ result_2output += ['UNCERTAIN']
110
  except Exception as e:
111
  label_2 = f"Error: {str(e)}"
112
 
 
118
  logits_3 = outputs_3.logits
119
  probabilities_3 = softmax(logits_3.cpu().numpy()[0])
120
  result_3 = {
121
+ labels_3[1]: float(probabilities_3[1]), # Real
122
+ labels_3[0]: float(probabilities_3[0]) # AI
123
  }
124
+ result_3output = [3, float(probabilities_3[1]), float(probabilities_3[0])]
125
+ print(result_3output)
126
  # Ensure the result dictionary contains all class names
127
  for class_name in labels_3:
128
  if class_name not in result_3:
 
130
  # Check if either class meets the confidence threshold
131
  if result_3['AI'] >= confidence_threshold:
132
  label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
133
+ result_3output += ['AI']
134
  elif result_3['Real'] >= confidence_threshold:
135
  label_3 = f"Real, Confidence: {result_3['Real']:.4f}"
136
+ result_3output += ['REAL']
137
  else:
138
  label_3 = "Uncertain Classification"
139
+ result_3output += ['UNCERTAIN']
140
  except Exception as e:
141
  label_3 = f"Error: {str(e)}"
142
 
 
148
  logits_4 = outputs_4.logits
149
  probabilities_4 = softmax(logits_4.cpu().numpy()[0])
150
  result_4 = {
151
+ labels_4[1]: float(probabilities_4[1]), # Real
152
+ labels_4[0]: float(probabilities_4[0]) # AI
153
  }
154
+ result_4output = [4, float(probabilities_4[1]), float(probabilities_4[0])]
155
  print(result_4)
156
  # Ensure the result dictionary contains all class names
157
  for class_name in labels_4:
 
160
  # Check if either class meets the confidence threshold
161
  if result_4['AI'] >= confidence_threshold:
162
  label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
163
+ result_4output += ['AI']
164
  elif result_4['Real'] >= confidence_threshold:
165
  label_4 = f"Real, Confidence: {result_4['Real']:.4f}"
166
+ result_4output += ['REAL']
167
  else:
168
  label_4 = "Uncertain Classification"
169
+ result_4output += ['UNCERTAIN']
170
  except Exception as e:
171
  label_4 = f"Error: {str(e)}"
172
 
173
  try:
174
+ result_5output = [5, 0.0, 0.0, 'MAINTENANCE']
175
  img_bytes = convert_pil_to_bytes(img_pil)
176
+ # print(img)
177
+ # print(img_bytes)
178
+ response5_raw = call_inference(img)
179
+ print(response5_raw)
180
+ response5 = response5_raw
181
  print(response5)
182
  label_5 = f"Result: {response5}"
183
+
184
  except Exception as e:
185
  label_5 = f"Error: {str(e)}"
186
 
 
192
  "Swin/SDXL-FLUX": label_4,
193
  "GOAT": label_5
194
  }
195
+ combined_outputs = [ result_1output, result_2output, result_3output, result_4output, result_5output ]
196
+ return img_pil, combined_outputs
197
 
198
  # Define a function to generate the HTML content
199
  def generate_results_html(results):
200
+ print(results)
201
  html_content = f"""
202
  <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
203
  <div class="container">
204
  <div class="row mt-4">
205
  <div class="col">
206
+ <h5>SwinV2/detect <span class="badge badge-secondary">M1</span></h5>
207
  <p>{results.get("SwinV2/detect", "N/A")}</p>
208
  </div>
209
  <div class="col">
210
+ <h5>ViT/AI-vs-Real <span class="badge badge-secondary">M2</span></h5>
211
  <p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
212
  </div>
213
  <div class="col">
214
+ <h5>Swin/SDXL <span class="badge badge-secondary">M3</span></h5>
215
  <p>{results.get("Swin/SDXL", "N/A")}</p>
216
  </div>
217
  <div class="col">
218
+ <h5>Swin/SDXL-FLUX <span class="badge badge-secondary">M4</span></h5>
219
  <p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
220
  </div>
221
  <div class="col">
222
+ <h5>GOAT <span class="badge badge-secondary">M5</span></h5>
223
  <p>{results.get("GOAT", "N/A")}</p>
224
  </div>
225
  </div>
 
238
  gr.Markdown("# AI Generated Image Classification")
239
 
240
  with gr.Row():
241
+ with gr.Column(scale=2):
242
  image_input = gr.Image(label="Upload Image to Analyze", sources=['upload'], type='pil')
243
  confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
244
  inputs = [image_input, confidence_slider]
245
+ with gr.Column(scale=3):
246
  image_output = gr.Image(label="Processed Image")
247
  # Custom HTML component to display results in 5 columns
248
  results_html = gr.HTML(label="Model Predictions")