SixOpen commited on
Commit
664d48c
·
verified ·
1 Parent(s): a4e5a42

Update app.py

Browse files

Cleanup + cuda init update

Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -21,16 +21,33 @@ def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
21
  imports.remove("flash_attn")
22
  return imports
23
 
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
-
26
- with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
27
- model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval()
28
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
29
-
30
- colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
31
- 'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def run_example(task_prompt, image, text_input=None):
 
 
 
 
34
  prompt = task_prompt if text_input is None else task_prompt + text_input
35
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
36
  with torch.inference_mode():
@@ -38,6 +55,9 @@ def run_example(task_prompt, image, text_input=None):
38
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
39
  return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
40
 
 
 
 
41
  def fig_to_pil(fig):
42
  buf = io.BytesIO()
43
  fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
@@ -85,7 +105,7 @@ def draw_ocr_bboxes(image, prediction):
85
  bboxes, labels = prediction['quad_boxes'], prediction['labels']
86
  for box, label in zip(bboxes, labels):
87
  color = random.choice(colormap)
88
- box_array = np.array(box).reshape(-1, 2) # respect format
89
  polygon = patches.Polygon(box_array, edgecolor=color, fill=False, linewidth=2)
90
  ax.add_patch(polygon)
91
  plt.text(box_array[0, 0], box_array[0, 1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8))
@@ -101,7 +121,7 @@ def plot_bbox(image, data):
101
  draw.text((x1, y1), label, fill="white")
102
  return np.array(img_draw)
103
 
104
- @spaces.GPU(duration=130) #remains to be seen, increasing too much may leave people queueing for long
105
  def process_video(input_video_path, task_prompt):
106
  cap = cv2.VideoCapture(input_video_path)
107
  if not cap.isOpened():
@@ -118,7 +138,7 @@ def process_video(input_video_path, task_prompt):
118
 
119
  processed_frames = 0
120
  frame_results = []
121
- color_map = {} #consistency for chromakey possibility
122
 
123
  def get_color(label):
124
  if label not in color_map:
@@ -229,6 +249,10 @@ def process_video_p(input_video, task, text_input):
229
  return None, "Error: Video processing failed. Check logs above for info.", str(frame_results)
230
  return result, result, str(frame_results)
231
 
 
 
 
 
232
  with gr.Blocks() as demo:
233
  gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
234
 
@@ -300,7 +324,16 @@ with gr.Blocks() as demo:
300
 
301
  video_task_dropdown.change(fn=update_video_text_input, inputs=video_task_dropdown, outputs=video_text_input)
302
 
303
- submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])
304
- video_submit_btn.click(fn=process_video_p, inputs=[input_video, video_task_dropdown, video_text_input], outputs=[output_video, output_video, frame_results_output])
 
 
 
 
 
 
 
 
 
305
 
306
  demo.launch()
 
21
  imports.remove("flash_attn")
22
  return imports
23
 
24
+ def load_model():
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ "microsoft/Florence-2-large-ft",
29
+ trust_remote_code=True
30
+ ).to(device).eval()
31
+ processor = AutoProcessor.from_pretrained(
32
+ "microsoft/Florence-2-large-ft",
33
+ trust_remote_code=True
34
+ )
35
+ return model, processor, device
36
+
37
+ model = None
38
+ processor = None
39
+ device = None
40
+
41
+ @spaces.GPU
42
+ def initialize_model():
43
+ global model, processor, device
44
+ model, processor, device = load_model()
45
 
46
  def run_example(task_prompt, image, text_input=None):
47
+ global model, processor, device
48
+ if model is None or processor is None:
49
+ initialize_model()
50
+
51
  prompt = task_prompt if text_input is None else task_prompt + text_input
52
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
53
  with torch.inference_mode():
 
55
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
56
  return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
57
 
58
+ colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
59
+ 'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
60
+
61
  def fig_to_pil(fig):
62
  buf = io.BytesIO()
63
  fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
 
105
  bboxes, labels = prediction['quad_boxes'], prediction['labels']
106
  for box, label in zip(bboxes, labels):
107
  color = random.choice(colormap)
108
+ box_array = np.array(box).reshape(-1, 2)
109
  polygon = patches.Polygon(box_array, edgecolor=color, fill=False, linewidth=2)
110
  ax.add_patch(polygon)
111
  plt.text(box_array[0, 0], box_array[0, 1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8))
 
121
  draw.text((x1, y1), label, fill="white")
122
  return np.array(img_draw)
123
 
124
+ @spaces.GPU
125
  def process_video(input_video_path, task_prompt):
126
  cap = cv2.VideoCapture(input_video_path)
127
  if not cap.isOpened():
 
138
 
139
  processed_frames = 0
140
  frame_results = []
141
+ color_map = {}
142
 
143
  def get_color(label):
144
  if label not in color_map:
 
249
  return None, "Error: Video processing failed. Check logs above for info.", str(frame_results)
250
  return result, result, str(frame_results)
251
 
252
+ @spaces.GPU
253
+ def process_image_with_gpu(image, task, text):
254
+ return process_image(image, task, text)
255
+
256
  with gr.Blocks() as demo:
257
  gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
258
 
 
324
 
325
  video_task_dropdown.change(fn=update_video_text_input, inputs=video_task_dropdown, outputs=video_text_input)
326
 
327
+ submit_btn.click(
328
+ fn=process_image_with_gpu,
329
+ inputs=[input_img, task_dropdown, text_input],
330
+ outputs=[output_text, output_image]
331
+ )
332
+
333
+ video_submit_btn.click(
334
+ fn=process_video_p,
335
+ inputs=[input_video, video_task_dropdown, video_text_input],
336
+ outputs=[output_video, output_video, frame_results_output]
337
+ )
338
 
339
  demo.launch()