Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers.image_utils import load_image | |
from threading import Thread | |
import time | |
import torch | |
import spaces | |
from PIL import Image | |
from transformers import ( | |
Qwen2VLForConditionalGeneration, | |
AutoProcessor, | |
TextIteratorStreamer, | |
) | |
from transformers import Qwen2_5_VLForConditionalGeneration | |
# Helper Functions | |
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str: | |
""" | |
Returns an HTML snippet for a thin animated progress bar with a label. | |
Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision. | |
""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
# Model and Processor Setup | |
QV_MODEL_ID = "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct" | |
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True) | |
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
QV_MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to("cuda").eval() | |
DOCSCOPEOCR_MODEL_ID = "prithivMLmods/docscopeOCR-7B-050425-exp" | |
docscopeocr_processor = AutoProcessor.from_pretrained(DOCSCOPEOCR_MODEL_ID, trust_remote_code=True) | |
docscopeocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
DOCSCOPEOCR_MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda").eval() | |
# Main Inference Function | |
def model_inference(message, history, use_docscopeocr): | |
text = message["text"].strip() | |
files = message.get("files", []) | |
if not text and not files: | |
yield "Error: Please input a text query or provide image files." | |
return | |
# Process files: images only | |
image_list = [] | |
for idx, file in enumerate(files): | |
try: | |
img = load_image(file) | |
label = f"Image {idx+1}:" | |
image_list.append((label, img)) | |
except Exception as e: | |
yield f"Error loading image: {str(e)}" | |
return | |
# Build content list | |
content = [{"type": "text", "text": text}] | |
for label, img in image_list: | |
content.append({"type": "text", "text": label}) | |
content.append({"type": "image", "image": img}) | |
messages = [{"role": "user", "content": content}] | |
# Select processor and model | |
if use_docscopeocr: | |
processor = docscopeocr_processor | |
model = docscopeocr_model | |
model_name = "DocScopeOCR" | |
else: | |
processor = qwen_processor | |
model = qwen_model | |
model_name = "Qwen2VL OCR" | |
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
all_images = [item["image"] for item in content if item["type"] == "image"] | |
inputs = processor( | |
text=[prompt_full], | |
images=all_images if all_images else None, | |
return_tensors="pt", | |
padding=True, | |
).to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
yield progress_bar_html(f"Processing with {model_name}") | |
for new_text in streamer: | |
buffer += new_text | |
buffer = buffer.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield buffer | |
# Gradio Interface | |
examples = [ | |
[{"text": "OCR the text in the image", "files": ["example/image1.jpg"]}], | |
[{"text": "Describe the content of the image", "files": ["example/image2.jpg"]}], | |
[{"text": "Extract the image content", "files": ["example/image3.jpg"]}], | |
] | |
demo = gr.ChatInterface( | |
fn=model_inference, | |
description="# **DocScope OCR `VL/OCR`**", | |
examples=examples, | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image"], | |
file_count="multiple", | |
placeholder="Input your query and optionally upload image(s). Select the model using the checkbox." | |
), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
cache_examples=False, | |
theme="bethecloud/storj_theme", | |
additional_inputs=[gr.Checkbox(label="Use DocScopeOCR", value=True, info="Check to use DocScopeOCR, uncheck to use Qwen2VL OCR")], | |
) | |
demo.launch(debug=True, ssr_mode=False) |