File size: 5,109 Bytes
5dbe551
97b296f
 
 
 
5dbe551
97b296f
 
 
 
 
 
 
 
 
 
 
 
 
5dbe551
97b296f
 
 
 
 
37acc53
97b296f
 
 
 
 
 
 
4905934
97b296f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ade4954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97b296f
5dbe551
97b296f
4905934
97b296f
 
37acc53
 
 
97b296f
 
37acc53
 
97b296f
37acc53
4905934
ade4954
97b296f
 
37acc53
 
97b296f
 
 
 
5dbe551
37acc53
e1accc9
37acc53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import re

# Load the model on CPU
def load_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "prithivMLmods/Qwen2-VL-OCR-2B-Instruct", 
        torch_dtype=torch.float32,
        device_map="cpu"
    )
    processor = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
    return model, processor

# Function to extract medicine names
def extract_medicine_names(image):
    model, processor = load_model()
    
    # Prepare the message with the specific prompt for medicine extraction
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
            ],
        }
    ]
    
    # Prepare for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    
    # Generate output
    generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    # Remove <|im_end|> and any other special tokens that might appear in the output
    output_text = output_text.replace("<|im_end|>", "").strip()
    
    return output_text

# Create a singleton model and processor to avoid reloading for each request
model_instance = None
processor_instance = None

def get_model_and_processor():
    global model_instance, processor_instance
    if model_instance is None or processor_instance is None:
        model_instance, processor_instance = load_model()
    return model_instance, processor_instance

# Optimized extraction function that uses the singleton model
def extract_medicine_names_optimized(image):
    if image is None:
        return "Please upload an image."
    
    model, processor = get_model_and_processor()
    
    # Prepare the message with the specific prompt for medicine extraction
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
            ],
        }
    ]
    
    # Prepare for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    
    # Generate output
    generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    # Remove <|im_end|> and any other special tokens that might appear in the output
    output_text = output_text.replace("<|im_end|>", "").strip()
    
    return output_text

# Create Gradio interface
with gr.Blocks(title="Medicine Name Extractor") as app:
    gr.Markdown("# Medicine Name Extractor")
    gr.Markdown("Upload a medical prescription image to extract the names of medicines.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Upload Prescription Image")
            extract_btn = gr.Button("Extract Medicine Names", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="Extracted Medicine Names", lines=10)
    
    extract_btn.click(
        fn=extract_medicine_names_optimized,
        inputs=input_image,
        outputs=output_text
    )
    
    gr.Markdown("### Notes")
    gr.Markdown("- This tool uses the Qwen2-VL-OCR model to extract text from prescription images")
    gr.Markdown("- For best results, ensure the prescription image is clear and readable")
    gr.Markdown("- Processing may take some time as the model runs on CPU")

# Launch the app
if __name__ == "__main__":
    app.launch()