Saurabh Kumar commited on
Commit
763589d
·
verified ·
1 Parent(s): 34bbd4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+
4
+ # default: Load the model on the available device(s)
5
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
6
+ "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
7
+ )
8
+
9
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
10
+ # model = Qwen2VLForConditionalGeneration.from_pretrained(
11
+ # "Qwen/Qwen2-VL-7B-Instruct",
12
+ # torch_dtype=torch.bfloat16,
13
+ # attn_implementation="flash_attention_2",
14
+ # device_map="auto",
15
+ # )
16
+
17
+ # default processer
18
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
19
+
20
+ # The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
21
+ # min_pixels = 256*28*28
22
+ # max_pixels = 1280*28*28
23
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
24
+
25
+ # Streamlit app title
26
+ st.title("OCR Image Text Extraction")
27
+
28
+ # File uploader for images
29
+ uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
30
+
31
+ if uploaded_file is not None:
32
+ # Open the uploaded image file
33
+ image = Image.open(uploaded_file)
34
+ st.image(image, caption="Uploaded Image", use_column_width=True)
35
+
36
+ messages = [
37
+ {
38
+ "role": "user",
39
+ "content": [
40
+ {
41
+ "type": "image",
42
+ "image": "image",
43
+ },
44
+ {"type": "text", "text": "Run Optical Character recognition on the image."},
45
+ ],
46
+ }
47
+ ]
48
+
49
+ # Preparation for inference
50
+ text = processor.apply_chat_template(
51
+ messages, tokenize=False, add_generation_prompt=True
52
+ )
53
+ image_inputs, video_inputs = process_vision_info(messages)
54
+ inputs = processor(
55
+ text=[text],
56
+ images=image_inputs,
57
+ videos=video_inputs,
58
+ padding=True,
59
+ return_tensors="pt",
60
+ )
61
+ inputs = inputs.to("cuda")
62
+
63
+ # Inference: Generation of the output
64
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
65
+ generated_ids_trimmed = [
66
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
67
+ ]
68
+ output_text = processor.batch_decode(
69
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
70
+ )
71
+
72
+ st.subheader("Extracted Text:")
73
+ st.write(output_text)
74
+
75
+ # Keyword search functionality
76
+ st.subheader("Keyword Search")
77
+ search_query = st.text_input("Enter keywords to search within the extracted text")
78
+
79
+ if search_query:
80
+ # Check if the search query is in the extracted text
81
+ if search_query.lower() in extracted_text.lower():
82
+ highlighted_text = extracted_text.replace(search_query, f"**{search_query}**")
83
+ st.write(f"Matching Text: {highlighted_text}")
84
+ else:
85
+ st.write("No matching text found.")