Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,10 @@ from PIL import Image
|
|
7 |
import numpy as np
|
8 |
from utils.goat import call_inference
|
9 |
import io
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Ensure using GPU if available
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -168,6 +172,43 @@ def predict_image(img, confidence_threshold):
|
|
168 |
}
|
169 |
return img_pil, combined_results
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
# Define the Gradio interface
|
172 |
with gr.Blocks() as iface:
|
173 |
gr.Markdown("# AI Generated Image Classification")
|
@@ -185,42 +226,5 @@ with gr.Blocks() as iface:
|
|
185 |
|
186 |
gr.Button("Predict").click(fn=predict_image_with_html, inputs=inputs, outputs=outputs)
|
187 |
|
188 |
-
# Define a function to generate the HTML content
|
189 |
-
def generate_results_html(results):
|
190 |
-
html_content = f"""
|
191 |
-
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
192 |
-
<div class="container">
|
193 |
-
<div class="row mt-4">
|
194 |
-
<div class="col">
|
195 |
-
<h5>SwinV2/detect</h5>
|
196 |
-
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
197 |
-
</div>
|
198 |
-
<div class="col">
|
199 |
-
<h5>ViT/AI-vs-Real</h5>
|
200 |
-
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
201 |
-
</div>
|
202 |
-
<div class="col">
|
203 |
-
<h5>Swin/SDXL</h5>
|
204 |
-
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
205 |
-
</div>
|
206 |
-
<div class="col">
|
207 |
-
<h5>Swin/SDXL-FLUX</h5>
|
208 |
-
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
209 |
-
</div>
|
210 |
-
<div class="col">
|
211 |
-
<h5>GOAT</h5>
|
212 |
-
<p>{results.get("GOAT", "N/A")}</p>
|
213 |
-
</div>
|
214 |
-
</div>
|
215 |
-
</div>
|
216 |
-
"""
|
217 |
-
return html_content
|
218 |
-
|
219 |
-
# Modify the predict_image function to return the HTML content
|
220 |
-
def predict_image_with_html(img, confidence_threshold):
|
221 |
-
img_pil, results = predict_image(img, confidence_threshold)
|
222 |
-
html_content = generate_results_html(results)
|
223 |
-
return img_pil, html_content
|
224 |
-
|
225 |
# Launch the interface
|
226 |
iface.launch()
|
|
|
7 |
import numpy as np
|
8 |
from utils.goat import call_inference
|
9 |
import io
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
# Suppress warnings
|
13 |
+
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
14 |
|
15 |
# Ensure using GPU if available
|
16 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
172 |
}
|
173 |
return img_pil, combined_results
|
174 |
|
175 |
+
# Define a function to generate the HTML content
|
176 |
+
def generate_results_html(results):
|
177 |
+
html_content = f"""
|
178 |
+
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
179 |
+
<div class="container">
|
180 |
+
<div class="row mt-4">
|
181 |
+
<div class="col">
|
182 |
+
<h5>SwinV2/detect</h5>
|
183 |
+
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
184 |
+
</div>
|
185 |
+
<div class="col">
|
186 |
+
<h5>ViT/AI-vs-Real</h5>
|
187 |
+
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
188 |
+
</div>
|
189 |
+
<div class="col">
|
190 |
+
<h5>Swin/SDXL</h5>
|
191 |
+
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
192 |
+
</div>
|
193 |
+
<div class="col">
|
194 |
+
<h5>Swin/SDXL-FLUX</h5>
|
195 |
+
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
196 |
+
</div>
|
197 |
+
<div class="col">
|
198 |
+
<h5>GOAT</h5>
|
199 |
+
<p>{results.get("GOAT", "N/A")}</p>
|
200 |
+
</div>
|
201 |
+
</div>
|
202 |
+
</div>
|
203 |
+
"""
|
204 |
+
return html_content
|
205 |
+
|
206 |
+
# Modify the predict_image function to return the HTML content
|
207 |
+
def predict_image_with_html(img, confidence_threshold):
|
208 |
+
img_pil, results = predict_image(img, confidence_threshold)
|
209 |
+
html_content = generate_results_html(results)
|
210 |
+
return img_pil, html_content
|
211 |
+
|
212 |
# Define the Gradio interface
|
213 |
with gr.Blocks() as iface:
|
214 |
gr.Markdown("# AI Generated Image Classification")
|
|
|
226 |
|
227 |
gr.Button("Predict").click(fn=predict_image_with_html, inputs=inputs, outputs=outputs)
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
# Launch the interface
|
230 |
iface.launch()
|