wjbmattingly commited on
Commit
ad1dec5
·
verified ·
1 Parent(s): c499b78

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoProcessor
3
+ from PIL import Image
4
+ import requests
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ model_id = "yifeihu/TB-OCR-preview-0.1"
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ device_map="cuda",
13
+ trust_remote_code=True,
14
+ torch_dtype="auto",
15
+ attn_implementation='flash_attention_2',
16
+ load_in_4bit=True
17
+ )
18
+ processor = AutoProcessor.from_pretrained(model_id,
19
+ trust_remote_code=True,
20
+ num_crops=16
21
+ )
22
+
23
+ @spaces.GPU
24
+ def phi_ocr(image):
25
+ question = "Convert the text to markdown format."
26
+ prompt_message = [{
27
+ 'role': 'user',
28
+ 'content': f'<|image_1|>\n{question}',
29
+ }]
30
+ prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
31
+ inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
32
+ generation_args = {
33
+ "max_new_tokens": 1024,
34
+ "temperature": 0.1,
35
+ "do_sample": False
36
+ }
37
+ generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
38
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
39
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
40
+ response = response.split("<image_end>")[0]
41
+ return response
42
+
43
+ def process_image(input_image):
44
+ return phi_ocr(input_image)
45
+
46
+ iface = gr.Interface(
47
+ fn=process_image,
48
+ inputs=gr.Image(type="pil"),
49
+ outputs="text",
50
+ title="OCR with Phi-3.5-vision-instruct",
51
+ description="Upload an image to extract and convert text to markdown format."
52
+ )
53
+
54
+ iface.launch()