Saurabh Kumar commited on
Commit
f90e854
·
verified ·
1 Parent(s): 7ca46d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -62
app.py CHANGED
@@ -4,12 +4,12 @@ import streamlit as st
4
  import torch
5
  from PIL import Image
6
 
7
- @st.cache_resource
8
  # default: Load the model on the available device(s)
9
- model = Qwen2VLForConditionalGeneration.from_pretrained(
10
- "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
11
- )
12
-
 
13
  # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
14
  # model = Qwen2VLForConditionalGeneration.from_pretrained(
15
  # "Qwen/Qwen2-VL-7B-Instruct",
@@ -17,74 +17,73 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
17
  # attn_implementation="flash_attention_2",
18
  # device_map="auto",
19
  # )
20
-
21
- # default processer
22
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
23
-
24
  # The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
25
  # min_pixels = 256*28*28
26
  # max_pixels = 1280*28*28
27
  # processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
28
 
29
  @st.cache_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Streamlit app title
31
  st.title("OCR Image Text Extraction")
32
 
33
  # File uploader for images
34
  uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
35
 
36
- if uploaded_file is not None:
37
- # Open the uploaded image file
38
- image = Image.open(uploaded_file)
39
- st.image(image, caption="Uploaded Image", use_column_width=True)
40
-
41
- messages = [
42
- {
43
- "role": "user",
44
- "content": [
45
- {
46
- "type": "image",
47
- "image": image,
48
- },
49
- {"type": "text", "text": "Run Optical Character recognition on the image."},
50
- ],
51
- }
52
- ]
53
-
54
- # Preparation for inference
55
- text = processor.apply_chat_template(
56
- messages, tokenize=False, add_generation_prompt=True
57
- )
58
- image_inputs, video_inputs = process_vision_info(messages)
59
- inputs = processor(
60
- text=[text],
61
- images=image_inputs,
62
- videos=video_inputs,
63
- padding=True,
64
- return_tensors="pt",
65
- )
66
- inputs = inputs.to("cpu")
67
-
68
- # Inference: Generation of the output
69
- generated_ids = model.generate(**inputs, max_new_tokens=128)
70
- generated_ids_trimmed = [
71
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
72
- ]
73
- output_text = processor.batch_decode(
74
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
75
- )
76
-
77
- st.subheader("Extracted Text:")
78
- st.write(output_text)
79
 
80
- # Keyword search functionality
81
- st.subheader("Keyword Search")
82
- search_query = st.text_input("Enter keywords to search within the extracted text")
83
 
84
- if search_query:
85
- # Check if the search query is in the extracted text
86
- if search_query.lower() in extracted_text.lower():
87
- highlighted_text = extracted_text.replace(search_query, f"**{search_query}**")
88
- st.write(f"Matching Text: {highlighted_text}")
89
- else:
90
- st.write("No matching text found.")
 
4
  import torch
5
  from PIL import Image
6
 
 
7
  # default: Load the model on the available device(s)
8
+ @st.cache_resource
9
+ def init_qwen_model():
10
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
11
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
12
+ return model, processor
13
  # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
14
  # model = Qwen2VLForConditionalGeneration.from_pretrained(
15
  # "Qwen/Qwen2-VL-7B-Instruct",
 
17
  # attn_implementation="flash_attention_2",
18
  # device_map="auto",
19
  # )
 
 
 
 
20
  # The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
21
  # min_pixels = 256*28*28
22
  # max_pixels = 1280*28*28
23
  # processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
24
 
25
  @st.cache_data
26
+ def get_qwen_text(uploaded_file):
27
+ if uploaded_file is not None:
28
+ # Open the uploaded image file
29
+ image = Image.open(uploaded_file)
30
+ st.image(image, caption="Uploaded Image", use_column_width=True)
31
+
32
+ messages = [
33
+ {
34
+ "role": "user",
35
+ "content": [
36
+ {
37
+ "type": "image",
38
+ "image": image,
39
+ },
40
+ {"type": "text", "text": "Run Optical Character recognition on the image."},
41
+ ],
42
+ }
43
+ ]
44
+
45
+ # Preparation for inference
46
+ text = processor.apply_chat_template(
47
+ messages, tokenize=False, add_generation_prompt=True
48
+ )
49
+ image_inputs, video_inputs = process_vision_info(messages)
50
+ inputs = processor(
51
+ text=[text],
52
+ images=image_inputs,
53
+ videos=video_inputs,
54
+ padding=True,
55
+ return_tensors="pt",
56
+ )
57
+ inputs = inputs.to("cpu")
58
+
59
+ # Inference: Generation of the output
60
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
61
+ generated_ids_trimmed = [
62
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
63
+ ]
64
+ output_text = processor.batch_decode(
65
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
66
+ )
67
+ return output_text
68
+
69
  # Streamlit app title
70
  st.title("OCR Image Text Extraction")
71
 
72
  # File uploader for images
73
  uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
74
 
75
+ st.subheader("Extracted Text:")
76
+ output = get_qwen_text(uploaded_file)
77
+ st.write(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # Keyword search functionality
80
+ st.subheader("Keyword Search")
81
+ search_query = st.text_input("Enter keywords to search within the extracted text")
82
 
83
+ if search_query:
84
+ # Check if the search query is in the extracted text
85
+ if search_query.lower() in output.lower():
86
+ highlighted_text = output.replace(search_query, f"**{search_query}**")
87
+ st.write(f"Matching Text: {highlighted_text}")
88
+ else:
89
+ st.write("No matching text found.")