LPX55 commited on
Commit
bd23e86
·
verified ·
1 Parent(s): 4439436

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -18
app.py CHANGED
@@ -30,11 +30,8 @@ clf_2 = pipeline("image-classification", model=model_2_path, device=device)
30
 
31
  # Load additional models
32
  models = ["Organika/sdxl-detector", "cmckinle/sdxl-flux-detector"]
33
-
34
- # Load the third and fourth models
35
  feature_extractor_3 = AutoFeatureExtractor.from_pretrained(models[0], device=device)
36
  model_3 = AutoModelForImageClassification.from_pretrained(models[0]).to(device)
37
-
38
  feature_extractor_4 = AutoFeatureExtractor.from_pretrained(models[1], device=device)
39
  model_4 = AutoModelForImageClassification.from_pretrained(models[1]).to(device)
40
 
@@ -56,7 +53,6 @@ def convert_pil_to_bytes(image, format='JPEG'):
56
 
57
  @spaces.GPU(duration=10)
58
  def predict_image(img, confidence_threshold):
59
-
60
  # Ensure the image is a PIL Image
61
  if not isinstance(img, Image.Image):
62
  raise ValueError(f"Expected a PIL Image, but got {type(img)}")
@@ -66,7 +62,7 @@ def predict_image(img, confidence_threshold):
66
  img_pil = img.convert('RGB')
67
  else:
68
  img_pil = img
69
-
70
  # Resize the image
71
  img_pil = transforms.Resize((256, 256))(img_pil)
72
 
@@ -79,7 +75,6 @@ def predict_image(img, confidence_threshold):
79
  for class_name in class_names_1:
80
  if class_name not in result_1:
81
  result_1[class_name] = 0.0
82
-
83
  # Check if either class meets the confidence threshold
84
  if result_1['artificial'] >= confidence_threshold:
85
  label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
@@ -99,7 +94,6 @@ def predict_image(img, confidence_threshold):
99
  for class_name in class_names_2:
100
  if class_name not in result_2:
101
  result_2[class_name] = 0.0
102
-
103
  # Check if either class meets the confidence threshold
104
  if result_2['AI Image'] >= confidence_threshold:
105
  label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
@@ -117,7 +111,6 @@ def predict_image(img, confidence_threshold):
117
  outputs_3 = model_3(**inputs_3)
118
  logits_3 = outputs_3.logits
119
  probabilities_3 = softmax(logits_3.cpu().numpy()[0])
120
-
121
  result_3 = {
122
  labels_3[0]: float(probabilities_3[0]), # AI
123
  labels_3[1]: float(probabilities_3[1]) # Real
@@ -127,7 +120,6 @@ def predict_image(img, confidence_threshold):
127
  for class_name in labels_3:
128
  if class_name not in result_3:
129
  result_3[class_name] = 0.0
130
-
131
  # Check if either class meets the confidence threshold
132
  if result_3['AI'] >= confidence_threshold:
133
  label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
@@ -145,7 +137,6 @@ def predict_image(img, confidence_threshold):
145
  outputs_4 = model_4(**inputs_4)
146
  logits_4 = outputs_4.logits
147
  probabilities_4 = softmax(logits_4.cpu().numpy()[0])
148
-
149
  result_4 = {
150
  labels_4[0]: float(probabilities_4[0]), # AI
151
  labels_4[1]: float(probabilities_4[1]) # Real
@@ -155,7 +146,6 @@ def predict_image(img, confidence_threshold):
155
  for class_name in labels_4:
156
  if class_name not in result_4:
157
  result_4[class_name] = 0.0
158
-
159
  # Check if either class meets the confidence threshold
160
  if result_4['AI'] >= confidence_threshold:
161
  label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
@@ -165,16 +155,17 @@ def predict_image(img, confidence_threshold):
165
  label_4 = "Uncertain Classification"
166
  except Exception as e:
167
  label_4 = f"Error: {str(e)}"
168
-
169
  try:
170
  img_bytes = convert_pil_to_bytes(img_pil)
171
  response5_raw = call_inference(img_bytes)
172
- response5 = response5_raw.json()
173
- print(response5)
 
 
174
  except Exception as e:
175
  label_5 = f"Error: {str(e)}"
176
 
177
-
178
  # Combine results
179
  combined_results = {
180
  "SwinV2/detect": label_1,
@@ -183,7 +174,6 @@ def predict_image(img, confidence_threshold):
183
  "Swin/SDXL-FLUX": label_4,
184
  "GOAT": label_5
185
  }
186
-
187
  return img_pil, combined_results
188
 
189
  # Define the Gradio interface
@@ -197,10 +187,57 @@ with gr.Blocks() as iface:
197
  inputs = [image_input, confidence_slider]
198
  with gr.Column():
199
  image_output = gr.Image(label="Processed Image")
200
- label_output = gr.JSON(label="Model Predictions")
201
- outputs = [image_output, label_output]
 
202
 
203
  gr.Button("Predict").click(fn=predict_image, inputs=inputs, outputs=outputs)
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  # Launch the interface
206
  iface.launch()
 
30
 
31
  # Load additional models
32
  models = ["Organika/sdxl-detector", "cmckinle/sdxl-flux-detector"]
 
 
33
  feature_extractor_3 = AutoFeatureExtractor.from_pretrained(models[0], device=device)
34
  model_3 = AutoModelForImageClassification.from_pretrained(models[0]).to(device)
 
35
  feature_extractor_4 = AutoFeatureExtractor.from_pretrained(models[1], device=device)
36
  model_4 = AutoModelForImageClassification.from_pretrained(models[1]).to(device)
37
 
 
53
 
54
  @spaces.GPU(duration=10)
55
  def predict_image(img, confidence_threshold):
 
56
  # Ensure the image is a PIL Image
57
  if not isinstance(img, Image.Image):
58
  raise ValueError(f"Expected a PIL Image, but got {type(img)}")
 
62
  img_pil = img.convert('RGB')
63
  else:
64
  img_pil = img
65
+
66
  # Resize the image
67
  img_pil = transforms.Resize((256, 256))(img_pil)
68
 
 
75
  for class_name in class_names_1:
76
  if class_name not in result_1:
77
  result_1[class_name] = 0.0
 
78
  # Check if either class meets the confidence threshold
79
  if result_1['artificial'] >= confidence_threshold:
80
  label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
 
94
  for class_name in class_names_2:
95
  if class_name not in result_2:
96
  result_2[class_name] = 0.0
 
97
  # Check if either class meets the confidence threshold
98
  if result_2['AI Image'] >= confidence_threshold:
99
  label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
 
111
  outputs_3 = model_3(**inputs_3)
112
  logits_3 = outputs_3.logits
113
  probabilities_3 = softmax(logits_3.cpu().numpy()[0])
 
114
  result_3 = {
115
  labels_3[0]: float(probabilities_3[0]), # AI
116
  labels_3[1]: float(probabilities_3[1]) # Real
 
120
  for class_name in labels_3:
121
  if class_name not in result_3:
122
  result_3[class_name] = 0.0
 
123
  # Check if either class meets the confidence threshold
124
  if result_3['AI'] >= confidence_threshold:
125
  label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
 
137
  outputs_4 = model_4(**inputs_4)
138
  logits_4 = outputs_4.logits
139
  probabilities_4 = softmax(logits_4.cpu().numpy()[0])
 
140
  result_4 = {
141
  labels_4[0]: float(probabilities_4[0]), # AI
142
  labels_4[1]: float(probabilities_4[1]) # Real
 
146
  for class_name in labels_4:
147
  if class_name not in result_4:
148
  result_4[class_name] = 0.0
 
149
  # Check if either class meets the confidence threshold
150
  if result_4['AI'] >= confidence_threshold:
151
  label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
 
155
  label_4 = "Uncertain Classification"
156
  except Exception as e:
157
  label_4 = f"Error: {str(e)}"
158
+
159
  try:
160
  img_bytes = convert_pil_to_bytes(img_pil)
161
  response5_raw = call_inference(img_bytes)
162
+ print(response5_raw)
163
+ response5 = response5_raw
164
+
165
+ label_5 = f"Result: {response5}"
166
  except Exception as e:
167
  label_5 = f"Error: {str(e)}"
168
 
 
169
  # Combine results
170
  combined_results = {
171
  "SwinV2/detect": label_1,
 
174
  "Swin/SDXL-FLUX": label_4,
175
  "GOAT": label_5
176
  }
 
177
  return img_pil, combined_results
178
 
179
  # Define the Gradio interface
 
187
  inputs = [image_input, confidence_slider]
188
  with gr.Column():
189
  image_output = gr.Image(label="Processed Image")
190
+ # Custom HTML component to display results in 5 columns
191
+ results_html = gr.HTML(label="Model Predictions")
192
+ outputs = [image_output, results_html]
193
 
194
  gr.Button("Predict").click(fn=predict_image, inputs=inputs, outputs=outputs)
195
 
196
+ # Define a function to generate the HTML content for the results
197
+ def generate_results_html(results):
198
+ html_content = """
199
+ <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
200
+ <div class="container">
201
+ <div class="row mt-4">
202
+ <div class="col">
203
+ <h5>SwinV2/detect</h5>
204
+ <p>{SwinV2_detect}</p>
205
+ </div>
206
+ <div class="col">
207
+ <h5>ViT/AI-vs-Real</h5>
208
+ <p>{ViT_AI_vs_Real}</p>
209
+ </div>
210
+ <div class="col">
211
+ <h5>Swin/SDXL</h5>
212
+ <p>{Swin_SDXL}</p>
213
+ </div>
214
+ <div class="col">
215
+ <h5>Swin/SDXL-FLUX</h5>
216
+ <p>{Swin_SDXL_FLUX}</p>
217
+ </div>
218
+ <div class="col">
219
+ <h5>GOAT</h5>
220
+ <p>{GOAT}</p>
221
+ </div>
222
+ </div>
223
+ </div>
224
+ """.format(
225
+ SwinV2_detect=results.get("SwinV2/detect", "N/A"),
226
+ ViT_AI_vs_Real=results.get("ViT/AI-vs-Real", "N/A"),
227
+ Swin_SDXL=results.get("Swin/SDXL", "N/A"),
228
+ Swin_SDXL_FLUX=results.get("Swin/SDXL-FLUX", "N/A"),
229
+ GOAT=results.get("GOAT", "N/A")
230
+ )
231
+ return html_content
232
+
233
+ # Modify the predict_image function to return the HTML content
234
+ def predict_image_with_html(img, confidence_threshold):
235
+ img_pil, results = predict_image(img, confidence_threshold)
236
+ html_content = generate_results_html(results)
237
+ return img_pil, html_content
238
+
239
+ # Update the button click to use the new function
240
+ gr.Button("Predict").click(fn=predict_image_with_html, inputs=inputs, outputs=outputs)
241
+
242
  # Launch the interface
243
  iface.launch()