ayush2607 commited on
Commit
150fc7f
·
verified ·
1 Parent(s): be3cffa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -38
app.py CHANGED
@@ -1,24 +1,21 @@
1
  import streamlit as st
2
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
  from PIL import Image
 
6
 
7
  @st.cache_resource
8
  def load_model():
9
- # Load model on CPU
10
  model = Qwen2VLForConditionalGeneration.from_pretrained(
11
  "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float32, device_map=None
12
- ).to("cpu") # type:ignore # Ensure the model is on CPU
13
-
14
  min_pixels = 256*28*28
15
  max_pixels = 1280*28*28
16
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
17
  return model, processor
18
-
19
 
20
- def process_file(img, model, processor):
21
- # Prepare the image for the model
22
  messages = [
23
  {
24
  "role": "system",
@@ -29,7 +26,7 @@ def process_file(img, model, processor):
29
  "content": [
30
  {
31
  "type": "image",
32
- "image": img, # Pass the image object directly
33
  },
34
  {
35
  "type": "text",
@@ -39,7 +36,6 @@ def process_file(img, model, processor):
39
  }
40
  ]
41
 
42
- # Process the image for inference
43
  text = processor.apply_chat_template(
44
  messages, tokenize=False, add_generation_prompt=True
45
  )
@@ -51,23 +47,22 @@ def process_file(img, model, processor):
51
  padding=True,
52
  return_tensors="pt",
53
  )
54
- inputs = inputs.to("cpu") # Send the inputs to CPU
55
 
56
- # Inference on CPU
57
- generated_ids = model.generate(**inputs, max_new_tokens=200)
58
- generated_ids_trimmed = [
59
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
60
- ]
61
- output_text = processor.batch_decode(
62
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
63
- )
64
-
65
- return output_text[0]
66
 
 
 
67
 
68
- # Streamlit app
69
- st.title("OCR Application with Keyword Search")
70
 
 
 
 
 
 
71
 
72
  # Initialize session state variables
73
  if 'current_image' not in st.session_state:
@@ -75,9 +70,6 @@ if 'current_image' not in st.session_state:
75
  if 'extracted_text' not in st.session_state:
76
  st.session_state.extracted_text = None
77
 
78
-
79
- model, processor = load_model()
80
-
81
  # Upload image
82
  uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
83
 
@@ -85,19 +77,32 @@ if uploaded_file is not None:
85
  # Convert the uploaded file to an image
86
  img = Image.open(uploaded_file)
87
 
88
- if st.session_state.current_image != uploaded_file:
89
- st.session_state.current_image = uploaded_file
90
- st.session_state.extracted_text = process_file(img, model, processor)
91
-
92
  # Display the uploaded image
93
  st.image(img, caption="Uploaded Image", use_column_width=True)
94
 
95
- # if 'extracted_text' not in st.session_state:
96
- # st.session_state.extracted_text = process_file(img, model, processor)
97
-
98
- # Display the extracted text
99
- st.subheader("Extracted Text")
100
- st.write(st.session_state.extracted_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # Keyword Search
103
  keyword = st.text_input("Enter keyword to search in the extracted text")
@@ -105,9 +110,8 @@ if keyword and st.session_state.extracted_text:
105
  if keyword.lower() in st.session_state.extracted_text.lower():
106
  highlighted_text = st.session_state.extracted_text.replace(keyword, f"**{keyword}**")
107
  st.subheader("Keyword Found")
108
- st.markdown(highlighted_text, unsafe_allow_html=True)
109
  else:
110
  st.write("Keyword not found in the extracted text.")
111
  elif keyword:
112
- st.write("Please upload an image first to extract text.")
113
-
 
1
  import streamlit as st
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
  from PIL import Image
6
+ from threading import Thread
7
 
8
  @st.cache_resource
9
  def load_model():
 
10
  model = Qwen2VLForConditionalGeneration.from_pretrained(
11
  "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float32, device_map=None
12
+ ).to("cpu")
 
13
  min_pixels = 256*28*28
14
  max_pixels = 1280*28*28
15
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
16
  return model, processor
 
17
 
18
+ def process_file_streaming(img, model, processor):
 
19
  messages = [
20
  {
21
  "role": "system",
 
26
  "content": [
27
  {
28
  "type": "image",
29
+ "image": img,
30
  },
31
  {
32
  "type": "text",
 
36
  }
37
  ]
38
 
 
39
  text = processor.apply_chat_template(
40
  messages, tokenize=False, add_generation_prompt=True
41
  )
 
47
  padding=True,
48
  return_tensors="pt",
49
  )
50
+ inputs = inputs.to("cpu")
51
 
52
+ # Stream tokens
53
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
54
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200)
 
 
 
 
 
 
 
55
 
56
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
57
+ thread.start()
58
 
59
+ return streamer
 
60
 
61
+ # Load model and processor once
62
+ model, processor = load_model()
63
+
64
+ # Streamlit app
65
+ st.title("OCR Application with Real-Time Token Streaming")
66
 
67
  # Initialize session state variables
68
  if 'current_image' not in st.session_state:
 
70
  if 'extracted_text' not in st.session_state:
71
  st.session_state.extracted_text = None
72
 
 
 
 
73
  # Upload image
74
  uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
75
 
 
77
  # Convert the uploaded file to an image
78
  img = Image.open(uploaded_file)
79
 
 
 
 
 
80
  # Display the uploaded image
81
  st.image(img, caption="Uploaded Image", use_column_width=True)
82
 
83
+ # Check if the uploaded image is different from the current one
84
+ if st.session_state.current_image != uploaded_file:
85
+ st.session_state.current_image = uploaded_file
86
+
87
+ # Process the image with streaming
88
+ streamer = process_file_streaming(img, model, processor)
89
+
90
+ # Display streaming results
91
+ st.subheader("Extracted Text (Streaming)")
92
+ text_placeholder = st.empty()
93
+ collected_text = ""
94
+
95
+ for new_text in streamer:
96
+ collected_text += new_text
97
+ text_placeholder.markdown(collected_text)
98
+
99
+ # Store the final extracted text
100
+ st.session_state.extracted_text = collected_text
101
+
102
+ else:
103
+ # Display the previously extracted text
104
+ st.subheader("Extracted Text")
105
+ st.write(st.session_state.extracted_text)
106
 
107
  # Keyword Search
108
  keyword = st.text_input("Enter keyword to search in the extracted text")
 
110
  if keyword.lower() in st.session_state.extracted_text.lower():
111
  highlighted_text = st.session_state.extracted_text.replace(keyword, f"**{keyword}**")
112
  st.subheader("Keyword Found")
113
+ st.markdown(highlighted_text)
114
  else:
115
  st.write("Keyword not found in the extracted text.")
116
  elif keyword:
117
+ st.write("Please upload an image first before searching.")