Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,11 +3,11 @@ from transformers.image_utils import load_image
|
|
3 |
from threading import Thread
|
4 |
import time
|
5 |
import torch
|
6 |
-
import spaces
|
7 |
import cv2
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
import re
|
|
|
11 |
from transformers import (
|
12 |
Qwen2VLForConditionalGeneration,
|
13 |
AutoProcessor,
|
@@ -105,28 +105,57 @@ def extract_medicine_names(text):
|
|
105 |
|
106 |
return unique_medicines
|
107 |
|
108 |
-
#
|
109 |
-
|
110 |
-
|
111 |
-
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
|
112 |
-
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
113 |
-
QV_MODEL_ID,
|
114 |
-
trust_remote_code=True,
|
115 |
-
torch_dtype=torch.float16
|
116 |
-
).to("cuda").eval()
|
117 |
|
118 |
-
#
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
# Main Inference Function
|
128 |
-
@spaces.GPU
|
129 |
def model_inference(input_dict, history):
|
|
|
|
|
|
|
|
|
130 |
text = input_dict["text"].strip()
|
131 |
files = input_dict.get("files", [])
|
132 |
|
@@ -154,7 +183,7 @@ def model_inference(input_dict, history):
|
|
154 |
images=images,
|
155 |
return_tensors="pt",
|
156 |
padding=True,
|
157 |
-
).to(
|
158 |
|
159 |
# First, get the complete OCR text
|
160 |
streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
|
@@ -210,7 +239,7 @@ def model_inference(input_dict, history):
|
|
210 |
images=video_images,
|
211 |
return_tensors="pt",
|
212 |
padding=True,
|
213 |
-
).to(
|
214 |
else:
|
215 |
# Assume image(s) or text query.
|
216 |
if len(files) > 1:
|
@@ -235,7 +264,7 @@ def model_inference(input_dict, history):
|
|
235 |
images=images if images else None,
|
236 |
return_tensors="pt",
|
237 |
padding=True,
|
238 |
-
).to(
|
239 |
streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
|
240 |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
|
241 |
thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
|
@@ -279,7 +308,7 @@ def model_inference(input_dict, history):
|
|
279 |
images=images if images else None,
|
280 |
return_tensors="pt",
|
281 |
padding=True,
|
282 |
-
).to(
|
283 |
streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
|
284 |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
|
285 |
thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
|
@@ -296,7 +325,6 @@ def model_inference(input_dict, history):
|
|
296 |
examples = [
|
297 |
[{"text": "@Prescription Extract medicines from this prescription", "files": ["examples/prescription1.jpg"]}],
|
298 |
[{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
|
299 |
-
[{"text": "@RolmOCR Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
|
300 |
[{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
|
301 |
[{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
|
302 |
]
|
@@ -325,6 +353,10 @@ description = """
|
|
325 |
Upload your medical prescription images and get the medicine names extracted automatically!
|
326 |
"""
|
327 |
|
|
|
|
|
|
|
|
|
328 |
demo = gr.ChatInterface(
|
329 |
fn=model_inference,
|
330 |
description=description,
|
@@ -341,4 +373,7 @@ demo = gr.ChatInterface(
|
|
341 |
css=css
|
342 |
)
|
343 |
|
344 |
-
|
|
|
|
|
|
|
|
3 |
from threading import Thread
|
4 |
import time
|
5 |
import torch
|
|
|
6 |
import cv2
|
7 |
import numpy as np
|
8 |
from PIL import Image
|
9 |
import re
|
10 |
+
import os
|
11 |
from transformers import (
|
12 |
Qwen2VLForConditionalGeneration,
|
13 |
AutoProcessor,
|
|
|
105 |
|
106 |
return unique_medicines
|
107 |
|
108 |
+
# Check for CUDA availability
|
109 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
+
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
# Adjust model loading based on device
|
113 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
114 |
+
bfdtype = torch.bfloat16 if device == "cuda" else torch.float32
|
115 |
+
|
116 |
+
# Set lower precision for CPU if available
|
117 |
+
if device == "cpu":
|
118 |
+
try:
|
119 |
+
# Check if Intel MKL is available for better CPU performance
|
120 |
+
import intel_extension_for_pytorch as ipex
|
121 |
+
dtype = torch.bfloat16
|
122 |
+
print("Using Intel optimizations for PyTorch")
|
123 |
+
except ImportError:
|
124 |
+
print("Intel optimizations not available, using standard CPU mode")
|
125 |
+
|
126 |
+
# Model and Processor Setup with proper error handling
|
127 |
+
try:
|
128 |
+
# Qwen2VL OCR (default branch)
|
129 |
+
QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
|
130 |
+
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
|
131 |
+
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
132 |
+
QV_MODEL_ID,
|
133 |
+
trust_remote_code=True,
|
134 |
+
torch_dtype=dtype,
|
135 |
+
low_cpu_mem_usage=True,
|
136 |
+
).to(device).eval()
|
137 |
+
|
138 |
+
# RolmOCR branch (@RolmOCR)
|
139 |
+
ROLMOCR_MODEL_ID = "reducto/RolmOCR"
|
140 |
+
rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
|
141 |
+
rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
142 |
+
ROLMOCR_MODEL_ID,
|
143 |
+
trust_remote_code=True,
|
144 |
+
torch_dtype=bfdtype,
|
145 |
+
low_cpu_mem_usage=True,
|
146 |
+
).to(device).eval()
|
147 |
+
|
148 |
+
models_loaded = True
|
149 |
+
except Exception as e:
|
150 |
+
print(f"Error loading models: {str(e)}")
|
151 |
+
models_loaded = False
|
152 |
|
153 |
# Main Inference Function
|
|
|
154 |
def model_inference(input_dict, history):
|
155 |
+
if not models_loaded:
|
156 |
+
yield "Error: Models could not be loaded. Please check system requirements."
|
157 |
+
return
|
158 |
+
|
159 |
text = input_dict["text"].strip()
|
160 |
files = input_dict.get("files", [])
|
161 |
|
|
|
183 |
images=images,
|
184 |
return_tensors="pt",
|
185 |
padding=True,
|
186 |
+
).to(device)
|
187 |
|
188 |
# First, get the complete OCR text
|
189 |
streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
|
|
|
239 |
images=video_images,
|
240 |
return_tensors="pt",
|
241 |
padding=True,
|
242 |
+
).to(device)
|
243 |
else:
|
244 |
# Assume image(s) or text query.
|
245 |
if len(files) > 1:
|
|
|
264 |
images=images if images else None,
|
265 |
return_tensors="pt",
|
266 |
padding=True,
|
267 |
+
).to(device)
|
268 |
streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
|
269 |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
|
270 |
thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
|
|
|
308 |
images=images if images else None,
|
309 |
return_tensors="pt",
|
310 |
padding=True,
|
311 |
+
).to(device)
|
312 |
streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
|
313 |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
|
314 |
thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
|
|
|
325 |
examples = [
|
326 |
[{"text": "@Prescription Extract medicines from this prescription", "files": ["examples/prescription1.jpg"]}],
|
327 |
[{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
|
|
|
328 |
[{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
|
329 |
[{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
|
330 |
]
|
|
|
353 |
Upload your medical prescription images and get the medicine names extracted automatically!
|
354 |
"""
|
355 |
|
356 |
+
# Memory optimization for Hugging Face Spaces
|
357 |
+
import gc
|
358 |
+
max_memory = {i: f"{15}GiB" for i in range(torch.cuda.device_count())}
|
359 |
+
|
360 |
demo = gr.ChatInterface(
|
361 |
fn=model_inference,
|
362 |
description=description,
|
|
|
373 |
css=css
|
374 |
)
|
375 |
|
376 |
+
if __name__ == "__main__":
|
377 |
+
# Add queue to prevent timeouts
|
378 |
+
demo.queue(concurrency_count=1)
|
379 |
+
demo.launch(debug=True, share=False)
|