tagger / app.py
MonkeyJuice's picture
check ignore
c410097
raw
history blame
3.68 kB
#!/usr/bin/env python
from __future__ import annotations
import gradio as gr
import PIL.Image
import zipfile
from genTag import genTag
from checkIgnore import is_ignore
def predict(image: PIL.Image.Image, score_threshold: float):
result_threshold = genTag(image, score_threshold)
result_html = ''
for label, prob in result_threshold.items():
if is_ignore(label, 1):
result_html += '<p class="m5dd_list">'
else:
result_html += '<p class="m5dd_list use">'
result_html = result_html + '<span>' + str(label) + '</span><span>' + str(round(prob, 3)) + '</span></p>'
result_html = '<div>' + result_html + '</div>'
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 1)}
result_text = '<div id="m5dd_result">' + ', '.join(result_filter.keys()) + '</div>'
return result_html, result_text
def predict_batch(zip_file, score_threshold: float, progress=gr.Progress()):
result = ''
with zipfile.ZipFile(zip_file) as zf:
for file in progress.tqdm(zf.namelist()):
print(file)
if file.endswith(".png") or file.endswith(".jpg"):
image_file = zf.open(file)
image = PIL.Image.open(image_file)
image = image.convert("RGB")
result_threshold = genTag(image, score_threshold)
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
tag = ', '.join(result_filter.keys())
result = result + str(file) + '\n' + str(tag) + '\n'
return result
with gr.Blocks(css="style.css", js="script.js") as demo:
with gr.Tab(label='Single'):
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(label='Upload a image',
type='pil',
sources=["upload", "clipboard"],
height='20em')
score_threshold = gr.Slider(label='Score threshold',
minimum=0,
maximum=1,
step=0.05,
value=0.5)
run_button = gr.Button('Run')
result_text = gr.HTML(value="<div></div>")
with gr.Column(scale=2):
result_html = gr.HTML(value="<div></div>")
with gr.Tab(label='Batch'):
with gr.Row():
with gr.Column(scale=1):
batch_file = gr.File(label="Upload a ZIP file containing images",
file_types=['.zip'],
height='20em')
score_threshold2 = gr.Slider(label='Score threshold',
minimum=0,
maximum=1,
step=0.05,
value=0.5)
run_button2 = gr.Button('Run')
with gr.Column(scale=2):
result_text2 = gr.Textbox(lines=5,
label='Result',
show_copy_button=True)
run_button.click(
fn=predict,
inputs=[image, score_threshold],
outputs=[result_html, result_text],
api_name='predict',
)
run_button2.click(
fn=predict_batch,
inputs=[batch_file, score_threshold2],
outputs=[result_text2],
api_name='predict_batch',
)
demo.queue().launch()