Spaces:
Running
Running
File size: 5,109 Bytes
5dbe551 97b296f 5dbe551 97b296f 5dbe551 97b296f 37acc53 97b296f 4905934 97b296f ade4954 97b296f 5dbe551 97b296f 4905934 97b296f 37acc53 97b296f 37acc53 97b296f 37acc53 4905934 ade4954 97b296f 37acc53 97b296f 5dbe551 37acc53 e1accc9 37acc53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import re
# Load the model on CPU
def load_model():
model = Qwen2VLForConditionalGeneration.from_pretrained(
"prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
torch_dtype=torch.float32,
device_map="cpu"
)
processor = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
return model, processor
# Function to extract medicine names
def extract_medicine_names(image):
model, processor = load_model()
# Prepare the message with the specific prompt for medicine extraction
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
],
}
]
# Prepare for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Generate output
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Remove <|im_end|> and any other special tokens that might appear in the output
output_text = output_text.replace("<|im_end|>", "").strip()
return output_text
# Create a singleton model and processor to avoid reloading for each request
model_instance = None
processor_instance = None
def get_model_and_processor():
global model_instance, processor_instance
if model_instance is None or processor_instance is None:
model_instance, processor_instance = load_model()
return model_instance, processor_instance
# Optimized extraction function that uses the singleton model
def extract_medicine_names_optimized(image):
if image is None:
return "Please upload an image."
model, processor = get_model_and_processor()
# Prepare the message with the specific prompt for medicine extraction
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
],
}
]
# Prepare for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Generate output
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Remove <|im_end|> and any other special tokens that might appear in the output
output_text = output_text.replace("<|im_end|>", "").strip()
return output_text
# Create Gradio interface
with gr.Blocks(title="Medicine Name Extractor") as app:
gr.Markdown("# Medicine Name Extractor")
gr.Markdown("Upload a medical prescription image to extract the names of medicines.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Prescription Image")
extract_btn = gr.Button("Extract Medicine Names", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Extracted Medicine Names", lines=10)
extract_btn.click(
fn=extract_medicine_names_optimized,
inputs=input_image,
outputs=output_text
)
gr.Markdown("### Notes")
gr.Markdown("- This tool uses the Qwen2-VL-OCR model to extract text from prescription images")
gr.Markdown("- For best results, ensure the prescription image is clear and readable")
gr.Markdown("- Processing may take some time as the model runs on CPU")
# Launch the app
if __name__ == "__main__":
app.launch() |