tb-ocr / app.py
wjbmattingly's picture
Update app.py
b8c6096 verified
raw
history blame
1.92 kB
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
import gradio as gr
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = "yifeihu/TB-OCR-preview-0.1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
trust_remote_code=True,
torch_dtype="auto",
attn_implementation='flash_attention_2',
load_in_4bit=True
)
processor = AutoProcessor.from_pretrained(model_id,
trust_remote_code=True,
num_crops=16
)
def phi_ocr(image):
question = "Convert the text to markdown format."
prompt_message = [{
'role': 'user',
'content': f'<|image_1|>\n{question}',
}]
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
generation_args = {
"max_new_tokens": 1024,
"temperature": 0.1,
"do_sample": False
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = response.split("<image_end>")[0]
return response
@spaces.GPU
def process_image(input_image):
return phi_ocr(input_image)
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="OCR with [TB-OCR-preview-0.1](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)",
description="Upload an image to extract and convert text to markdown format."
)
iface.launch()