sanjay9876 commited on
Commit
addd52c
·
verified ·
1 Parent(s): e697d1f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ import spaces
4
+
5
+ import requests
6
+ import copy
7
+
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import io
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+
13
+ import random
14
+ import numpy as np
15
+
16
+ import subprocess
17
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
18
+
19
+ models = {
20
+ 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True).to("cuda").eval(),
21
+ 'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
22
+ 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
23
+ 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
24
+ }
25
+
26
+ processors = {
27
+ 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
28
+ 'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
29
+ 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
30
+ 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
31
+ }
32
+
33
+
34
+ DESCRIPTION = "# [Handwriting Reader with 'Large Vision Model'](https://huggingface.co/microsoft/Florence-2-large)"
35
+
36
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
37
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
38
+
39
+ def fig_to_pil(fig):
40
+ buf = io.BytesIO()
41
+ fig.savefig(buf, format='png')
42
+ buf.seek(0)
43
+ return Image.open(buf)
44
+
45
+ @spaces.GPU
46
+ def run_example(task_prompt, image, text_input=None, model_id='microsoft/Florence-2-large'):
47
+ model = models[model_id]
48
+ processor = processors[model_id]
49
+ if text_input is None:
50
+ prompt = task_prompt
51
+ else:
52
+ prompt = task_prompt + text_input
53
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
54
+ generated_ids = model.generate(
55
+ input_ids=inputs["input_ids"],
56
+ pixel_values=inputs["pixel_values"],
57
+ max_new_tokens=1024,
58
+ early_stopping=False,
59
+ do_sample=False,
60
+ num_beams=3,
61
+ )
62
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
+ parsed_answer = processor.post_process_generation(
64
+ generated_text,
65
+ task=task_prompt,
66
+ image_size=(image.width, image.height)
67
+ )
68
+ return parsed_answer
69
+
70
+ def plot_bbox(image, data):
71
+ fig, ax = plt.subplots()
72
+ ax.imshow(image)
73
+ for bbox, label in zip(data['bboxes'], data['labels']):
74
+ x1, y1, x2, y2 = bbox
75
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
76
+ ax.add_patch(rect)
77
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
78
+ ax.axis('off')
79
+ return fig
80
+
81
+ def draw_polygons(image, prediction, fill_mask=False):
82
+ draw = ImageDraw.Draw(image)
83
+ scale = 1
84
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
85
+ color = random.choice(colormap)
86
+ fill_color = random.choice(colormap) if fill_mask else None
87
+ for _polygon in polygons:
88
+ _polygon = np.array(_polygon).reshape(-1, 2)
89
+ if len(_polygon) < 3:
90
+ print('Invalid polygon:', _polygon)
91
+ continue
92
+ _polygon = (_polygon * scale).reshape(-1).tolist()
93
+ if fill_mask:
94
+ draw.polygon(_polygon, outline=color, fill=fill_color)
95
+ else:
96
+ draw.polygon(_polygon, outline=color)
97
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
98
+ return image
99
+
100
+ def convert_to_od_format(data):
101
+ bboxes = data.get('bboxes', [])
102
+ labels = data.get('bboxes_labels', [])
103
+ od_results = {
104
+ 'bboxes': bboxes,
105
+ 'labels': labels
106
+ }
107
+ return od_results
108
+
109
+ def draw_ocr_bboxes(image, prediction):
110
+ scale = 1
111
+ draw = ImageDraw.Draw(image)
112
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
113
+ for box, label in zip(bboxes, labels):
114
+ color = random.choice(colormap)
115
+ new_box = (np.array(box) * scale).tolist()
116
+ draw.polygon(new_box, width=3, outline=color)
117
+ draw.text((new_box[0]+8, new_box[1]+2),
118
+ "{}".format(label),
119
+ align="right",
120
+ fill=color)
121
+ return image
122
+
123
+ def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
124
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
125
+ if task_prompt == 'OCR':
126
+ task_prompt = '<OCR>'
127
+ results = run_example(task_prompt, image, model_id=model_id)
128
+ return results, None
129
+ elif task_prompt == 'OCR with Region':
130
+ task_prompt = '<OCR_WITH_REGION>'
131
+ results = run_example(task_prompt, image, model_id=model_id)
132
+ output_image = copy.deepcopy(image)
133
+ output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
134
+ return results, output_image
135
+ else:
136
+ return "", None # Return empty string and None for unknown task prompts
137
+
138
+ css = """
139
+ #output {
140
+ height: 500px;
141
+ overflow: auto;
142
+ border: 1px solid #ccc;
143
+ }
144
+ """
145
+
146
+
147
+ single_task_list =['OCR', 'OCR with Region']
148
+
149
+ cascased_task_list =[
150
+ 'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
151
+ ]
152
+
153
+
154
+ def update_task_dropdown(choice):
155
+ if choice == 'Cascased task':
156
+ return gr.Dropdown(choices=cascased_task_list, value='Caption + Grounding')
157
+ else:
158
+ return gr.Dropdown(choices=single_task_list, value='Caption')
159
+
160
+
161
+
162
+ with gr.Blocks(css=css) as demo:
163
+ gr.Markdown(DESCRIPTION)
164
+ with gr.Tab(label="Handwriting Reader with LVM"):
165
+ with gr.Row():
166
+ with gr.Column():
167
+ input_img = gr.Image(label="Input Picture")
168
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
169
+ task_type = gr.Radio(choices=['Single task', 'Cascased task'], label='Task type selector', value='Single task')
170
+ task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
171
+ task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
172
+ text_input = gr.Textbox(label="Text Input (optional)")
173
+ submit_btn = gr.Button(value="Submit")
174
+ with gr.Column():
175
+ output_text = gr.Textbox(label="Output Text")
176
+ output_img = gr.Image(label="Output Image")
177
+
178
+ gr.Examples(
179
+ examples=[
180
+ ["hw_pic_1.jpg", 'OCR'],
181
+ ["hw_pic_2.jpg", 'OCR'],
182
+ ["hw_pic_3.jpg", 'OCR'],
183
+ ["hw_pic_4.jpg", 'OCR'],
184
+ ["hw_pic_5.jpg", 'OCR']
185
+ ],
186
+ inputs=[input_img, task_prompt],
187
+ outputs=[output_text, output_img],
188
+ fn=process_image,
189
+ cache_examples=True,
190
+ label='Try examples'
191
+ )
192
+
193
+ submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
194
+
195
+ demo.launch(debug=True)