top001's picture
Update app.py
42f5a78 verified
import os
import re
from functools import lru_cache
from typing import List, Mapping, Tuple
import gradio as gr
import numpy as np
import onnxruntime as ort
from PIL import Image
from huggingface_hub import hf_hub_download
import io
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
import uvicorn
app = FastAPI()
def _yield_tags_from_txt_file(txt_file: str):
with open(txt_file, 'r') as f:
for line in f:
if line:
yield line.strip()
@lru_cache()
def get_deepdanbooru_tags() -> List[str]:
tags_file = hf_hub_download('chinoll/deepdanbooru', 'tags.txt')
return list(_yield_tags_from_txt_file(tags_file))
@lru_cache()
def get_deepdanbooru_onnx() -> ort.InferenceSession:
onnx_file = hf_hub_download('chinoll/deepdanbooru', 'deepdanbooru.onnx')
return ort.InferenceSession(onnx_file)
def image_preprocess(image: Image.Image) -> np.ndarray:
if image.mode != 'RGB':
image = image.convert('RGB')
o_width, o_height = image.size
scale = 512.0 / max(o_width, o_height)
f_width, f_height = map(lambda x: int(x * scale), (o_width, o_height))
image = image.resize((f_width, f_height))
data = np.asarray(image).astype(np.float32) / 255 # H x W x C
height_pad_left = (512 - f_height) // 2
height_pad_right = 512 - f_height - height_pad_left
width_pad_left = (512 - f_width) // 2
width_pad_right = 512 - f_width - width_pad_left
data = np.pad(
data,
((height_pad_left, height_pad_right), (width_pad_left, width_pad_right), (0, 0)),
mode='constant',
constant_values=0.0
)
assert data.shape == (512, 512, 3), f'Shape (512, 512, 3) expected, but {data.shape!r} found.'
return data.reshape((1, 512, 512, 3)) # B x H x W x C
RE_SPECIAL = re.compile(r'([\\()])')
def image_to_deepdanbooru_tags(
image: Image.Image,
threshold: float,
use_spaces: bool,
use_escape: bool,
include_ranks: bool,
score_descend: bool
) -> Tuple[str, Mapping[str, float]]:
tags = get_deepdanbooru_tags()
session = get_deepdanbooru_onnx()
input_name = session.get_inputs()[0].name
output_names = [output.name for output in session.get_outputs()]
result = session.run(output_names, {input_name: image_preprocess(image)})[0]
filtered_tags = {
tag: float(score) for tag, score in zip(tags, result[0])
if score >= threshold
}
text_items = []
tags_pairs = filtered_tags.items()
if score_descend:
tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
for tag, score in tags_pairs:
tag_outformat = tag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat)
if include_ranks:
tag_outformat = f"({tag_outformat}:{score:.3f})"
text_items.append(tag_outformat)
output_text = ', '.join(text_items)
return output_text, filtered_tags
from typing import Optional
@app.post("/tagging")
async def tagging_endpoint(
image: UploadFile = File(...),
threshold: Optional[float] = Form(0.5)
):
image_data = await image.read()
pil_image = Image.open(io.BytesIO(image_data)).convert("RGB")
output_text, filtered_tags = image_to_deepdanbooru_tags(
pil_image,
threshold=threshold,
use_spaces=False,
use_escape=False,
include_ranks=False,
score_descend=True
)
tags = list(filtered_tags.keys())
return JSONResponse(content={"tags": tags})
def gradio_interface(
image: Image.Image,
threshold: float,
use_spaces: bool,
use_escape: bool,
include_ranks: bool,
score_descend: bool
):
output_text, filtered_tags = image_to_deepdanbooru_tags(
image, threshold, use_spaces, use_escape, include_ranks, score_descend
)
return output_text, filtered_tags
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
gr_threshold = gr.Slider(0.0, 1.0, 0.5, label='Tagging Confidence Threshold')
with gr.Row():
gr_space = gr.Checkbox(value=False, label='Use Space Instead Of _')
gr_escape = gr.Checkbox(value=True, label='Use Text Escape')
gr_confidence = gr.Checkbox(value=False, label='Keep Confidences')
gr_order = gr.Checkbox(value=True, label='Descend By Confidence')
gr_btn_submit = gr.Button(value='Tagging', variant='primary')
with gr.Column():
with gr.Tabs():
with gr.Tab("Tags"):
gr_tags = gr.Label(label='Tags')
with gr.Tab("Exported Text"):
gr_output_text = gr.TextArea(label='Exported Text')
gr_btn_submit.click(
gradio_interface,
inputs=[gr_input_image, gr_threshold, gr_space, gr_escape, gr_confidence, gr_order],
outputs=[gr_output_text, gr_tags],
)
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=7860)