Manga-Panel-OCR / app.py
yuki-imajuku's picture
initial commit
b5e2084
# Install FlashAttention
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import base64
from io import BytesIO
import re
from PIL import Image, ImageDraw
import gradio as gr
import spaces
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
def pil2base64(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
@spaces.GPU
@torch.inference_mode()
def inference_fn(
image: Image.Image | None,
# progress=gr.Progress(track_tqdm=True),
) -> tuple[str, Image.Image | None]:
if image is None:
gr.Warning("Please upload an image!", duration=10)
return "Please upload an image!", None
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"yuki-imajuku/Qwen2.5-VL-3B-Instruct-FT-Manga109-OCR-Panel-Onomatopoeia",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=device,
)
base64_image = pil2base64(image)
messages = [
{"role": "user", "content": [
{"type": "image", "image": f"data:image;base64,{base64_image}"},
{"type": "text", "text": "With this image, please output the result of OCR with grounding."}
]},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
raw_output = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
)[0]
print(raw_output)
result_image = image_inputs[0].copy()
draw = ImageDraw.Draw(result_image)
ocr_texts = []
for ocr_text, ocr_quad in re.findall(r"<\|object_ref_start\|>(.+?)<\|object_ref_end\|><\|quad_start\|>([\d,]+)<\|quad_end\|>", raw_output):
ocr_texts.append(f"{ocr_text} -> {ocr_quad}")
quad = [int(x) for x in ocr_quad.split(",")]
for i in range(4):
start_point = quad[i*2:i*2+2]
end_point = quad[i*2+2:i*2+4] if i < 3 else quad[:2]
draw.line(start_point + end_point, fill="red", width=4)
ocr_texts_str = "\n".join(ocr_texts)
return ocr_texts_str, result_image
with gr.Blocks() as demo:
gr.Markdown("# Manga Panel OCR")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", image_mode="RGB", type="pil")
input_button = gr.Button(value="Submit")
with gr.Column():
ocr_text = gr.Textbox(label="Result", lines=5)
ocr_image = gr.Image(label="OCR Result", type="pil", show_label=False)
input_button.click(
fn=inference_fn,
inputs=[input_image],
outputs=[ocr_text, ocr_image],
)
ocr_examples = gr.Examples(
examples=[],
fn=inference_fn,
inputs=[input_image],
outputs=[ocr_text, ocr_image],
cache_examples=False,
)
demo.queue().launch()