Update app.py
Browse files
app.py
CHANGED
@@ -4,13 +4,9 @@ from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassificat
|
|
4 |
from torchvision import transforms
|
5 |
import torch
|
6 |
from PIL import Image
|
7 |
-
import pandas as pd
|
8 |
-
import warnings
|
9 |
-
import math
|
10 |
import numpy as np
|
11 |
from utils.goat import call_inference
|
12 |
import io
|
13 |
-
import sys
|
14 |
|
15 |
# Suppress warnings
|
16 |
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
@@ -157,11 +153,10 @@ def predict_image(img, confidence_threshold):
|
|
157 |
label_4 = f"Error: {str(e)}"
|
158 |
|
159 |
try:
|
160 |
-
|
161 |
-
response5_raw = call_inference(
|
162 |
-
|
163 |
-
response5
|
164 |
-
|
165 |
label_5 = f"Result: {response5}"
|
166 |
except Exception as e:
|
167 |
label_5 = f"Error: {str(e)}"
|
@@ -191,43 +186,37 @@ with gr.Blocks() as iface:
|
|
191 |
results_html = gr.HTML(label="Model Predictions")
|
192 |
outputs = [image_output, results_html]
|
193 |
|
194 |
-
gr.Button("Predict").click(fn=
|
195 |
|
196 |
-
# Define a function to generate the HTML content
|
197 |
def generate_results_html(results):
|
198 |
-
html_content = """
|
199 |
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
200 |
<div class="container">
|
201 |
<div class="row mt-4">
|
202 |
<div class="col">
|
203 |
<h5>SwinV2/detect</h5>
|
204 |
-
<p>{
|
205 |
</div>
|
206 |
<div class="col">
|
207 |
<h5>ViT/AI-vs-Real</h5>
|
208 |
-
<p>{
|
209 |
</div>
|
210 |
<div class="col">
|
211 |
<h5>Swin/SDXL</h5>
|
212 |
-
<p>{
|
213 |
</div>
|
214 |
<div class="col">
|
215 |
<h5>Swin/SDXL-FLUX</h5>
|
216 |
-
<p>{
|
217 |
</div>
|
218 |
<div class="col">
|
219 |
<h5>GOAT</h5>
|
220 |
-
<p>{GOAT}</p>
|
221 |
</div>
|
222 |
</div>
|
223 |
</div>
|
224 |
-
"""
|
225 |
-
SwinV2_detect=results.get("SwinV2/detect", "N/A"),
|
226 |
-
ViT_AI_vs_Real=results.get("ViT/AI-vs-Real", "N/A"),
|
227 |
-
Swin_SDXL=results.get("Swin/SDXL", "N/A"),
|
228 |
-
Swin_SDXL_FLUX=results.get("Swin/SDXL-FLUX", "N/A"),
|
229 |
-
GOAT=results.get("GOAT", "N/A")
|
230 |
-
)
|
231 |
return html_content
|
232 |
|
233 |
# Modify the predict_image function to return the HTML content
|
@@ -236,8 +225,5 @@ with gr.Blocks() as iface:
|
|
236 |
html_content = generate_results_html(results)
|
237 |
return img_pil, html_content
|
238 |
|
239 |
-
# Update the button click to use the new function
|
240 |
-
gr.Button("Predict").click(fn=predict_image_with_html, inputs=inputs, outputs=outputs)
|
241 |
-
|
242 |
# Launch the interface
|
243 |
iface.launch()
|
|
|
4 |
from torchvision import transforms
|
5 |
import torch
|
6 |
from PIL import Image
|
|
|
|
|
|
|
7 |
import numpy as np
|
8 |
from utils.goat import call_inference
|
9 |
import io
|
|
|
10 |
|
11 |
# Suppress warnings
|
12 |
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
|
|
153 |
label_4 = f"Error: {str(e)}"
|
154 |
|
155 |
try:
|
156 |
+
img_bytes = convert_pil_to_bytes(img_pil)
|
157 |
+
response5_raw = call_inference(img_bytes)
|
158 |
+
response5 = response5_raw.json()
|
159 |
+
print(response5)
|
|
|
160 |
label_5 = f"Result: {response5}"
|
161 |
except Exception as e:
|
162 |
label_5 = f"Error: {str(e)}"
|
|
|
186 |
results_html = gr.HTML(label="Model Predictions")
|
187 |
outputs = [image_output, results_html]
|
188 |
|
189 |
+
gr.Button("Predict").click(fn=predict_image_with_html, inputs=inputs, outputs=outputs)
|
190 |
|
191 |
+
# Define a function to generate the HTML content
|
192 |
def generate_results_html(results):
|
193 |
+
html_content = f"""
|
194 |
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
195 |
<div class="container">
|
196 |
<div class="row mt-4">
|
197 |
<div class="col">
|
198 |
<h5>SwinV2/detect</h5>
|
199 |
+
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
200 |
</div>
|
201 |
<div class="col">
|
202 |
<h5>ViT/AI-vs-Real</h5>
|
203 |
+
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
204 |
</div>
|
205 |
<div class="col">
|
206 |
<h5>Swin/SDXL</h5>
|
207 |
+
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
208 |
</div>
|
209 |
<div class="col">
|
210 |
<h5>Swin/SDXL-FLUX</h5>
|
211 |
+
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
212 |
</div>
|
213 |
<div class="col">
|
214 |
<h5>GOAT</h5>
|
215 |
+
<p>{results.get("GOAT", "N/A")}</p>
|
216 |
</div>
|
217 |
</div>
|
218 |
</div>
|
219 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
return html_content
|
221 |
|
222 |
# Modify the predict_image function to return the HTML content
|
|
|
225 |
html_content = generate_results_html(results)
|
226 |
return img_pil, html_content
|
227 |
|
|
|
|
|
|
|
228 |
# Launch the interface
|
229 |
iface.launch()
|