File size: 2,955 Bytes
738e435
 
 
c744dc7
738e435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c744dc7
 
738e435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c744dc7
738e435
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import re

min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28

def model_inference(images):
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
    )
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

    images = [{"type": "image", "image": Image.open(image[0])} for image in images]
    
    messages = [{"role": "user", "content": "Please extract only the text from the following image."}, {"role": "user", "content": images}]
    
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = inputs.to(device)
    model = model.to(device)

    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]

    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    del model
    del processor
    return output_text[0]

def search_and_highlight(text, keywords):
    if not keywords:
        return text

    keywords = [kw.strip().lower() for kw in keywords.split(',')]
    highlighted_text = text

    for keyword in keywords:
        pattern = re.compile(re.escape(keyword), re.IGNORECASE)
        highlighted_text = pattern.sub(f'**{keyword}**', highlighted_text)

    return highlighted_text

def extract_and_search(images, keywords):
    extracted_text = model_inference(images)
    highlighted_text = search_and_highlight(extracted_text, keywords)
    return extracted_text, highlighted_text

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    with gr.Row():
        output_gallery = gr.Gallery(label="Image", height=300, show_label=True)
        keywords = gr.Textbox(placeholder="Enter keywords to search (comma-separated)", label="Search Keywords")
    
    extract_button = gr.Button("Extract Text and Search", variant="primary")
    
    with gr.Row():
        raw_output = gr.Textbox(label="Extracted Text")
        highlighted_output = gr.Markdown(label="Highlighted Search Results")

    extract_button.click(extract_and_search, inputs=[output_gallery, keywords], outputs=[raw_output, highlighted_output])

if __name__ == "__main__":
    demo.queue(max_size=10).launch(share=True)