awacke1 commited on
Commit
090a3c3
·
verified ·
1 Parent(s): 2578d93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -81
app.py CHANGED
@@ -3,9 +3,10 @@ import os
3
  import glob
4
  import time
5
  import streamlit as st
 
 
6
  from PIL import Image
7
- import torch
8
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, AutoTokenizer, AutoModel, TrOCRProcessor, VisionEncoderDecoderModel
9
  from diffusers import StableDiffusionPipeline
10
  import cv2
11
  import numpy as np
@@ -31,21 +32,18 @@ st.set_page_config(
31
  page_icon="🤖",
32
  layout="wide",
33
  initial_sidebar_state="expanded",
34
- menu_items={'About': "AI Vision Titans: OCR, Image Gen, Line Drawings on CPU! 🌌"}
35
  )
36
 
37
  # Initialize st.session_state
38
- if 'captured_images' not in st.session_state:
39
- st.session_state['captured_images'] = []
40
  if 'processing' not in st.session_state:
41
  st.session_state['processing'] = {}
42
 
43
  # Utility Functions
44
  def generate_filename(sequence, ext="png"):
45
- from datetime import datetime
46
- import pytz
47
- central = pytz.timezone('US/Central')
48
- timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p")
49
  return f"{sequence}{timestamp}.{ext}"
50
 
51
  def get_gallery_files(file_types):
@@ -61,20 +59,27 @@ def update_gallery():
61
  st.image(Image.open(file), caption=file, use_container_width=True)
62
  elif file.endswith(".txt"):
63
  with open(file, "r") as f:
64
- st.text(f.read()[:50] + "..." if len(f.read()) > 50 else f.read(), help=file)
 
65
 
66
- # Model Loaders (Smaller, CPU-focused)
67
- def load_ocr_qwen2vl():
68
- model_id = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
69
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
70
- model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
71
- return processor, model
 
 
 
 
 
72
 
73
- def load_ocr_trocr():
74
- model_id = "microsoft/trocr-small-handwritten" # ~250 MB
75
- processor = TrOCRProcessor.from_pretrained(model_id)
76
- model = VisionEncoderDecoderModel.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu").eval()
77
- return processor, model
 
78
 
79
  def load_image_gen():
80
  model_id = "OFA-Sys/small-stable-diffusion-v0" # ~300 MB
@@ -82,37 +87,60 @@ def load_image_gen():
82
  return pipeline
83
 
84
  def load_line_drawer():
85
- # Simplified OpenCV-based edge detection (CPU-friendly substitute for Torch Space UNet)
86
- def edge_detection(image):
87
  img_np = np.array(image.convert("RGB"))
88
  gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
89
- edges = cv2.Canny(gray, 100, 200)
 
 
 
90
  return Image.fromarray(edges)
91
  return edge_detection
92
 
93
  # Async Processing Functions
94
- async def process_ocr(image, prompt, model_name, output_file):
95
  start_time = time.time()
96
  status = st.empty()
97
- status.text(f"Processing {model_name} OCR... (0s)")
98
- if model_name == "Qwen2-VL-OCR-2B":
99
- processor, model = load_ocr_qwen2vl()
100
- # Corrected input format: apply chat template
101
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
102
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
103
- inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True).to("cpu")
104
- outputs = model.generate(**inputs, max_new_tokens=1024)
105
- result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
106
- else: # TrOCR
107
- processor, model = load_ocr_trocr()
108
- pixel_values = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
109
- outputs = model.generate(pixel_values)
110
- result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
 
 
 
111
  elapsed = int(time.time() - start_time)
112
- status.text(f"{model_name} OCR completed in {elapsed}s!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  async with aiofiles.open(output_file, "w") as f:
114
  await f.write(result)
115
- st.session_state['captured_images'].append(output_file)
 
 
116
  return result
117
 
118
  async def process_image_gen(prompt, output_file):
@@ -120,30 +148,34 @@ async def process_image_gen(prompt, output_file):
120
  status = st.empty()
121
  status.text("Processing Image Gen... (0s)")
122
  pipeline = load_image_gen()
123
- gen_image = pipeline(prompt, num_inference_steps=20).images[0] # Reduced steps for speed
124
  elapsed = int(time.time() - start_time)
125
  status.text(f"Image Gen completed in {elapsed}s!")
126
  gen_image.save(output_file)
127
- st.session_state['captured_images'].append(output_file)
 
 
128
  return gen_image
129
 
130
- async def process_line_drawing(image, output_file):
131
  start_time = time.time()
132
  status = st.empty()
133
- status.text("Processing Line Drawing... (0s)")
134
  edge_fn = load_line_drawer()
135
- line_drawing = edge_fn(image)
136
  elapsed = int(time.time() - start_time)
137
- status.text(f"Line Drawing completed in {elapsed}s!")
138
  line_drawing.save(output_file)
139
- st.session_state['captured_images'].append(output_file)
 
 
140
  return line_drawing
141
 
142
  # Main App
143
- st.title("AI Vision Titans 🚀 (OCR, Gen, Drawings!)")
144
 
145
  # Sidebar Gallery
146
- st.sidebar.header("Captured Images 🎨")
147
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
148
  update_gallery()
149
 
@@ -154,7 +186,7 @@ with log_container:
154
  st.write(f"{record.asctime} - {record.levelname} - {record.message}")
155
 
156
  # Tabs
157
- tab1, tab2, tab3, tab4 = st.tabs(["Camera Snap 📷", "Test OCR 🔍", "Test Image Gen 🎨", "Test Line Drawings ✏️"])
158
 
159
  with tab1:
160
  st.header("Camera Snap 📷")
@@ -164,23 +196,23 @@ with tab1:
164
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
165
  if cam0_img:
166
  filename = generate_filename(0)
167
- if filename not in st.session_state['captured_images']:
168
  with open(filename, "wb") as f:
169
  f.write(cam0_img.getvalue())
170
  st.image(Image.open(filename), caption=filename, use_container_width=True)
171
  logger.info(f"Saved snapshot from Camera 0: {filename}")
172
- st.session_state['captured_images'].append(filename)
173
  update_gallery()
174
  with cols[1]:
175
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
176
  if cam1_img:
177
  filename = generate_filename(1)
178
- if filename not in st.session_state['captured_images']:
179
  with open(filename, "wb") as f:
180
  f.write(cam1_img.getvalue())
181
  st.image(Image.open(filename), caption=filename, use_container_width=True)
182
  logger.info(f"Saved snapshot from Camera 1: {filename}")
183
- st.session_state['captured_images'].append(filename)
184
  update_gallery()
185
 
186
  st.subheader("Burst Capture")
@@ -194,42 +226,57 @@ with tab1:
194
  img = st.camera_input(f"Frame {i}", key=f"burst_{i}_{time.time()}")
195
  if img:
196
  filename = generate_filename(f"burst_{i}")
197
- if filename not in st.session_state['captured_images']:
198
  with open(filename, "wb") as f:
199
  f.write(img.getvalue())
200
  st.session_state['burst_frames'].append(filename)
201
  logger.info(f"Saved burst frame {i}: {filename}")
202
  st.image(Image.open(filename), caption=filename, use_container_width=True)
203
- time.sleep(0.5) # Small delay for visibility
204
- st.session_state['captured_images'].extend([f for f in st.session_state['burst_frames'] if f not in st.session_state['captured_images']])
205
  update_gallery()
206
  placeholder.success(f"Captured {len(st.session_state['burst_frames'])} frames!")
207
 
208
  with tab2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  st.header("Test OCR 🔍")
210
- captured_images = get_gallery_files(["png"])
211
- if captured_images:
212
- selected_image = st.selectbox("Select Image", captured_images, key="ocr_select")
213
- image = Image.open(selected_image)
214
  st.image(image, caption="Input Image", use_container_width=True)
215
- ocr_model = st.selectbox("Select OCR Model", ["Qwen2-VL-OCR-2B", "TrOCR-Small"], key="ocr_model_select")
216
- prompt = st.text_area("Prompt", "Extract text from the image", key="ocr_prompt")
217
  if st.button("Run OCR 🚀", key="ocr_run"):
218
  output_file = generate_filename("ocr_output", "txt")
219
  st.session_state['processing']['ocr'] = True
220
- result = asyncio.run(process_ocr(image, prompt, ocr_model, output_file))
221
  st.text_area("OCR Result", result, height=200, key="ocr_result")
222
  st.success(f"OCR output saved to {output_file}")
223
  st.session_state['processing']['ocr'] = False
224
  else:
225
- st.warning("No images captured yet. Use Camera Snap first!")
226
 
227
- with tab3:
228
  st.header("Test Image Gen 🎨")
229
- captured_images = get_gallery_files(["png"])
230
- if captured_images:
231
- selected_image = st.selectbox("Select Image", captured_images, key="gen_select")
232
- image = Image.open(selected_image)
233
  st.image(image, caption="Reference Image", use_container_width=True)
234
  prompt = st.text_area("Prompt", "Generate a similar superhero image", key="gen_prompt")
235
  if st.button("Run Image Gen 🚀", key="gen_run"):
@@ -240,24 +287,25 @@ with tab3:
240
  st.success(f"Image saved to {output_file}")
241
  st.session_state['processing']['gen'] = False
242
  else:
243
- st.warning("No images captured yet. Use Camera Snap first!")
244
 
245
- with tab4:
246
  st.header("Test Line Drawings ✏️")
247
- captured_images = get_gallery_files(["png"])
248
- if captured_images:
249
- selected_image = st.selectbox("Select Image", captured_images, key="line_select")
250
- image = Image.open(selected_image)
251
  st.image(image, caption="Input Image", use_container_width=True)
 
252
  if st.button("Run Line Drawing 🚀", key="line_run"):
253
- output_file = generate_filename("line_output", "png")
254
  st.session_state['processing']['line'] = True
255
- result = asyncio.run(process_line_drawing(image, output_file))
256
- st.image(result, caption="Line Drawing", use_container_width=True)
257
  st.success(f"Line drawing saved to {output_file}")
258
  st.session_state['processing']['line'] = False
259
  else:
260
- st.warning("No images captured yet. Use Camera Snap first!")
261
 
262
  # Initial Gallery Update
263
  update_gallery()
 
3
  import glob
4
  import time
5
  import streamlit as st
6
+ import fitz # PyMuPDF
7
+ import requests
8
  from PIL import Image
9
+ from transformers import AutoTokenizer, AutoModel
 
10
  from diffusers import StableDiffusionPipeline
11
  import cv2
12
  import numpy as np
 
32
  page_icon="🤖",
33
  layout="wide",
34
  initial_sidebar_state="expanded",
35
+ menu_items={'About': "AI Vision Titans: PDF Snapshots, OCR, Image Gen, Line Drawings on CPU! 🌌"}
36
  )
37
 
38
  # Initialize st.session_state
39
+ if 'captured_files' not in st.session_state:
40
+ st.session_state['captured_files'] = []
41
  if 'processing' not in st.session_state:
42
  st.session_state['processing'] = {}
43
 
44
  # Utility Functions
45
  def generate_filename(sequence, ext="png"):
46
+ timestamp = time.strftime("%d%m%Y%H%M%S")
 
 
 
47
  return f"{sequence}{timestamp}.{ext}"
48
 
49
  def get_gallery_files(file_types):
 
59
  st.image(Image.open(file), caption=file, use_container_width=True)
60
  elif file.endswith(".txt"):
61
  with open(file, "r") as f:
62
+ content = f.read()
63
+ st.text(content[:50] + "..." if len(content) > 50 else content, help=file)
64
 
65
+ def download_pdf(url, output_path):
66
+ try:
67
+ response = requests.get(url, stream=True, timeout=10)
68
+ if response.status_code == 200:
69
+ with open(output_path, "wb") as f:
70
+ for chunk in response.iter_content(chunk_size=8192):
71
+ f.write(chunk)
72
+ return True
73
+ except requests.RequestException as e:
74
+ logger.error(f"Failed to download {url}: {e}")
75
+ return False
76
 
77
+ # Model Loaders (CPU-focused)
78
+ def load_ocr_got():
79
+ model_id = "ucaslcl/GOT-OCR2_0"
80
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
81
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
82
+ return tokenizer, model
83
 
84
  def load_image_gen():
85
  model_id = "OFA-Sys/small-stable-diffusion-v0" # ~300 MB
 
87
  return pipeline
88
 
89
  def load_line_drawer():
90
+ def edge_detection(image, style="fine"):
 
91
  img_np = np.array(image.convert("RGB"))
92
  gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
93
+ if style == "fine":
94
+ edges = cv2.Canny(gray, 50, 150) # Finer lines
95
+ else: # "bold"
96
+ edges = cv2.Canny(gray, 100, 200) # Bolder lines
97
  return Image.fromarray(edges)
98
  return edge_detection
99
 
100
  # Async Processing Functions
101
+ async def process_pdf_snapshot(pdf_path, mode="thumbnail"):
102
  start_time = time.time()
103
  status = st.empty()
104
+ status.text(f"Processing PDF Snapshot ({mode})... (0s)")
105
+ doc = fitz.open(pdf_path)
106
+ output_files = []
107
+
108
+ if mode == "thumbnail":
109
+ page = doc[0]
110
+ pix = page.get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) # 50% scale
111
+ output_file = generate_filename("thumbnail", "png")
112
+ pix.save(output_file)
113
+ output_files.append(output_file)
114
+ elif mode == "twopage":
115
+ for i in range(min(2, len(doc))):
116
+ page = doc[i]
117
+ pix = page.get_pixmap(matrix=fitz.Matrix(1.0, 1.0)) # Full scale
118
+ output_file = generate_filename(f"twopage_{i}", "png")
119
+ pix.save(output_file)
120
+ output_files.append(output_file)
121
+
122
+ doc.close()
123
  elapsed = int(time.time() - start_time)
124
+ status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
125
+ for file in output_files:
126
+ if file not in st.session_state['captured_files']:
127
+ st.session_state['captured_files'].append(file)
128
+ update_gallery()
129
+ return output_files
130
+
131
+ async def process_ocr(image, output_file):
132
+ start_time = time.time()
133
+ status = st.empty()
134
+ status.text("Processing GOT-OCR2_0... (0s)")
135
+ tokenizer, model = load_ocr_got()
136
+ result = model.chat(tokenizer, image, ocr_type='ocr')
137
+ elapsed = int(time.time() - start_time)
138
+ status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
139
  async with aiofiles.open(output_file, "w") as f:
140
  await f.write(result)
141
+ if output_file not in st.session_state['captured_files']:
142
+ st.session_state['captured_files'].append(output_file)
143
+ update_gallery()
144
  return result
145
 
146
  async def process_image_gen(prompt, output_file):
 
148
  status = st.empty()
149
  status.text("Processing Image Gen... (0s)")
150
  pipeline = load_image_gen()
151
+ gen_image = pipeline(prompt, num_inference_steps=20).images[0]
152
  elapsed = int(time.time() - start_time)
153
  status.text(f"Image Gen completed in {elapsed}s!")
154
  gen_image.save(output_file)
155
+ if output_file not in st.session_state['captured_files']:
156
+ st.session_state['captured_files'].append(output_file)
157
+ update_gallery()
158
  return gen_image
159
 
160
+ async def process_line_drawing(image, style, output_file):
161
  start_time = time.time()
162
  status = st.empty()
163
+ status.text(f"Processing Line Drawing ({style})... (0s)")
164
  edge_fn = load_line_drawer()
165
+ line_drawing = edge_fn(image, style=style)
166
  elapsed = int(time.time() - start_time)
167
+ status.text(f"Line Drawing ({style}) completed in {elapsed}s!")
168
  line_drawing.save(output_file)
169
+ if output_file not in st.session_state['captured_files']:
170
+ st.session_state['captured_files'].append(output_file)
171
+ update_gallery()
172
  return line_drawing
173
 
174
  # Main App
175
+ st.title("AI Vision Titans 🚀")
176
 
177
  # Sidebar Gallery
178
+ st.sidebar.header("Captured Files 📜")
179
  gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 4)
180
  update_gallery()
181
 
 
186
  st.write(f"{record.asctime} - {record.levelname} - {record.message}")
187
 
188
  # Tabs
189
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Camera Snap 📷", "Download PDFs 📥", "Test OCR 🔍", "Test Image Gen 🎨", "Test Line Drawings ✏️"])
190
 
191
  with tab1:
192
  st.header("Camera Snap 📷")
 
196
  cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
197
  if cam0_img:
198
  filename = generate_filename(0)
199
+ if filename not in st.session_state['captured_files']:
200
  with open(filename, "wb") as f:
201
  f.write(cam0_img.getvalue())
202
  st.image(Image.open(filename), caption=filename, use_container_width=True)
203
  logger.info(f"Saved snapshot from Camera 0: {filename}")
204
+ st.session_state['captured_files'].append(filename)
205
  update_gallery()
206
  with cols[1]:
207
  cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
208
  if cam1_img:
209
  filename = generate_filename(1)
210
+ if filename not in st.session_state['captured_files']:
211
  with open(filename, "wb") as f:
212
  f.write(cam1_img.getvalue())
213
  st.image(Image.open(filename), caption=filename, use_container_width=True)
214
  logger.info(f"Saved snapshot from Camera 1: {filename}")
215
+ st.session_state['captured_files'].append(filename)
216
  update_gallery()
217
 
218
  st.subheader("Burst Capture")
 
226
  img = st.camera_input(f"Frame {i}", key=f"burst_{i}_{time.time()}")
227
  if img:
228
  filename = generate_filename(f"burst_{i}")
229
+ if filename not in st.session_state['captured_files']:
230
  with open(filename, "wb") as f:
231
  f.write(img.getvalue())
232
  st.session_state['burst_frames'].append(filename)
233
  logger.info(f"Saved burst frame {i}: {filename}")
234
  st.image(Image.open(filename), caption=filename, use_container_width=True)
235
+ time.sleep(0.5)
236
+ st.session_state['captured_files'].extend([f for f in st.session_state['burst_frames'] if f not in st.session_state['captured_files']])
237
  update_gallery()
238
  placeholder.success(f"Captured {len(st.session_state['burst_frames'])} frames!")
239
 
240
  with tab2:
241
+ st.header("Download PDFs 📥")
242
+ url_input = st.text_area("Enter PDF URLs (one per line)", height=100)
243
+ mode = st.selectbox("Snapshot Mode", ["Thumbnail", "Two-Page View"], key="download_mode")
244
+ if st.button("Download & Snapshot 📸"):
245
+ urls = url_input.strip().split("\n")
246
+ for url in urls:
247
+ if url:
248
+ pdf_path = generate_filename("downloaded", "pdf")
249
+ if download_pdf(url, pdf_path):
250
+ logger.info(f"Downloaded PDF from {url} to {pdf_path}")
251
+ snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode.lower().replace(" ", "")))
252
+ for snapshot in snapshots:
253
+ st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
254
+ else:
255
+ st.error(f"Failed to download {url}")
256
+
257
+ with tab3:
258
  st.header("Test OCR 🔍")
259
+ captured_files = get_gallery_files(["png"])
260
+ if captured_files:
261
+ selected_file = st.selectbox("Select Image", captured_files, key="ocr_select")
262
+ image = Image.open(selected_file)
263
  st.image(image, caption="Input Image", use_container_width=True)
 
 
264
  if st.button("Run OCR 🚀", key="ocr_run"):
265
  output_file = generate_filename("ocr_output", "txt")
266
  st.session_state['processing']['ocr'] = True
267
+ result = asyncio.run(process_ocr(image, output_file))
268
  st.text_area("OCR Result", result, height=200, key="ocr_result")
269
  st.success(f"OCR output saved to {output_file}")
270
  st.session_state['processing']['ocr'] = False
271
  else:
272
+ st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
273
 
274
+ with tab4:
275
  st.header("Test Image Gen 🎨")
276
+ captured_files = get_gallery_files(["png"])
277
+ if captured_files:
278
+ selected_file = st.selectbox("Select Image", captured_files, key="gen_select")
279
+ image = Image.open(selected_file)
280
  st.image(image, caption="Reference Image", use_container_width=True)
281
  prompt = st.text_area("Prompt", "Generate a similar superhero image", key="gen_prompt")
282
  if st.button("Run Image Gen 🚀", key="gen_run"):
 
287
  st.success(f"Image saved to {output_file}")
288
  st.session_state['processing']['gen'] = False
289
  else:
290
+ st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
291
 
292
+ with tab5:
293
  st.header("Test Line Drawings ✏️")
294
+ captured_files = get_gallery_files(["png"])
295
+ if captured_files:
296
+ selected_file = st.selectbox("Select Image", captured_files, key="line_select")
297
+ image = Image.open(selected_file)
298
  st.image(image, caption="Input Image", use_container_width=True)
299
+ style = st.selectbox("Line Style", ["Fine", "Bold"], key="line_style")
300
  if st.button("Run Line Drawing 🚀", key="line_run"):
301
+ output_file = generate_filename(f"line_{style.lower()}", "png")
302
  st.session_state['processing']['line'] = True
303
+ result = asyncio.run(process_line_drawing(image, style.lower(), output_file))
304
+ st.image(result, caption=f"{style} Line Drawing", use_container_width=True)
305
  st.success(f"Line drawing saved to {output_file}")
306
  st.session_state['processing']['line'] = False
307
  else:
308
+ st.warning("No images captured yet. Use Camera Snap or Download PDFs first!")
309
 
310
  # Initial Gallery Update
311
  update_gallery()