awacke1 commited on
Commit
e1bf9f9
·
verified ·
1 Parent(s): de31118

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -13,7 +13,6 @@ import logging
13
  import asyncio
14
  import aiofiles
15
  from io import BytesIO
16
- import threading
17
 
18
  # Logging setup
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -53,12 +52,16 @@ def get_gallery_files(file_types):
53
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
54
 
55
  def update_gallery():
56
- media_files = get_gallery_files(["png"])
57
  if media_files:
58
  cols = st.sidebar.columns(2)
59
  for idx, file in enumerate(media_files[:gallery_size * 2]):
60
  with cols[idx % 2]:
61
- st.image(Image.open(file), caption=file, use_container_width=True)
 
 
 
 
62
 
63
  # Model Loaders (Smaller, CPU-focused)
64
  def load_ocr_qwen2vl():
@@ -68,7 +71,7 @@ def load_ocr_qwen2vl():
68
  return processor, model
69
 
70
  def load_ocr_trocr():
71
- model_id = "microsoft/trocr-small-handwritten" # Smaller, ~250 MB
72
  processor = TrOCRProcessor.from_pretrained(model_id)
73
  model = VisionEncoderDecoderModel.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu").eval()
74
  return processor, model
@@ -79,7 +82,7 @@ def load_image_gen():
79
  return pipeline
80
 
81
  def load_line_drawer():
82
- # Simplified from your Torch Space (assuming edge detection)
83
  def edge_detection(image):
84
  img_np = np.array(image.convert("RGB"))
85
  gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
@@ -94,20 +97,23 @@ async def process_ocr(image, prompt, model_name, output_file):
94
  status.text(f"Processing {model_name} OCR... (0s)")
95
  if model_name == "Qwen2-VL-OCR-2B":
96
  processor, model = load_ocr_qwen2vl()
97
- inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
 
 
 
98
  outputs = model.generate(**inputs, max_new_tokens=1024)
99
- text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
100
  else: # TrOCR
101
  processor, model = load_ocr_trocr()
102
  pixel_values = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
103
  outputs = model.generate(pixel_values)
104
- text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
105
  elapsed = int(time.time() - start_time)
106
  status.text(f"{model_name} OCR completed in {elapsed}s!")
107
  async with aiofiles.open(output_file, "w") as f:
108
- await f.write(text)
109
  st.session_state['captured_images'].append(output_file)
110
- return text
111
 
112
  async def process_image_gen(prompt, output_file):
113
  start_time = time.time()
 
13
  import asyncio
14
  import aiofiles
15
  from io import BytesIO
 
16
 
17
  # Logging setup
18
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
52
  return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
53
 
54
  def update_gallery():
55
+ media_files = get_gallery_files(["png", "txt"])
56
  if media_files:
57
  cols = st.sidebar.columns(2)
58
  for idx, file in enumerate(media_files[:gallery_size * 2]):
59
  with cols[idx % 2]:
60
+ if file.endswith(".png"):
61
+ st.image(Image.open(file), caption=file, use_container_width=True)
62
+ elif file.endswith(".txt"):
63
+ with open(file, "r") as f:
64
+ st.text(f.read()[:50] + "..." if len(f.read()) > 50 else f.read(), help=file)
65
 
66
  # Model Loaders (Smaller, CPU-focused)
67
  def load_ocr_qwen2vl():
 
71
  return processor, model
72
 
73
  def load_ocr_trocr():
74
+ model_id = "microsoft/trocr-small-handwritten" # ~250 MB
75
  processor = TrOCRProcessor.from_pretrained(model_id)
76
  model = VisionEncoderDecoderModel.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu").eval()
77
  return processor, model
 
82
  return pipeline
83
 
84
  def load_line_drawer():
85
+ # Simplified OpenCV-based edge detection (CPU-friendly substitute for Torch Space UNet)
86
  def edge_detection(image):
87
  img_np = np.array(image.convert("RGB"))
88
  gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
 
97
  status.text(f"Processing {model_name} OCR... (0s)")
98
  if model_name == "Qwen2-VL-OCR-2B":
99
  processor, model = load_ocr_qwen2vl()
100
+ # Corrected input format: apply chat template
101
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
102
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
103
+ inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True).to("cpu")
104
  outputs = model.generate(**inputs, max_new_tokens=1024)
105
+ result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
106
  else: # TrOCR
107
  processor, model = load_ocr_trocr()
108
  pixel_values = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
109
  outputs = model.generate(pixel_values)
110
+ result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
111
  elapsed = int(time.time() - start_time)
112
  status.text(f"{model_name} OCR completed in {elapsed}s!")
113
  async with aiofiles.open(output_file, "w") as f:
114
+ await f.write(result)
115
  st.session_state['captured_images'].append(output_file)
116
+ return result
117
 
118
  async def process_image_gen(prompt, output_file):
119
  start_time = time.time()