shukdevdatta123 commited on
Commit
37acc53
·
verified ·
1 Parent(s): e1accc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -361
app.py CHANGED
@@ -1,379 +1,124 @@
1
  import gradio as gr
2
- from transformers.image_utils import load_image
3
- from threading import Thread
4
- import time
5
- import torch
6
- import cv2
7
- import numpy as np
8
- from PIL import Image
9
  import re
10
- import os
11
- from transformers import (
12
- Qwen2VLForConditionalGeneration,
13
- AutoProcessor,
14
- TextIteratorStreamer,
15
- )
16
- from transformers import Qwen2_5_VLForConditionalGeneration
17
 
18
- # ---------------------------
19
- # Helper Functions
20
- # ---------------------------
21
- def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
22
- """
23
- Returns an HTML snippet for a thin animated progress bar with a label.
24
- Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
25
- """
26
- return f'''
27
- <div style="display: flex; align-items: center;">
28
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
29
- <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
30
- <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
31
- </div>
32
- </div>
33
- <style>
34
- @keyframes loading {{
35
- 0% {{ transform: translateX(-100%); }}
36
- 100% {{ transform: translateX(100%); }}
37
- }}
38
- </style>
39
- '''
40
-
41
- def downsample_video(video_path):
42
- """
43
- Downsamples a video file by extracting 10 evenly spaced frames.
44
- Returns a list of tuples (PIL.Image, timestamp).
45
- """
46
- vidcap = cv2.VideoCapture(video_path)
47
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
48
- fps = vidcap.get(cv2.CAP_PROP_FPS)
49
- frames = []
50
- if total_frames <= 0 or fps <= 0:
51
- vidcap.release()
52
- return frames
53
- # Determine 10 evenly spaced frame indices.
54
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
55
- for i in frame_indices:
56
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
57
- success, image = vidcap.read()
58
- if success:
59
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
60
- pil_image = Image.fromarray(image)
61
- timestamp = round(i / fps, 2)
62
- frames.append((pil_image, timestamp))
63
- vidcap.release()
64
- return frames
65
-
66
- def extract_medicine_names(text):
67
- """
68
- Extracts medicine names from OCR text output.
69
- Uses a combination of pattern matching and formatting to identify medications.
70
- Returns a formatted list of medicines found.
71
- """
72
- # Common medicine patterns (extended to catch more formats)
73
- lines = text.split('\n')
74
- medicines = []
75
-
76
- # Look for patterns typical in prescriptions
77
- for line in lines:
78
- # Clean and standardize the line
79
- clean_line = line.strip()
80
-
81
- # Skip very short lines, headers, or non-relevant text
82
- if len(clean_line) < 3 or re.search(r'(prescription|rx|patient|name|date|doctor|hospital|clinic|address)', clean_line.lower()):
83
- continue
84
-
85
- # Medicine names often appear at the beginning of lines, with dosage info following
86
- # Look for tablet/capsule/mg indicators - strong indicators of medication
87
- if re.search(r'(tab|tablet|cap|capsule|mg|ml|injection|syrup|solution|suspension|ointment|cream|gel|patch|suppository|inhaler|drops)', clean_line.lower()):
88
- # Extract the likely medicine name - the part before the dosage/form or the entire line if it's short
89
- medicine_match = re.split(r'(\d+\s*mg|\d+\s*ml|\d+\s*tab|\d+\s*cap)', clean_line, 1)[0].strip()
90
- if medicine_match and len(medicine_match) > 2:
91
- medicines.append(medicine_match)
92
-
93
- # Check for brand names or generic medication patterns
94
- elif re.match(r'^[A-Z][a-z]+\s*[A-Z0-9]', clean_line) or re.match(r'^[A-Z][a-z]+', clean_line):
95
- # Likely a medicine name starting with a capital letter
96
- medicine_parts = re.split(r'(\d+|\s+\d+\s*times|\s+\d+\s*times\s+daily)', clean_line, 1)
97
- if medicine_parts and len(medicine_parts[0]) > 2:
98
- medicines.append(medicine_parts[0].strip())
99
-
100
- # Remove duplicates while preserving order
101
- unique_medicines = []
102
- for med in medicines:
103
- if med not in unique_medicines:
104
- unique_medicines.append(med)
105
 
106
- return unique_medicines
107
-
108
- # Check for CUDA availability
109
- device = "cuda" if torch.cuda.is_available() else "cpu"
110
- print(f"Using device: {device}")
111
-
112
- # Adjust model loading based on device
113
- dtype = torch.float16 if device == "cuda" else torch.float32
114
- bfdtype = torch.bfloat16 if device == "cuda" else torch.float32
115
-
116
- # Set lower precision for CPU if available
117
- if device == "cpu":
118
  try:
119
- # Check if Intel MKL is available for better CPU performance
120
- import intel_extension_for_pytorch as ipex
121
- dtype = torch.bfloat16
122
- print("Using Intel optimizations for PyTorch")
123
- except ImportError:
124
- print("Intel optimizations not available, using standard CPU mode")
125
-
126
- # Model and Processor Setup with proper error handling
127
- try:
128
- # Qwen2VL OCR (default branch)
129
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
130
- qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
131
- qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
132
- QV_MODEL_ID,
133
- trust_remote_code=True,
134
- torch_dtype=dtype,
135
- low_cpu_mem_usage=True,
136
- ).to(device).eval()
137
-
138
- # RolmOCR branch (@RolmOCR)
139
- ROLMOCR_MODEL_ID = "reducto/RolmOCR"
140
- rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
141
- rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
142
- ROLMOCR_MODEL_ID,
143
- trust_remote_code=True,
144
- torch_dtype=bfdtype,
145
- low_cpu_mem_usage=True,
146
- ).to(device).eval()
147
 
148
- models_loaded = True
149
- except Exception as e:
150
- print(f"Error loading models: {str(e)}")
151
- models_loaded = False
152
-
153
- # Main Inference Function
154
- def model_inference(input_dict, history):
155
- if not models_loaded:
156
- yield "Error: Models could not be loaded. Please check system requirements."
157
- return
158
-
159
- text = input_dict["text"].strip()
160
- files = input_dict.get("files", [])
161
 
162
- # Check for prescription-specific command
163
- if text.lower().startswith("@prescription") or text.lower().startswith("@med"):
164
- # Specific mode for medicine extraction
165
- if not files:
166
- yield "Error: Please upload a prescription image to extract medicine names."
167
- return
168
-
169
- # Use RolmOCR for better text extraction from prescriptions
170
- images = [load_image(image) for image in files[:1]] # Taking just the first image for processing
171
-
172
- messages = [{
173
- "role": "user",
174
- "content": [
175
- {"type": "image", "image": images[0]},
176
- {"type": "text", "text": "Extract all text from this medical prescription image, focus on medicine names, dosages, and instructions."},
 
 
 
 
 
 
 
 
 
 
 
 
177
  ],
178
- }]
 
179
 
180
- prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
- inputs = rolmocr_processor(
182
- text=[prompt_full],
183
- images=images,
184
- return_tensors="pt",
185
- padding=True,
186
- ).to(device)
187
 
188
- # First, get the complete OCR text
189
- streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
190
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
191
- thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
192
- thread.start()
193
 
194
- ocr_text = ""
195
- yield progress_bar_html("Processing Prescription with Medicine Extractor")
 
 
 
 
 
 
 
196
 
197
- for new_text in streamer:
198
- ocr_text += new_text
199
- ocr_text = ocr_text.replace("<|im_end|>", "")
200
- time.sleep(0.01)
201
 
202
- # After getting full OCR text, extract medicine names
203
- medicines = extract_medicine_names(ocr_text)
204
 
205
- # Format the results nicely
206
- result = "## Extracted Medicine Names\n\n"
207
- if medicines:
208
- for i, med in enumerate(medicines, 1):
209
- result += f"{i}. {med}\n"
210
- else:
211
- result += "No medicine names detected in the prescription.\n\n"
212
-
213
- result += "\n\n## Full OCR Text\n\n```\n" + ocr_text + "\n```"
214
- yield result
215
- return
216
 
217
- # RolmOCR Inference (@RolmOCR)
218
- if text.lower().startswith("@rolmocr"):
219
- # Remove the tag from the query.
220
- text_prompt = text[len("@rolmocr"):].strip()
221
- # Check if a video is provided for inference.
222
- if files and isinstance(files[0], str) and files[0].lower().endswith((".mp4", ".avi", ".mov")):
223
- video_path = files[0]
224
- frames = downsample_video(video_path)
225
- if not frames:
226
- yield "Error: Could not extract frames from the video."
227
- return
228
- # Build the message: prompt followed by each frame with its timestamp.
229
- content_list = [{"type": "text", "text": text_prompt}]
230
- for image, timestamp in frames:
231
- content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
232
- content_list.append({"type": "image", "image": image})
233
- messages = [{"role": "user", "content": content_list}]
234
- # For video, extract images only.
235
- video_images = [image for image, _ in frames]
236
- prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
237
- inputs = rolmocr_processor(
238
- text=[prompt_full],
239
- images=video_images,
240
- return_tensors="pt",
241
- padding=True,
242
- ).to(device)
243
- else:
244
- # Assume image(s) or text query.
245
- if len(files) > 1:
246
- images = [load_image(image) for image in files]
247
- elif len(files) == 1:
248
- images = [load_image(files[0])]
249
- else:
250
- images = []
251
- if text_prompt == "" and not images:
252
- yield "Error: Please input a text query and/or provide an image for the @RolmOCR feature."
253
- return
254
- messages = [{
255
- "role": "user",
256
- "content": [
257
- *[{"type": "image", "image": image} for image in images],
258
- {"type": "text", "text": text_prompt},
259
- ],
260
- }]
261
- prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
262
- inputs = rolmocr_processor(
263
- text=[prompt_full],
264
- images=images if images else None,
265
- return_tensors="pt",
266
- padding=True,
267
- ).to(device)
268
- streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
269
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
270
- thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
271
- thread.start()
272
- buffer = ""
273
- # Use a different color scheme for RolmOCR (purple-themed).
274
- yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)")
275
- for new_text in streamer:
276
- buffer += new_text
277
- buffer = buffer.replace("<|im_end|>", "")
278
- time.sleep(0.01)
279
- yield buffer
280
- return
281
-
282
- # Default Inference: Qwen2VL OCR
283
- # Process files: support multiple images.
284
- if len(files) > 1:
285
- images = [load_image(image) for image in files]
286
- elif len(files) == 1:
287
- images = [load_image(files[0])]
288
- else:
289
- images = []
290
 
291
- if text == "" and not images:
292
- yield "Error: Please input a text query and optionally image(s)."
293
- return
294
- if text == "" and images:
295
- yield "Error: Please input a text query along with the image(s)."
296
- return
297
-
298
- messages = [{
299
- "role": "user",
300
- "content": [
301
- *[{"type": "image", "image": image} for image in images],
302
- {"type": "text", "text": text},
303
- ],
304
- }]
305
- prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
306
- inputs = qwen_processor(
307
- text=[prompt_full],
308
- images=images if images else None,
309
- return_tensors="pt",
310
- padding=True,
311
- ).to(device)
312
- streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
313
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
314
- thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
315
- thread.start()
316
- buffer = ""
317
- yield progress_bar_html("Processing with Qwen2VL OCR")
318
- for new_text in streamer:
319
- buffer += new_text
320
- buffer = buffer.replace("<|im_end|>", "")
321
- time.sleep(0.01)
322
- yield buffer
323
-
324
- # Gradio Interface
325
- examples = [
326
- [{"text": "@Prescription Extract medicines from this prescription", "files": ["examples/prescription1.jpg"]}],
327
- [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
328
- [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
329
- [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
330
- ]
331
-
332
- css = """
333
- .gradio-container {
334
- font-family: 'Roboto', sans-serif;
335
- }
336
- .prescription-header {
337
- background-color: #4B0082;
338
- color: white;
339
- padding: 10px;
340
- border-radius: 5px;
341
- margin-bottom: 10px;
342
- }
343
- """
344
-
345
- description = """
346
- # **Multimodal OCR with Medicine Extraction**
347
-
348
- ## Modes:
349
- - **@Prescription** - Upload a prescription image to extract medicine names
350
- - **@RolmOCR** - Use RolmOCR for general text extraction
351
- - **Default** - Use Qwen2VL OCR for general purposes
352
-
353
- Upload your medical prescription images and get the medicine names extracted automatically!
354
- """
355
-
356
- # Memory optimization for Hugging Face Spaces
357
- import gc
358
- max_memory = {i: f"{15}GiB" for i in range(torch.cuda.device_count())}
359
-
360
- demo = gr.ChatInterface(
361
- fn=model_inference,
362
- description=description,
363
- examples=examples,
364
- textbox=gr.MultimodalTextbox(
365
- label="Query Input",
366
- file_types=["image", "video"],
367
- file_count="multiple",
368
- placeholder="Use @Prescription to extract medicines, @RolmOCR for RolmOCR, or leave blank for default Qwen2VL OCR"
369
- ),
370
- stop_btn="Stop Generation",
371
- multimodal=True,
372
- cache_examples=False,
373
- css=css
374
- )
375
 
 
376
  if __name__ == "__main__":
377
- # Add queue to prevent timeouts
378
- demo.queue(concurrency_count=1)
379
- demo.launch(debug=True, share=False)
 
1
  import gradio as gr
2
+ from openai import OpenAI
 
 
 
 
 
 
3
  import re
 
 
 
 
 
 
 
4
 
5
+ def get_openrouter_client(api_key):
6
+ """Initialize OpenRouter client with user-provided API key"""
7
+ if not api_key or api_key.strip() == "":
8
+ return None, "Please enter your OpenRouter API key"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
+ client = OpenAI(
12
+ base_url="https://openrouter.ai/api/v1",
13
+ api_key=api_key
14
+ )
15
+ return client, None
16
+ except Exception as e:
17
+ return None, f"Error initializing client: {str(e)}"
18
+
19
+ def extract_medicine_names(image, api_key):
20
+ """Extract medicine names from a prescription image using Gemini via OpenRouter"""
21
+ if not image:
22
+ return "Please upload a prescription image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Get client with user-provided API key
25
+ client, error = get_openrouter_client(api_key)
26
+ if error:
27
+ return error
 
 
 
 
 
 
 
 
 
28
 
29
+ try:
30
+ response = client.chat.completions.create(
31
+ extra_headers={
32
+ "HTTP-Referer": "https://medicine-extractor-app.com",
33
+ "X-Title": "Medicine Name Extractor",
34
+ },
35
+ model="google/gemini-2.5-pro-exp-03-25:free",
36
+ messages=[
37
+ {
38
+ "role": "system",
39
+ "content": "You are an AI specialized in extracting medication names from prescription images. Only list the medication names, nothing else."
40
+ },
41
+ {
42
+ "role": "user",
43
+ "content": [
44
+ {
45
+ "type": "text",
46
+ "text": "Extract ONLY the names of medications from this prescription image. Provide them as a numbered list. If this isn't a medical prescription, respond with 'No prescription detected'."
47
+ },
48
+ {
49
+ "type": "image_url",
50
+ "image_url": {
51
+ "url": image
52
+ }
53
+ }
54
+ ]
55
+ }
56
  ],
57
+ max_tokens=300
58
+ )
59
 
60
+ result = response.choices[0].message.content.strip()
 
 
 
 
 
 
61
 
62
+ # Check if no prescription was detected
63
+ if "No prescription detected" in result:
64
+ return "No prescription detected in the image."
 
 
65
 
66
+ # Clean up the response to just include the medication names
67
+ # Remove any explanatory text that might appear before or after the list
68
+ medicines = []
69
+ for line in result.split('\n'):
70
+ # Look for numbered lines or lines starting with medication names
71
+ if re.match(r'^\d+\.', line.strip()):
72
+ # Extract text after the number and period
73
+ med_name = re.sub(r'^\d+\.\s*', '', line.strip())
74
+ medicines.append(med_name)
75
 
76
+ if not medicines:
77
+ # If numbered list processing didn't work, return the raw output
78
+ return result
 
79
 
80
+ return "\n".join([f"{i+1}. {med}" for i, med in enumerate(medicines)])
 
81
 
82
+ except Exception as e:
83
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
84
 
85
+ # Create the Gradio interface
86
+ with gr.Blocks(title="Prescription Medicine Extractor") as app:
87
+ gr.Markdown("# Prescription Medicine Name Extractor")
88
+ gr.Markdown("Upload a prescription image to extract medication names.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ api_key = gr.Textbox(
91
+ label="OpenRouter API Key",
92
+ placeholder="Enter your OpenRouter API key here",
93
+ type="password"
94
+ )
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ image_input = gr.Image(type="filepath", label="Upload Prescription Image")
99
+ submit_btn = gr.Button("Extract Medicine Names", variant="primary")
100
+
101
+ with gr.Column():
102
+ output = gr.Textbox(label="Extracted Medicine Names", lines=10)
103
+
104
+ submit_btn.click(
105
+ fn=extract_medicine_names,
106
+ inputs=[image_input, api_key],
107
+ outputs=[output]
108
+ )
109
+
110
+ gr.Markdown("""
111
+ ## Usage Instructions
112
+ 1. Enter your OpenRouter API key (get one from https://openrouter.ai)
113
+ 2. Upload a clear image of a medical prescription
114
+ 3. Click the "Extract Medicine Names" button
115
+ 4. The names of medications will be displayed in the output box
116
+
117
+ **Note:** For best results, ensure the image is clear and the text is readable.
118
+ **Privacy Notice:** Your API key and images are processed only during the active session and are not stored.
119
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # Launch the app
122
  if __name__ == "__main__":
123
+ print("Starting Prescription Medicine Name Extractor application...")
124
+ app.launch()