Svngoku commited on
Commit
f13386c
·
verified ·
1 Parent(s): 716b14f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +585 -153
app.py CHANGED
@@ -1,215 +1,619 @@
1
  import gradio as gr
2
  from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
3
  from langchain.schema import Document
4
- from typing import List, Dict, Any
5
  import logging
6
  import re
7
- from pathlib import Path
8
- import requests
9
  import base64
10
- import io
11
- from PIL import Image
12
  from datasets import Dataset
13
  from huggingface_hub import HfApi
 
14
  import os
15
- from mistralai import Mistral
16
 
17
  # Configure logging
18
- logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- # Mistral OCR setup
22
  api_key = os.environ.get("MISTRAL_API_KEY")
23
  if not api_key:
24
- raise ValueError("MISTRAL_API_KEY environment variable not set")
25
- client = Mistral(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Function to encode image to base64
28
- def encode_image(image_path):
 
29
  try:
30
- with open(image_path, "rb") as image_file:
31
- return base64.b64encode(image_file.read()).decode('utf-8')
32
- except FileNotFoundError:
33
- return "Error: The file was not found."
34
  except Exception as e:
35
- return f"Error: {e}"
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Function to replace images in markdown with base64 strings
38
- def replace_images_in_markdown(markdown_str: str, images_dict: Dict[str, str]) -> str:
39
- for img_name, base64_str in images_dict.items():
40
- markdown_str = markdown_str.replace(f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})")
41
- return markdown_str
42
 
43
- # Function to combine markdown from OCR response
44
- def get_combined_markdown(ocr_response) -> tuple[str, str, Dict[str, str]]:
45
- markdowns = []
 
 
 
 
 
 
 
 
 
 
 
 
46
  raw_markdowns = []
47
- image_data = {} # Collect all image data
48
- for page in ocr_response.pages:
49
- for img in page.images:
50
- image_data[img.id] = img.image_base64
51
- markdowns.append(replace_images_in_markdown(page.markdown, image_data))
52
- raw_markdowns.append(page.markdown)
53
- return "\n\n".join(markdowns), "\n\n".join(raw_markdowns), image_data
54
-
55
- # Perform OCR on uploaded file
56
- def perform_ocr_file(file):
57
  try:
58
- if file.name.lower().endswith('.pdf'):
59
- uploaded_pdf = client.files.upload(
60
- file={
61
- "file_name": file.name,
62
- "content": open(file.name, "rb"),
63
- },
64
- purpose="ocr"
65
- )
66
- signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id)
67
- ocr_response = client.ocr.process(
68
- model="mistral-ocr-latest",
69
- document={
70
- "type": "document_url",
71
- "document_url": signed_url.url,
72
- },
73
- include_image_base64=True
74
- )
75
- client.files.delete(file_id=uploaded_pdf.id)
76
-
77
- elif file.name.lower().endswith(('.png', '.jpg', '.jpeg')):
78
- base64_image = encode_image(file.name)
79
- ocr_response = client.ocr.process(
80
- model="mistral-ocr-latest",
81
- document={
82
- "type": "image_url",
83
- "image_url": f"data:image/jpeg;base64,{base64_image}"
84
- },
85
- include_image_base64=True
86
- )
87
- else:
88
- return "Unsupported file type. Please provide a PDF or an image (png, jpeg, jpg).", "", {}
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- combined_markdown, raw_markdown, image_data = get_combined_markdown(ocr_response)
91
- return combined_markdown, raw_markdown, image_data
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
- return f"Error during OCR: {str(e)}", "", {}
 
94
 
95
- # Function to extract image names from markdown content
96
- def extract_image_names_from_markdown(markdown_text: str) -> List[str]:
97
- # Regex to match markdown image syntax
98
- pattern = r"!\[(.*?)\]\("
99
- return [match.replace("![","").replace("](","") for match in re.findall(pattern, markdown_text)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Function to chunk markdown text with image handling
102
  def chunk_markdown(
103
- markdown_text: str,
104
- image_data: Dict[str, str],
105
  chunk_size: int = 1000,
106
  chunk_overlap: int = 200,
107
  strip_headers: bool = True
108
  ) -> List[Document]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  try:
110
- # Define headers to split on
111
  headers_to_split_on = [
112
  ("#", "Header 1"),
113
  ("##", "Header 2"),
114
  ("###", "Header 3"),
 
 
 
115
  ]
116
-
117
  # Initialize MarkdownHeaderTextSplitter
118
  markdown_splitter = MarkdownHeaderTextSplitter(
119
  headers_to_split_on=headers_to_split_on,
120
- strip_headers=strip_headers
 
121
  )
122
-
123
- # Split markdown by headers
124
- logger.info("Splitting markdown by headers")
125
- chunks = markdown_splitter.split_text(markdown_text)
126
 
127
- # If chunk_size is specified, further split large chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  if chunk_size > 0:
129
  text_splitter = RecursiveCharacterTextSplitter(
130
  chunk_size=chunk_size,
131
  chunk_overlap=chunk_overlap,
132
  length_function=len,
133
- separators=["\n\n", "\n", ".", " ", ""],
134
- keep_separator=True,
135
- add_start_index=True
 
136
  )
137
- logger.info(f"Applying character-level splitting with chunk_size={chunk_size}")
138
- final_chunks = []
139
- for chunk in chunks:
140
- if len(chunk.page_content) > chunk_size:
141
- sub_chunks = text_splitter.split_documents([chunk])
142
- final_chunks.extend(sub_chunks)
 
 
 
 
 
 
 
 
 
 
 
143
  else:
144
- final_chunks.append(chunk)
145
- chunks = final_chunks
146
-
147
- # Add images to metadata
148
- for chunk in chunks:
149
- image_names = extract_image_names_from_markdown(chunk.page_content)
150
- images = {name: image_data.get(name, None) for name in image_names}
151
- # Add a dummy field if the images dictionary is empty
152
- if not images:
153
- images = {"dummy": None}
154
- chunk.metadata["images"] = images
155
-
156
- logger.info(f"Created {len(chunks)} chunks")
157
- return chunks
158
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  except Exception as e:
160
- logger.error(f"Error processing markdown: {str(e)}")
161
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- # Process file: OCR -> Chunk -> Save
164
- def process_file_and_save(file, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name):
165
  try:
166
- # Step 1: Perform OCR
167
- combined_markdown, raw_markdown, image_data = perform_ocr_file(file)
168
- if "Error" in combined_markdown:
169
- return combined_markdown
170
-
171
- # Step 2: Chunk the markdown
172
- chunks = chunk_markdown(combined_markdown, image_data, chunk_size, chunk_overlap, strip_headers)
173
-
174
- # Step 3: Prepare dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  data: Dict[str, List[Any]] = {
176
  "chunk_id": [],
177
- "content": [],
178
  "metadata": [],
 
179
  }
180
-
181
  for i, chunk in enumerate(chunks):
182
- data["chunk_id"].append(i)
183
- data["content"].append(chunk.page_content)
184
- data["metadata"].append(chunk.metadata)
185
-
186
- # Step 4: Create and push dataset to Hugging Face
187
- dataset = Dataset.from_dict(data)
188
- api = HfApi()
189
- api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True)
190
- dataset.push_to_hub(repo_name, token=hf_token)
191
-
192
- return f"Dataset created with {len(chunks)} chunks and saved to Hugging Face at {repo_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  except Exception as e:
194
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- # Gradio Interface
197
- with gr.Blocks(title="PDF/Image OCR, Markdown Chunking, and Dataset Creator") as demo:
198
- gr.Markdown("# PDF/Image OCR, Markdown Chunking, and Dataset Creator")
199
- gr.Markdown("Upload a PDF or image, extract text/images with Mistral OCR, chunk the markdown by headers, and save to Hugging Face.")
200
-
201
  with gr.Row():
202
- with gr.Column():
203
- file_input = gr.File(label="Upload PDF or Image")
204
- chunk_size = gr.Slider(0, 2000, value=1000, step=100, label="Max Chunk Size (0 to disable)")
205
- chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap")
206
- strip_headers = gr.Checkbox(label="Strip Headers from Content", value=True)
207
- hf_token = gr.Textbox(label="Hugging Face Token", type="password")
208
- repo_name = gr.Textbox(label="Hugging Face Repository Name (e.g., username/dataset-name)")
209
- submit_btn = gr.Button("Process and Save")
210
-
211
- with gr.Column():
212
- output = gr.Textbox(label="Result")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  submit_btn.click(
215
  fn=process_file_and_save,
@@ -217,4 +621,32 @@ with gr.Blocks(title="PDF/Image OCR, Markdown Chunking, and Dataset Creator") as
217
  outputs=output
218
  )
219
 
220
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
3
  from langchain.schema import Document
4
+ from typing import List, Dict, Any, Tuple
5
  import logging
6
  import re
 
 
7
  import base64
8
+ import mimetypes # Added
 
9
  from datasets import Dataset
10
  from huggingface_hub import HfApi
11
+ import huggingface_hub # Added for token checking and errors
12
  import os
13
+ from mistralai import Mistral # Assuming this is the correct import for the client
14
 
15
  # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
  logger = logging.getLogger(__name__)
18
 
19
+ # --- Mistral OCR Setup ---
20
  api_key = os.environ.get("MISTRAL_API_KEY")
21
  if not api_key:
22
+ logger.warning("MISTRAL_API_KEY environment variable not set. Attempting to use Hugging Face token.")
23
+ try:
24
+ api_key = huggingface_hub.get_token()
25
+ if not api_key:
26
+ # If running locally, this might still fail if not logged in.
27
+ logger.warning("Could not retrieve token from Hugging Face login.")
28
+ # Error will be raised later if client init fails or during HF push if token still missing
29
+ else:
30
+ logger.info("Using Hugging Face token as MISTRAL_API_KEY.")
31
+ except Exception as e:
32
+ logger.warning(f"Could not check Hugging Face login for token: {e}")
33
+ # Proceed without API key, client initialization might fail
34
 
35
+ # Initialize Mistral Client
36
+ client = None
37
+ if api_key:
38
  try:
39
+ client = Mistral(api_key=api_key)
40
+ logger.info("Mistral client initialized successfully.")
 
 
41
  except Exception as e:
42
+ logger.error(f"Failed to initialize Mistral client: {e}", exc_info=True)
43
+ # Raise a clearer error for Gradio startup if client fails
44
+ raise RuntimeError(f"Failed to initialize Mistral client. Check API key and mistralai installation. Error: {e}")
45
+ else:
46
+ # This path might be hit if no env var and no HF token found
47
+ logger.error("Mistral API key is not available. OCR functionality will fail.")
48
+ # We could raise an error here, or let it fail when client methods are called.
49
+ # Let's allow Gradio to load but OCR will fail clearly later.
50
+
51
+
52
+ # --- Helper Functions ---
53
 
54
+ def encode_image_bytes(image_bytes: bytes) -> str:
55
+ """Encodes image bytes to a base64 string."""
56
+ return base64.b64encode(image_bytes).decode('utf-8')
 
 
57
 
58
+ def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]:
59
+ """
60
+ Combines markdown from OCR pages, replacing image IDs with base64 data URIs.
61
+
62
+ Args:
63
+ ocr_response: The response object from the Mistral OCR API.
64
+
65
+ Returns:
66
+ A tuple containing:
67
+ - combined_markdown_with_images: Markdown string with image references replaced by base64 data URIs.
68
+ - combined_raw_markdown: Original markdown string without image replacement.
69
+ - image_data_map: A dictionary mapping image IDs to their base64 data URIs.
70
+ Raises ValueError on unexpected response structure.
71
+ """
72
+ processed_markdowns = []
73
  raw_markdowns = []
74
+ image_data_map = {} # Collect image_id -> base64_data_uri
75
+
76
+ if not hasattr(ocr_response, 'pages') or not ocr_response.pages:
77
+ logger.warning("OCR response has no 'pages' attribute or pages list is empty.")
78
+ return "", "", {}
79
+
 
 
 
 
80
  try:
81
+ # Collect all image data first (assuming image_base64 includes data URI prefix from Mistral)
82
+ for page_idx, page in enumerate(ocr_response.pages):
83
+ if hasattr(page, 'images') and page.images:
84
+ for img in page.images:
85
+ if hasattr(img, 'id') and hasattr(img, 'image_base64') and img.image_base64:
86
+ image_data_map[img.id] = img.image_base64 # Assuming this is the full data URI
87
+ else:
88
+ logger.warning(f"Page {page_idx}: Image object lacks 'id' or valid 'image_base64'. Image: {img}")
89
+ # else: # Don't warn if a page simply has no images
90
+ # logger.debug(f"Page {page_idx} has no 'images' attribute or no images found.")
91
+
92
+
93
+ # Process markdown for each page
94
+ for page_idx, page in enumerate(ocr_response.pages):
95
+ if not hasattr(page, 'markdown'):
96
+ logger.warning(f"Page {page_idx} in OCR response lacks 'markdown' attribute. Skipping.")
97
+ continue # Skip page if no markdown
98
+
99
+ current_raw_markdown = page.markdown if page.markdown else ""
100
+ raw_markdowns.append(current_raw_markdown)
101
+ current_processed_markdown = current_raw_markdown
102
+
103
+ # Find all image references like ![alt_text](image_id)
104
+ # Regex to find the image ID (content within parentheses)
105
+ img_refs = re.findall(r"!\[.*?\]\((.*?)\)", current_processed_markdown)
106
+ for img_id in img_refs:
107
+ if img_id in image_data_map:
108
+ base64_data_uri = image_data_map[img_id]
109
+ # Escape potential regex special characters in img_id before using in replace
110
+ escaped_img_id = re.escape(img_id)
111
+ # Replace ![...](image_id) with ![...](data:...)
112
+ # Use a specific regex for replacement: find the exact pattern ![...](img_id)
113
+ pattern = r"(!\[.*?\]\()" + escaped_img_id + r"(\))"
114
+ # Check if replacement target exists before replacing
115
+ if re.search(pattern, current_processed_markdown):
116
+ current_processed_markdown = re.sub(
117
+ pattern,
118
+ r"\1" + base64_data_uri + r"\2",
119
+ current_processed_markdown
120
+ )
121
+ else:
122
+ # This case shouldn't happen often if img_id came from findall on the same string
123
+ logger.warning(f"Page {page_idx}: Found img_id '{img_id}' but couldn't find exact pattern '{pattern}' for replacement.")
124
 
125
+ else:
126
+ # Only log warning if the ID looks like an expected image ID pattern (e.g., 'image_X')
127
+ # Avoid warning for regular URLs that might be in the markdown
128
+ if not img_id.startswith(('http:', 'https:', 'data:')): # Check if it's not already a URL
129
+ logger.warning(f"Page {page_idx}: Image ID '{img_id}' found in markdown but not in collected image data.")
130
+
131
+ processed_markdowns.append(current_processed_markdown)
132
+
133
+ return "\n\n".join(processed_markdowns), "\n\n".join(raw_markdowns), image_data_map
134
+
135
+ except AttributeError as ae:
136
+ logger.error(f"Attribute error accessing OCR response structure: {ae}", exc_info=True)
137
+ raise ValueError(f"Unexpected OCR response structure. Check Mistral API changes. Error: {ae}")
138
  except Exception as e:
139
+ logger.error(f"Error processing OCR response markdown: {e}", exc_info=True)
140
+ raise
141
 
142
+ def perform_ocr_file(file_obj: Any) -> Tuple[str, str, Dict[str, str]]:
143
+ """
144
+ Performs OCR on an uploaded file (PDF or image) using the Mistral API.
145
+
146
+ Args:
147
+ file_obj: The file object from Gradio's gr.File component.
148
+
149
+ Returns:
150
+ A tuple containing:
151
+ - processed_markdown: Markdown string with base64 images, or error message.
152
+ - raw_markdown: Original markdown string.
153
+ - image_data_map: Dictionary mapping image IDs to base64 data URIs.
154
+ """
155
+ if not client:
156
+ return "Error: Mistral client not initialized. Check API key setup.", "", {}
157
+ if not file_obj:
158
+ # This check might be redundant if called from process_file_and_save, but good practice
159
+ return "Error: No file provided to OCR function.", "", {}
160
+
161
+ try:
162
+ file_path = file_obj.name # Get the temporary file path from Gradio
163
+ # Use the original filename if available (Gradio>=4), else use the temp path's basename
164
+ file_name = getattr(file_obj, 'orig_name', os.path.basename(file_path))
165
+ logger.info(f"Performing OCR on file: {file_name} (temp path: {file_path})")
166
+
167
+ # Determine file type from extension
168
+ file_ext = os.path.splitext(file_name)[1].lower()
169
+
170
+ ocr_response = None
171
+ uploaded_file_id = None
172
+
173
+ if file_ext == '.pdf':
174
+ try:
175
+ with open(file_path, "rb") as f:
176
+ logger.info(f"Uploading PDF {file_name} to Mistral...")
177
+ # Pass as tuple (filename, file-like object)
178
+ uploaded_pdf = client.files.upload(
179
+ file=(file_name, f),
180
+ purpose="ocr"
181
+ )
182
+ uploaded_file_id = uploaded_pdf.id
183
+ logger.info(f"PDF uploaded successfully. File ID: {uploaded_file_id}")
184
+
185
+ logger.info(f"Getting signed URL for file ID: {uploaded_file_id}")
186
+ signed_url_response = client.files.get_signed_url(file_id=uploaded_file_id)
187
+ logger.info(f"Got signed URL: {signed_url_response.url[:50]}...")
188
+
189
+ logger.info("Sending PDF URL to Mistral OCR (model: mistral-ocr-latest)...")
190
+ ocr_response = client.ocr.process(
191
+ model="mistral-ocr-latest",
192
+ document={
193
+ "type": "document_url",
194
+ "document_url": signed_url_response.url,
195
+ },
196
+ include_image_base64=True
197
+ )
198
+ logger.info("OCR processing complete for PDF.")
199
+
200
+ finally:
201
+ # Ensure cleanup even if OCR fails after upload
202
+ if uploaded_file_id:
203
+ try:
204
+ logger.info(f"Deleting temporary Mistral file: {uploaded_file_id}")
205
+ client.files.delete(file_id=uploaded_file_id)
206
+ except Exception as delete_err:
207
+ logger.warning(f"Failed to delete temporary Mistral file {uploaded_file_id}: {delete_err}")
208
+
209
+ elif file_ext in ['.png', '.jpg', '.jpeg', '.webp', '.bmp']:
210
+ try:
211
+ with open(file_path, "rb") as f:
212
+ image_bytes = f.read()
213
+
214
+ if not image_bytes:
215
+ return f"Error: Uploaded image file '{file_name}' is empty.", "", {}
216
+
217
+ base64_encoded_image = encode_image_bytes(image_bytes)
218
+
219
+ # Determine MIME type
220
+ mime_type, _ = mimetypes.guess_type(file_path)
221
+ if not mime_type or not mime_type.startswith('image'):
222
+ logger.warning(f"Could not determine MIME type for {file_name} using extension. Defaulting to image/jpeg.")
223
+ mime_type = 'image/jpeg' # Fallback
224
+
225
+ data_uri = f"data:{mime_type};base64,{base64_encoded_image}"
226
+ logger.info(f"Sending image {file_name} ({mime_type}) as data URI to Mistral OCR (model: mistral-ocr-latest)...")
227
+
228
+ ocr_response = client.ocr.process(
229
+ model="mistral-ocr-latest",
230
+ document={
231
+ "type": "image_url",
232
+ "image_url": data_uri
233
+ },
234
+ include_image_base64=True
235
+ )
236
+ logger.info(f"OCR processing complete for image {file_name}.")
237
+ except Exception as img_ocr_err:
238
+ logger.error(f"Error during image OCR for {file_name}: {img_ocr_err}", exc_info=True)
239
+ return f"Error during OCR for image '{file_name}': {img_ocr_err}", "", {}
240
+
241
+ else:
242
+ unsupported_msg = f"Unsupported file type: '{file_name}'. Please provide a PDF or an image (png, jpg, jpeg, webp, bmp)."
243
+ logger.warning(unsupported_msg)
244
+ return unsupported_msg, "", {}
245
+
246
+ # Process the OCR response (common path for PDF/Image)
247
+ if ocr_response:
248
+ logger.info("Processing OCR response to combine markdown and images...")
249
+ processed_md, raw_md, img_map = get_combined_markdown(ocr_response)
250
+ logger.info("Markdown and image data extraction complete.")
251
+ return processed_md, raw_md, img_map
252
+ else:
253
+ # This case might occur if OCR processing itself failed silently or returned None
254
+ logger.error(f"OCR processing for '{file_name}' did not return a valid response.")
255
+ return f"Error: OCR processing failed for '{file_name}'. No response received.", "", {}
256
+
257
+ except FileNotFoundError:
258
+ logger.error(f"Temporary file not found: {file_path}", exc_info=True)
259
+ return f"Error: Could not read the uploaded file '{file_name}'. Ensure it uploaded correctly.", "", {}
260
+ except Exception as e:
261
+ logger.error(f"Unexpected error during OCR processing file {file_name}: {e}", exc_info=True)
262
+ # Provide more context in the error message returned to the user
263
+ return f"Error during OCR processing for '{file_name}': {str(e)}", "", {}
264
 
 
265
  def chunk_markdown(
266
+ markdown_text_with_images: str,
 
267
  chunk_size: int = 1000,
268
  chunk_overlap: int = 200,
269
  strip_headers: bool = True
270
  ) -> List[Document]:
271
+ """
272
+ Chunks markdown text, preserving headers in metadata and adding embedded image info.
273
+
274
+ Args:
275
+ markdown_text_with_images: The markdown string containing base64 data URIs for images.
276
+ chunk_size: The target size for chunks (characters). 0 to disable recursive splitting.
277
+ chunk_overlap: The overlap between consecutive chunks (characters).
278
+ strip_headers: Whether to remove header syntax (e.g., '# ') from the chunk content.
279
+
280
+ Returns:
281
+ A list of Langchain Document objects representing the chunks. Returns empty list if input is empty.
282
+ """
283
+ if not markdown_text_with_images or not markdown_text_with_images.strip():
284
+ logger.warning("chunk_markdown received empty or whitespace-only input string.")
285
+ return []
286
  try:
 
287
  headers_to_split_on = [
288
  ("#", "Header 1"),
289
  ("##", "Header 2"),
290
  ("###", "Header 3"),
291
+ ("####", "Header 4"),
292
+ ("#####", "Header 5"), # Added more levels
293
+ ("######", "Header 6"),
294
  ]
 
295
  # Initialize MarkdownHeaderTextSplitter
296
  markdown_splitter = MarkdownHeaderTextSplitter(
297
  headers_to_split_on=headers_to_split_on,
298
+ strip_headers=strip_headers,
299
+ return_each_line=False # Process blocks
300
  )
 
 
 
 
301
 
302
+ logger.info("Splitting markdown by headers...")
303
+ header_chunks = markdown_splitter.split_text(markdown_text_with_images)
304
+ logger.info(f"Split into {len(header_chunks)} chunks based on headers.")
305
+
306
+ if not header_chunks:
307
+ logger.warning("MarkdownHeaderTextSplitter returned zero chunks.")
308
+ # Maybe the input had no headers? Treat the whole text as one chunk?
309
+ # Or just return empty? Let's return empty for now, as header splitting is intended.
310
+ # Alternative: create a single Document if header_chunks is empty but input wasn't.
311
+ # doc = Document(page_content=markdown_text_with_images, metadata={})
312
+ # header_chunks = [doc]
313
+ # logger.info("No headers found, treating input as a single chunk.")
314
+ # For now, stick to returning empty list if no header chunks are made.
315
+ return []
316
+
317
+
318
+ final_chunks = []
319
+ # If chunk_size is specified and > 0, further split large chunks
320
  if chunk_size > 0:
321
  text_splitter = RecursiveCharacterTextSplitter(
322
  chunk_size=chunk_size,
323
  chunk_overlap=chunk_overlap,
324
  length_function=len,
325
+ # More robust separators
326
+ separators=["\n\n", "\n", "(?<=\. )", "(?<=\? )", "(?<=! )", ", ", "; ", " ", ""],
327
+ keep_separator=False,
328
+ add_start_index=True # Add start index relative to the parent (header) chunk
329
  )
330
+ logger.info(f"Applying recursive character splitting (size={chunk_size}, overlap={chunk_overlap})...")
331
+ processed_chunks_count = 0
332
+ for i, header_chunk in enumerate(header_chunks):
333
+ # Check if page_content exists and is longer than chunk_size
334
+ if header_chunk.page_content and len(header_chunk.page_content) > chunk_size:
335
+ logger.debug(f"Header chunk {i} (length {len(header_chunk.page_content)}) needs recursive splitting.")
336
+ try:
337
+ # split_documents preserves metadata from the parent chunk
338
+ sub_chunks = text_splitter.split_documents([header_chunk])
339
+ final_chunks.extend(sub_chunks)
340
+ processed_chunks_count += len(sub_chunks)
341
+ logger.debug(f" -> Split into {len(sub_chunks)} sub-chunks.")
342
+ except Exception as split_err:
343
+ logger.error(f"Error splitting header chunk {i}: {split_err}", exc_info=True)
344
+ # Option: Add the original large chunk instead? Or skip? Let's skip broken ones.
345
+ logger.warning(f"Skipping header chunk {i} due to splitting error.")
346
+ continue
347
  else:
348
+ # If the chunk is already small enough or empty, just add it
349
+ if header_chunk.page_content: # Add only if it has content
350
+ final_chunks.append(header_chunk)
351
+ processed_chunks_count += 1
352
+ logger.debug(f"Header chunk {i} (length {len(header_chunk.page_content)}) kept as is.")
353
+ else:
354
+ logger.debug(f"Header chunk {i} was empty, skipping.")
355
+ logger.info(f"Recursive character splitting finished. Processed {processed_chunks_count} chunks.")
356
+ else:
357
+ # If chunk_size is 0, use only non-empty header chunks
358
+ logger.info("chunk_size is 0, using only non-empty header-based chunks.")
359
+ final_chunks = [chunk for chunk in header_chunks if chunk.page_content]
360
+
361
+
362
+ # Post-process final chunks: Extract embedded image data URIs and add to metadata
363
+ logger.info("Extracting embedded image data URIs for final chunk metadata...")
364
+ for chunk in final_chunks:
365
+ images_in_chunk = []
366
+ if chunk.page_content:
367
+ try:
368
+ # Regex to find all base64 data URIs in the chunk content
369
+ # Non-greedy alt text `.*?`, robust base64 chars `[A-Za-z0-9+/=]+`
370
+ # Ensure the closing parenthesis `\)` is matched correctly
371
+ pattern = r"!\[.*?\]\((data:image/[a-zA-Z+]+;base64,[A-Za-z0-9+/=]+)\)"
372
+ images_in_chunk = re.findall(pattern, chunk.page_content)
373
+ except Exception as regex_err:
374
+ logger.error(f"Regex error extracting images from chunk: {regex_err}", exc_info=True)
375
+ # Leave images list empty for this chunk
376
+
377
+ # Ensure metadata exists and add images list (can be empty)
378
+ if not hasattr(chunk, 'metadata'):
379
+ chunk.metadata = {}
380
+ chunk.metadata["images_base64"] = images_in_chunk # Use a more specific key name
381
+
382
+ logger.info(f"Created {len(final_chunks)} final chunks after processing and filtering.")
383
+ return final_chunks
384
+
385
  except Exception as e:
386
+ logger.error(f"Error during markdown chunking process: {str(e)}", exc_info=True)
387
+ raise # Re-raise to be caught by the main processing function
388
+
389
+ # --- Main Processing Function ---
390
+
391
+ def process_file_and_save(
392
+ file_obj: Any, # Gradio File object
393
+ chunk_size: int,
394
+ chunk_overlap: int,
395
+ strip_headers: bool,
396
+ hf_token: str,
397
+ repo_name: str
398
+ ) -> str:
399
+ """
400
+ Orchestrates the OCR, chunking, and saving process to Hugging Face Hub.
401
+
402
+ Args:
403
+ file_obj: The uploaded file object from Gradio.
404
+ chunk_size: Max chunk size for text splitting (chars). 0 disables recursive splitting.
405
+ chunk_overlap: Overlap for text splitting (chars).
406
+ strip_headers: Whether to remove markdown headers from chunk content.
407
+ hf_token: Hugging Face API token (write permission).
408
+ repo_name: Name for the Hugging Face dataset repository (e.g., 'username/my-ocr-dataset').
409
+
410
+ Returns:
411
+ A string indicating success or failure, suitable for display in Gradio.
412
+ """
413
+ # --- Input Validation ---
414
+ if not file_obj:
415
+ return "Error: No file uploaded. Please upload a PDF or image file."
416
+ if not repo_name or '/' not in repo_name:
417
+ return "Error: Invalid Hugging Face Repository Name. Use format 'username/dataset-name'."
418
+
419
+ # Validate chunking parameters
420
+ if chunk_size < 0:
421
+ logger.warning("Chunk size cannot be negative. Setting to 0 (header splits only).")
422
+ chunk_size = 0
423
+ if chunk_overlap < 0:
424
+ logger.warning("Chunk overlap cannot be negative. Setting to 0.")
425
+ chunk_overlap = 0
426
+ if chunk_size > 0 and chunk_overlap >= chunk_size:
427
+ logger.warning(f"Chunk overlap ({chunk_overlap}) >= chunk size ({chunk_size}). Adjusting overlap to {min(200, chunk_size // 2)}.")
428
+ chunk_overlap = min(200, chunk_size // 2) # Set a reasonable overlap
429
+
430
+ # Handle Hugging Face Token
431
+ if not hf_token:
432
+ logger.info("No explicit HF token provided. Trying to use token from local Hugging Face login.")
433
+ try:
434
+ hf_token = huggingface_hub.get_token()
435
+ if not hf_token:
436
+ return "Error: Hugging Face Token is required. Please provide a token or log in using `huggingface-cli login`."
437
+ logger.info("Using HF token from local login for dataset operations.")
438
+ except Exception as e:
439
+ logger.error(f"Error checking HF login for token: {e}", exc_info=True)
440
+ return f"Error: Hugging Face Token is required. Could not verify HF login: {e}"
441
 
 
 
442
  try:
443
+ source_filename = getattr(file_obj, 'orig_name', os.path.basename(file_obj.name))
444
+ logger.info(f"--- Starting processing for file: {source_filename} ---")
445
+
446
+ # --- Step 1: Perform OCR ---
447
+ logger.info("Step 1: Performing OCR...")
448
+ processed_markdown, _, _ = perform_ocr_file(file_obj) # raw_markdown and image_map not directly used later
449
+
450
+ # Check if OCR returned an error message or was empty/invalid
451
+ if not processed_markdown or isinstance(processed_markdown, str) and (
452
+ processed_markdown.startswith("Error:") or processed_markdown.startswith("Unsupported file type:")):
453
+ logger.error(f"OCR failed or returned error/unsupported: {processed_markdown}")
454
+ return processed_markdown # Return the error message directly
455
+ if not isinstance(processed_markdown, str) or len(processed_markdown.strip()) == 0:
456
+ logger.error("OCR processing returned empty or invalid markdown content.")
457
+ return "Error: OCR returned empty or invalid content."
458
+ logger.info("Step 1: OCR finished successfully.")
459
+
460
+ # --- Step 2: Chunk the markdown ---
461
+ logger.info("Step 2: Chunking the markdown...")
462
+ chunks = chunk_markdown(processed_markdown, chunk_size, chunk_overlap, strip_headers)
463
+
464
+ if not chunks:
465
+ logger.error("Chunking resulted in zero chunks. Check OCR output and chunking parameters.")
466
+ return "Error: Failed to chunk the document (possibly empty after OCR or no headers found)."
467
+ logger.info(f"Step 2: Chunking finished, produced {len(chunks)} chunks.")
468
+
469
+ # --- Step 3: Prepare dataset ---
470
+ logger.info("Step 3: Preparing data for Hugging Face dataset...")
471
  data: Dict[str, List[Any]] = {
472
  "chunk_id": [],
473
+ "text": [], # Renamed 'content' to 'text'
474
  "metadata": [],
475
+ "source_filename": [],
476
  }
477
+
478
  for i, chunk in enumerate(chunks):
479
+ chunk_id = f"{source_filename}_chunk_{i}"
480
+ data["chunk_id"].append(chunk_id)
481
+ data["text"].append(chunk.page_content if chunk.page_content else "") # Ensure text is string
482
+
483
+ # Ensure metadata is serializable (dicts, lists, primitives) for HF Datasets
484
+ serializable_metadata = {}
485
+ if hasattr(chunk, 'metadata') and chunk.metadata:
486
+ for k, v in chunk.metadata.items():
487
+ if isinstance(v, (str, int, float, bool, list, dict, type(None))):
488
+ serializable_metadata[k] = v
489
+ else:
490
+ # Convert potentially problematic types (like Langchain objects) to string
491
+ logger.warning(f"Chunk {chunk_id}: Metadata key '{k}' has non-standard type {type(v)}. Converting to string.")
492
+ try:
493
+ serializable_metadata[k] = str(v)
494
+ except Exception as str_err:
495
+ logger.error(f"Chunk {chunk_id}: Failed to convert metadata key '{k}' to string: {str_err}")
496
+ serializable_metadata[k] = f"ERROR_CONVERTING_{type(v).__name__}"
497
+ data["metadata"].append(serializable_metadata)
498
+ data["source_filename"].append(source_filename)
499
+
500
+
501
+ # --- Step 4: Create and push dataset to Hugging Face ---
502
+ logger.info(f"Step 4: Creating Hugging Face Dataset object for repo '{repo_name}'...")
503
+ try:
504
+ # Explicitly define features for robustness, especially if metadata varies
505
+ # features = datasets.Features({
506
+ # "chunk_id": datasets.Value("string"),
507
+ # "text": datasets.Value("string"),
508
+ # "metadata": datasets.features.Features({}), # Define known metadata fields if possible, or leave open
509
+ # "source_filename": datasets.Value("string"),
510
+ # })
511
+ # dataset = Dataset.from_dict(data, features=features)
512
+ dataset = Dataset.from_dict(data) # Simpler approach, infers features
513
+ logger.info(f"Dataset object created with {len(chunks)} rows.")
514
+ except Exception as ds_err:
515
+ logger.error(f"Failed to create Dataset object: {ds_err}", exc_info=True)
516
+ return f"Error: Failed to create dataset structure. Check logs. ({ds_err})"
517
+
518
+ logger.info(f"Connecting to Hugging Face Hub API to push to '{repo_name}'...")
519
+ try:
520
+ api = HfApi(token=hf_token) # Pass token explicitly
521
+
522
+ # Create repo if it doesn't exist
523
+ try:
524
+ api.repo_info(repo_id=repo_name, repo_type="dataset")
525
+ logger.info(f"Repository '{repo_name}' already exists. Will overwrite content.")
526
+ except huggingface_hub.utils.RepositoryNotFoundError:
527
+ logger.info(f"Repository '{repo_name}' does not exist. Creating...")
528
+ api.create_repo(repo_id=repo_name, repo_type="dataset", private=False) # Default to public
529
+ logger.info(f"Successfully created repository '{repo_name}'.")
530
+
531
+ # Push the dataset
532
+ logger.info(f"Pushing dataset to '{repo_name}'...")
533
+ commit_message = f"Add/update OCR data from file: {source_filename}"
534
+ # push_to_hub overwrites the dataset by default
535
+ dataset.push_to_hub(repo_name, commit_message=commit_message)
536
+ repo_url = f"https://huggingface.co/datasets/{repo_name}"
537
+ logger.info(f"Dataset successfully pushed to {repo_url}")
538
+ return f"Success! Dataset with {len(chunks)} chunks saved to Hugging Face: {repo_url}"
539
+
540
+ except huggingface_hub.utils.HfHubHTTPError as hf_http_err:
541
+ logger.error(f"Hugging Face Hub HTTP Error: {hf_http_err}", exc_info=True)
542
+ return f"Error: Hugging Face Hub Error pushing to '{repo_name}'. Status: {hf_http_err.response.status_code}. Check token permissions, repo name, and network. Details: {hf_http_err}"
543
+ except Exception as push_err:
544
+ logger.error(f"Failed to push dataset to '{repo_name}': {push_err}", exc_info=True)
545
+ return f"Error: Failed to push dataset to Hugging Face repository '{repo_name}'. ({push_err})"
546
+
547
  except Exception as e:
548
+ # Catch any unexpected errors during the overall process
549
+ logger.error(f"An unexpected error occurred processing '{source_filename}': {str(e)}", exc_info=True)
550
+ return f"An unexpected error occurred: {str(e)}"
551
+ finally:
552
+ logger.info(f"--- Finished processing for file: {source_filename} ---")
553
+
554
+
555
+ # --- Gradio Interface ---
556
+ with gr.Blocks(title="Mistral OCR & Dataset Creator", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as demo:
557
+ gr.Markdown("# Mistral OCR, Markdown Chunking, and Hugging Face Dataset Creator")
558
+ gr.Markdown(
559
+ """
560
+ Upload a PDF or image file (PNG, JPG, WEBP, BMP). The application will:
561
+ 1. Extract text and images using **Mistral OCR**.
562
+ 2. Embed images as base64 data URIs directly within the extracted markdown text.
563
+ 3. Chunk the resulting markdown based on **headers** and optionally **recursively by character count**.
564
+ 4. Store any embedded base64 images found **within each chunk** in the chunk's metadata (`metadata['images_base64']`).
565
+ 5. Create or update a **Hugging Face Dataset** with the processed chunks (`chunk_id`, `text`, `metadata`, `source_filename`).
566
+ """
567
+ )
568
 
 
 
 
 
 
569
  with gr.Row():
570
+ with gr.Column(scale=1):
571
+ file_input = gr.File(
572
+ label="Upload PDF or Image File",
573
+ file_types=['.pdf', '.png', '.jpg', '.jpeg', '.webp', '.bmp'],
574
+ type="filepath" # Ensures we get a path usable by `open()`
575
+ )
576
+
577
+ gr.Markdown("## Chunking Options")
578
+ chunk_size = gr.Slider(
579
+ minimum=0, maximum=8000, value=1000, step=100, # Increased max size
580
+ label="Max Chunk Size (Characters)",
581
+ info="Approximate target size. Set to 0 to disable recursive splitting (uses only header splits)."
582
+ )
583
+ chunk_overlap = gr.Slider(
584
+ minimum=0, maximum=1000, value=200, step=50,
585
+ label="Chunk Overlap (Characters)",
586
+ info="Number of characters to overlap between consecutive chunks (if recursive splitting is enabled)."
587
+ )
588
+ strip_headers = gr.Checkbox(
589
+ label="Strip Markdown Headers (#) from Chunk Content",
590
+ value=True,
591
+ info="If checked, removes '#', '##' etc. from the start of the text in each chunk."
592
+ )
593
+
594
+ gr.Markdown("## Hugging Face Output Options")
595
+ repo_name = gr.Textbox(
596
+ label="Target Hugging Face Dataset Repository",
597
+ placeholder="your-username/your-dataset-name",
598
+ info="The dataset will be pushed here (e.g., 'my-org/my-ocr-data'). Will be created if it doesn't exist."
599
+ )
600
+ hf_token = gr.Textbox(
601
+ label="Hugging Face Token (write permission)",
602
+ type="password",
603
+ placeholder="hf_...",
604
+ info="Required to create/push the dataset. If blank, will try using token from local `huggingface-cli login`.",
605
+ # value=os.environ.get("HF_TOKEN", "") # Optionally pre-fill from env var if desired
606
+ )
607
+
608
+ submit_btn = gr.Button("Process File and Save to Hugging Face", variant="primary")
609
+
610
+ with gr.Column(scale=1):
611
+ output = gr.Textbox(
612
+ label="Processing Log / Result Status",
613
+ lines=20,
614
+ interactive=False,
615
+ placeholder="Processing steps and final result will appear here..."
616
+ )
617
 
618
  submit_btn.click(
619
  fn=process_file_and_save,
 
621
  outputs=output
622
  )
623
 
624
+ gr.Examples(
625
+ examples=[
626
+ [None, 1000, 200, True, "", "hf-username/my-first-ocr-dataset"],
627
+ [None, 2000, 400, True, "", "hf-username/large-chunk-ocr-data"],
628
+ [None, 0, 0, False, "", "hf-username/header-only-ocr-data"], # Example for header-only splitting
629
+ ],
630
+ inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name],
631
+ outputs=output,
632
+ fn=process_file_and_save, # Make examples clickable
633
+ cache_examples=False # Avoid caching as it involves API calls and file processing
634
+ )
635
+
636
+ gr.Markdown("--- \n *Requires `MISTRAL_API_KEY` environment variable or being logged in via `huggingface-cli login`.*")
637
+
638
+ # --- Launch the Gradio App ---
639
+ if __name__ == "__main__":
640
+ # Check if client initialization failed earlier
641
+ if not client and api_key: # Check if key was present but init failed
642
+ print("\nCRITICAL: Mistral client failed to initialize. The application cannot perform OCR.")
643
+ print("Please check your MISTRAL_API_KEY and network connection.\n")
644
+ # Optionally exit, or let Gradio launch with limited functionality
645
+ # exit(1)
646
+ elif not client and not api_key:
647
+ print("\nWARNING: Mistral client not initialized because no API key was found.")
648
+ print("OCR functionality will fail. Please set MISTRAL_API_KEY or log in via `huggingface-cli login`.\n")
649
+
650
+ # share=True creates a public link (useful for Colab/Spaces)
651
+ # debug=True provides detailed errors in the console during development
652
+ demo.launch(share=os.getenv('GRADIO_SHARE', 'False').lower() == 'true', debug=True,)