Svngoku commited on
Commit
93be02b
·
verified ·
1 Parent(s): 39c7fb4
Files changed (1) hide show
  1. app.py +62 -28
app.py CHANGED
@@ -45,6 +45,29 @@ def encode_image_bytes(image_bytes: bytes) -> str:
45
  """Encodes image bytes to a base64 string."""
46
  return base64.b64encode(image_bytes).decode('utf-8')
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]:
49
  """Combines markdown from OCR pages, replacing image IDs with base64 data URIs."""
50
  processed_markdowns = []
@@ -58,12 +81,16 @@ def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]:
58
  try:
59
  for page_idx, page in enumerate(ocr_response.pages):
60
  if hasattr(page, 'images') and page.images:
 
61
  for img in page.images:
62
  if hasattr(img, 'id') and hasattr(img, 'image_base64') and img.image_base64:
63
  image_data_map[img.id] = img.image_base64
 
64
  else:
65
- logger.warning(f"Page {page_idx}: Image object lacks 'id' or valid 'image_base64'.")
66
-
 
 
67
  if not hasattr(page, 'markdown'):
68
  logger.warning(f"Page {page_idx} lacks 'markdown' attribute. Skipping.")
69
  continue
@@ -73,6 +100,7 @@ def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]:
73
  current_processed_markdown = current_raw_markdown
74
 
75
  img_refs = re.findall(r"!\[.*?\]\((.*?)\)", current_processed_markdown)
 
76
  for img_id in img_refs:
77
  if img_id in image_data_map:
78
  base64_data_uri = image_data_map[img_id]
@@ -84,11 +112,13 @@ def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]:
84
  r"\1" + base64_data_uri + r"\2",
85
  current_processed_markdown
86
  )
 
87
  elif not img_id.startswith(('http:', 'https:', 'data:')):
88
  logger.warning(f"Page {page_idx}: Image ID '{img_id}' not in image data.")
89
 
90
  processed_markdowns.append(current_processed_markdown)
91
 
 
92
  return "\n\n".join(processed_markdowns), "\n\n".join(raw_markdowns), image_data_map
93
 
94
  except Exception as e:
@@ -114,10 +144,9 @@ def perform_ocr_file(file_obj: Any) -> Tuple[str, str, Dict[str, str]]:
114
  if file_ext == '.pdf':
115
  try:
116
  with open(file_path, "rb") as f:
117
- file_content = f.read() # Read the entire file content
118
 
119
  logger.info(f"Uploading PDF {file_name} to Mistral...")
120
- # Use dictionary format as per documentation
121
  uploaded_pdf = client.files.upload(
122
  file={
123
  "file_name": file_name,
@@ -134,6 +163,7 @@ def perform_ocr_file(file_obj: Any) -> Tuple[str, str, Dict[str, str]]:
134
  document={"type": "document_url", "document_url": signed_url_response.url},
135
  include_image_base64=True
136
  )
 
137
  finally:
138
  if uploaded_file_id:
139
  try:
@@ -155,12 +185,15 @@ def perform_ocr_file(file_obj: Any) -> Tuple[str, str, Dict[str, str]]:
155
  document={"type": "image_url", "image_url": data_uri},
156
  include_image_base64=True
157
  )
 
158
 
159
  else:
160
  return f"Unsupported file type: '{file_name}'.", "", {}
161
 
162
  if ocr_response:
163
- return get_combined_markdown(ocr_response)
 
 
164
  return f"Error: OCR failed for '{file_name}'.", "", {}
165
 
166
  except Exception as e:
@@ -173,11 +206,16 @@ def chunk_markdown(
173
  chunk_overlap: int = 200,
174
  strip_headers: bool = True
175
  ) -> List[Document]:
176
- """Chunks markdown text, preserving headers in metadata and extracting base64 images."""
177
  if not markdown_text_with_images or not markdown_text_with_images.strip():
178
  logger.warning("chunk_markdown received empty input.")
179
  return []
180
 
 
 
 
 
 
181
  headers_to_split_on = [
182
  ("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3"),
183
  ("####", "Header 4"), ("#####", "Header 5"), ("######", "Header 6"),
@@ -185,47 +223,43 @@ def chunk_markdown(
185
  markdown_splitter = MarkdownHeaderTextSplitter(
186
  headers_to_split_on=headers_to_split_on, strip_headers=strip_headers
187
  )
188
- header_chunks = markdown_splitter.split_text(markdown_text_with_images)
189
 
190
  if not header_chunks:
191
- logger.warning("No chunks created from markdown splitting.")
192
- return []
193
 
194
  final_chunks = []
195
  if chunk_size > 0:
196
  text_splitter = RecursiveCharacterTextSplitter(
197
- chunk_size=chunk_size,
198
- chunk_overlap=chunk_overlap,
199
- length_function=len,
200
  separators=["\n\n", "\n", "(?<=\. )", "(?<=\? )", "(?<=! )", ", ", "; ", " ", ""],
201
  add_start_index=True
202
  )
203
  for i, header_chunk in enumerate(header_chunks):
204
- if header_chunk.page_content and len(header_chunk.page_content) > chunk_size:
205
  sub_chunks = text_splitter.split_documents([header_chunk])
206
  final_chunks.extend(sub_chunks)
207
- elif header_chunk.page_content:
208
- final_chunks.append(header_chunk)
 
209
  else:
210
  final_chunks = [chunk for chunk in header_chunks if chunk.page_content]
211
 
212
- # Extract base64 images and add to metadata
213
  for chunk in final_chunks:
214
  if not hasattr(chunk, 'metadata'):
215
  chunk.metadata = {}
216
-
217
- # Improved regex to capture full base64 data URI
218
- images_in_chunk = re.findall(
219
- r"!\[.*?\]\((data:image/[a-zA-Z]+;base64,[A-Za-z0-9+/]+={0,2})\)",
220
- chunk.page_content
221
- )
222
- chunk.metadata["images_base64"] = images_in_chunk if images_in_chunk else []
223
- logger.debug(f"Chunk metadata updated with {len(images_in_chunk)} base64 images")
224
-
225
- logger.info(f"Created {len(final_chunks)} chunks with base64 metadata")
226
  return final_chunks
227
 
228
-
229
  def get_hf_token(explicit_token: str = None) -> str:
230
  """Retrieve Hugging Face token with fallback mechanisms."""
231
  global hf_token_global
@@ -280,7 +314,7 @@ def process_file_and_save(
280
  source_filename = getattr(file_obj, 'orig_name', os.path.basename(file_obj.name))
281
  logger.info(f"--- Starting processing for file: {source_filename} ---")
282
 
283
- processed_markdown, _, _ = perform_ocr_file(file_obj)
284
  if not processed_markdown or processed_markdown.startswith("Error:"):
285
  return processed_markdown
286
 
 
45
  """Encodes image bytes to a base64 string."""
46
  return base64.b64encode(image_bytes).decode('utf-8')
47
 
48
+ def extract_images_from_markdown(markdown_text: str) -> Dict[str, str]:
49
+ """
50
+ Extracts base64 image data URIs from markdown and maps them to reference IDs.
51
+ Returns a dictionary mapping reference IDs to base64 data URIs.
52
+ """
53
+ image_map = {}
54
+ img_refs = re.findall(r"!\[.*?\]\((data:image/[a-zA-Z+]+;base64,[A-Za-z0-9+/=]+)\)", markdown_text)
55
+ for idx, img_uri in enumerate(img_refs):
56
+ ref_id = f"img_ref_{idx+1}"
57
+ image_map[ref_id] = img_uri
58
+ return image_map
59
+
60
+ def replace_image_references(markdown_text: str, image_map: Dict[str, str]) -> str:
61
+ """
62
+ Replaces base64 image data URIs in markdown with reference IDs (e.g., img_ref_1).
63
+ """
64
+ updated_markdown = markdown_text
65
+ for ref_id, img_uri in image_map.items():
66
+ escaped_uri = re.escape(img_uri)
67
+ pattern = r"(!\[.*?\]\()" + escaped_uri + r"(\))"
68
+ updated_markdown = re.sub(pattern, f"\\1{ref_id}\\2", updated_markdown)
69
+ return updated_markdown
70
+
71
  def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]:
72
  """Combines markdown from OCR pages, replacing image IDs with base64 data URIs."""
73
  processed_markdowns = []
 
81
  try:
82
  for page_idx, page in enumerate(ocr_response.pages):
83
  if hasattr(page, 'images') and page.images:
84
+ logger.info(f"Page {page_idx}: Found {len(page.images)} images.")
85
  for img in page.images:
86
  if hasattr(img, 'id') and hasattr(img, 'image_base64') and img.image_base64:
87
  image_data_map[img.id] = img.image_base64
88
+ logger.debug(f"Page {page_idx}: Image ID {img.id} added to image_data_map.")
89
  else:
90
+ logger.warning(f"Page {page_idx}: Image object lacks 'id' or valid 'image_base64'. Image: {img}")
91
+ else:
92
+ logger.info(f"Page {page_idx}: No images found.")
93
+
94
  if not hasattr(page, 'markdown'):
95
  logger.warning(f"Page {page_idx} lacks 'markdown' attribute. Skipping.")
96
  continue
 
100
  current_processed_markdown = current_raw_markdown
101
 
102
  img_refs = re.findall(r"!\[.*?\]\((.*?)\)", current_processed_markdown)
103
+ logger.debug(f"Page {page_idx}: Found {len(img_refs)} image references in markdown.")
104
  for img_id in img_refs:
105
  if img_id in image_data_map:
106
  base64_data_uri = image_data_map[img_id]
 
112
  r"\1" + base64_data_uri + r"\2",
113
  current_processed_markdown
114
  )
115
+ logger.debug(f"Page {page_idx}: Replaced image ID {img_id} with base64 data URI.")
116
  elif not img_id.startswith(('http:', 'https:', 'data:')):
117
  logger.warning(f"Page {page_idx}: Image ID '{img_id}' not in image data.")
118
 
119
  processed_markdowns.append(current_processed_markdown)
120
 
121
+ logger.info(f"Processed {len(processed_markdowns)} pages with {len(image_data_map)} images.")
122
  return "\n\n".join(processed_markdowns), "\n\n".join(raw_markdowns), image_data_map
123
 
124
  except Exception as e:
 
144
  if file_ext == '.pdf':
145
  try:
146
  with open(file_path, "rb") as f:
147
+ file_content = f.read()
148
 
149
  logger.info(f"Uploading PDF {file_name} to Mistral...")
 
150
  uploaded_pdf = client.files.upload(
151
  file={
152
  "file_name": file_name,
 
163
  document={"type": "document_url", "document_url": signed_url_response.url},
164
  include_image_base64=True
165
  )
166
+ logger.info(f"OCR response received: {ocr_response}")
167
  finally:
168
  if uploaded_file_id:
169
  try:
 
185
  document={"type": "image_url", "image_url": data_uri},
186
  include_image_base64=True
187
  )
188
+ logger.info(f"OCR response received: {ocr_response}")
189
 
190
  else:
191
  return f"Unsupported file type: '{file_name}'.", "", {}
192
 
193
  if ocr_response:
194
+ processed_md, raw_md, img_map = get_combined_markdown(ocr_response)
195
+ logger.info(f"Processed markdown length: {len(processed_md)}")
196
+ return processed_md, raw_md, img_map
197
  return f"Error: OCR failed for '{file_name}'.", "", {}
198
 
199
  except Exception as e:
 
206
  chunk_overlap: int = 200,
207
  strip_headers: bool = True
208
  ) -> List[Document]:
209
+ """Chunks markdown text, preserving headers in metadata and extracting images."""
210
  if not markdown_text_with_images or not markdown_text_with_images.strip():
211
  logger.warning("chunk_markdown received empty input.")
212
  return []
213
 
214
+ # Extract images and replace with reference IDs
215
+ image_map = extract_images_from_markdown(markdown_text_with_images)
216
+ updated_markdown = replace_image_references(markdown_text_with_images, image_map)
217
+ logger.info(f"Extracted {len(image_map)} images from markdown.")
218
+
219
  headers_to_split_on = [
220
  ("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3"),
221
  ("####", "Header 4"), ("#####", "Header 5"), ("######", "Header 6"),
 
223
  markdown_splitter = MarkdownHeaderTextSplitter(
224
  headers_to_split_on=headers_to_split_on, strip_headers=strip_headers
225
  )
226
+ header_chunks = markdown_splitter.split_text(updated_markdown)
227
 
228
  if not header_chunks:
229
+ logger.warning("No header chunks created. Treating entire text as one chunk.")
230
+ return [Document(page_content=updated_markdown, metadata={"images_base64": list(image_map.values())})]
231
 
232
  final_chunks = []
233
  if chunk_size > 0:
234
  text_splitter = RecursiveCharacterTextSplitter(
235
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len,
 
 
236
  separators=["\n\n", "\n", "(?<=\. )", "(?<=\? )", "(?<=! )", ", ", "; ", " ", ""],
237
  add_start_index=True
238
  )
239
  for i, header_chunk in enumerate(header_chunks):
240
+ if header_chunk.page_content:
241
  sub_chunks = text_splitter.split_documents([header_chunk])
242
  final_chunks.extend(sub_chunks)
243
+ logger.debug(f"Header chunk {i}: Split into {len(sub_chunks)} sub-chunks.")
244
+ else:
245
+ logger.debug(f"Header chunk {i}: Empty, skipping.")
246
  else:
247
  final_chunks = [chunk for chunk in header_chunks if chunk.page_content]
248
 
249
+ # Add image references to metadata for each chunk
250
  for chunk in final_chunks:
251
  if not hasattr(chunk, 'metadata'):
252
  chunk.metadata = {}
253
+ # Find image references in this chunk
254
+ chunk_img_refs = re.findall(r"!\[.*?\]\((img_ref_\d+)\)", chunk.page_content)
255
+ chunk_images = [image_map[ref_id] for ref_id in chunk_img_refs if ref_id in image_map]
256
+ chunk.metadata["images_base64"] = chunk_images
257
+ chunk.metadata["image_references"] = chunk_img_refs
258
+ logger.debug(f"Chunk {chunk.metadata.get('start_index', 'unknown')}: Found {len(chunk_images)} images.")
259
+
260
+ logger.info(f"Created {len(final_chunks)} final chunks.")
 
 
261
  return final_chunks
262
 
 
263
  def get_hf_token(explicit_token: str = None) -> str:
264
  """Retrieve Hugging Face token with fallback mechanisms."""
265
  global hf_token_global
 
314
  source_filename = getattr(file_obj, 'orig_name', os.path.basename(file_obj.name))
315
  logger.info(f"--- Starting processing for file: {source_filename} ---")
316
 
317
+ processed_markdown, raw_markdown, img_map = perform_ocr_file(file_obj)
318
  if not processed_markdown or processed_markdown.startswith("Error:"):
319
  return processed_markdown
320