shukdevdatta123 commited on
Commit
e1accc9
·
verified ·
1 Parent(s): 79a9268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -25
app.py CHANGED
@@ -3,11 +3,11 @@ from transformers.image_utils import load_image
3
  from threading import Thread
4
  import time
5
  import torch
6
- import spaces
7
  import cv2
8
  import numpy as np
9
  from PIL import Image
10
  import re
 
11
  from transformers import (
12
  Qwen2VLForConditionalGeneration,
13
  AutoProcessor,
@@ -105,28 +105,57 @@ def extract_medicine_names(text):
105
 
106
  return unique_medicines
107
 
108
- # Model and Processor Setup
109
- # Qwen2VL OCR (default branch)
110
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
111
- qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
112
- qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
113
- QV_MODEL_ID,
114
- trust_remote_code=True,
115
- torch_dtype=torch.float16
116
- ).to("cuda").eval()
117
 
118
- # RolmOCR branch (@RolmOCR)
119
- ROLMOCR_MODEL_ID = "reducto/RolmOCR"
120
- rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
121
- rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
122
- ROLMOCR_MODEL_ID,
123
- trust_remote_code=True,
124
- torch_dtype=torch.bfloat16
125
- ).to("cuda").eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Main Inference Function
128
- @spaces.GPU
129
  def model_inference(input_dict, history):
 
 
 
 
130
  text = input_dict["text"].strip()
131
  files = input_dict.get("files", [])
132
 
@@ -154,7 +183,7 @@ def model_inference(input_dict, history):
154
  images=images,
155
  return_tensors="pt",
156
  padding=True,
157
- ).to("cuda")
158
 
159
  # First, get the complete OCR text
160
  streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
@@ -210,7 +239,7 @@ def model_inference(input_dict, history):
210
  images=video_images,
211
  return_tensors="pt",
212
  padding=True,
213
- ).to("cuda")
214
  else:
215
  # Assume image(s) or text query.
216
  if len(files) > 1:
@@ -235,7 +264,7 @@ def model_inference(input_dict, history):
235
  images=images if images else None,
236
  return_tensors="pt",
237
  padding=True,
238
- ).to("cuda")
239
  streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
240
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
241
  thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
@@ -279,7 +308,7 @@ def model_inference(input_dict, history):
279
  images=images if images else None,
280
  return_tensors="pt",
281
  padding=True,
282
- ).to("cuda")
283
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
284
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
285
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
@@ -296,7 +325,6 @@ def model_inference(input_dict, history):
296
  examples = [
297
  [{"text": "@Prescription Extract medicines from this prescription", "files": ["examples/prescription1.jpg"]}],
298
  [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
299
- [{"text": "@RolmOCR Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
300
  [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
301
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
302
  ]
@@ -325,6 +353,10 @@ description = """
325
  Upload your medical prescription images and get the medicine names extracted automatically!
326
  """
327
 
 
 
 
 
328
  demo = gr.ChatInterface(
329
  fn=model_inference,
330
  description=description,
@@ -341,4 +373,7 @@ demo = gr.ChatInterface(
341
  css=css
342
  )
343
 
344
- demo.launch(debug=True)
 
 
 
 
3
  from threading import Thread
4
  import time
5
  import torch
 
6
  import cv2
7
  import numpy as np
8
  from PIL import Image
9
  import re
10
+ import os
11
  from transformers import (
12
  Qwen2VLForConditionalGeneration,
13
  AutoProcessor,
 
105
 
106
  return unique_medicines
107
 
108
+ # Check for CUDA availability
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+ print(f"Using device: {device}")
 
 
 
 
 
 
111
 
112
+ # Adjust model loading based on device
113
+ dtype = torch.float16 if device == "cuda" else torch.float32
114
+ bfdtype = torch.bfloat16 if device == "cuda" else torch.float32
115
+
116
+ # Set lower precision for CPU if available
117
+ if device == "cpu":
118
+ try:
119
+ # Check if Intel MKL is available for better CPU performance
120
+ import intel_extension_for_pytorch as ipex
121
+ dtype = torch.bfloat16
122
+ print("Using Intel optimizations for PyTorch")
123
+ except ImportError:
124
+ print("Intel optimizations not available, using standard CPU mode")
125
+
126
+ # Model and Processor Setup with proper error handling
127
+ try:
128
+ # Qwen2VL OCR (default branch)
129
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
130
+ qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
131
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
132
+ QV_MODEL_ID,
133
+ trust_remote_code=True,
134
+ torch_dtype=dtype,
135
+ low_cpu_mem_usage=True,
136
+ ).to(device).eval()
137
+
138
+ # RolmOCR branch (@RolmOCR)
139
+ ROLMOCR_MODEL_ID = "reducto/RolmOCR"
140
+ rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
141
+ rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
142
+ ROLMOCR_MODEL_ID,
143
+ trust_remote_code=True,
144
+ torch_dtype=bfdtype,
145
+ low_cpu_mem_usage=True,
146
+ ).to(device).eval()
147
+
148
+ models_loaded = True
149
+ except Exception as e:
150
+ print(f"Error loading models: {str(e)}")
151
+ models_loaded = False
152
 
153
  # Main Inference Function
 
154
  def model_inference(input_dict, history):
155
+ if not models_loaded:
156
+ yield "Error: Models could not be loaded. Please check system requirements."
157
+ return
158
+
159
  text = input_dict["text"].strip()
160
  files = input_dict.get("files", [])
161
 
 
183
  images=images,
184
  return_tensors="pt",
185
  padding=True,
186
+ ).to(device)
187
 
188
  # First, get the complete OCR text
189
  streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
 
239
  images=video_images,
240
  return_tensors="pt",
241
  padding=True,
242
+ ).to(device)
243
  else:
244
  # Assume image(s) or text query.
245
  if len(files) > 1:
 
264
  images=images if images else None,
265
  return_tensors="pt",
266
  padding=True,
267
+ ).to(device)
268
  streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
269
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
270
  thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
 
308
  images=images if images else None,
309
  return_tensors="pt",
310
  padding=True,
311
+ ).to(device)
312
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
313
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
314
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
 
325
  examples = [
326
  [{"text": "@Prescription Extract medicines from this prescription", "files": ["examples/prescription1.jpg"]}],
327
  [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
 
328
  [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
329
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
330
  ]
 
353
  Upload your medical prescription images and get the medicine names extracted automatically!
354
  """
355
 
356
+ # Memory optimization for Hugging Face Spaces
357
+ import gc
358
+ max_memory = {i: f"{15}GiB" for i in range(torch.cuda.device_count())}
359
+
360
  demo = gr.ChatInterface(
361
  fn=model_inference,
362
  description=description,
 
373
  css=css
374
  )
375
 
376
+ if __name__ == "__main__":
377
+ # Add queue to prevent timeouts
378
+ demo.queue(concurrency_count=1)
379
+ demo.launch(debug=True, share=False)