Svngoku commited on
Commit
c5e1e79
·
verified ·
1 Parent(s): b88bc18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -65
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
- from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain.schema import Document
4
  from typing import List
5
  import logging
 
6
  from pathlib import Path
7
  import requests
8
  import base64
@@ -91,61 +92,61 @@ def perform_ocr_file(file):
91
  except Exception as e:
92
  return f"Error during OCR: {str(e)}", ""
93
 
94
- # Function to chunk markdown text
 
 
 
 
 
 
95
  def chunk_markdown(
96
  markdown_text: str,
97
  chunk_size: int = 1000,
98
  chunk_overlap: int = 200,
99
- preserve_numbering: bool = True
100
  ) -> List[Document]:
101
- if chunk_size <= 0:
102
- raise ValueError("chunk_size must be positive")
103
- if chunk_overlap < 0:
104
- raise ValueError("chunk_overlap cannot be negative")
105
- if chunk_overlap >= chunk_size:
106
- raise ValueError("chunk_overlap must be less than chunk_size")
107
-
108
  try:
109
- document = Document(page_content=markdown_text, metadata={"source": "ocr_output"})
110
-
111
- separators = (
112
- ["\n\d+\.\s+", "\n\n", "\n", ".", " ", ""]
113
- if preserve_numbering
114
- else ["\n\n", "\n", ".", " ", ""]
115
- )
116
-
117
- text_splitter = RecursiveCharacterTextSplitter(
118
- chunk_size=chunk_size,
119
- chunk_overlap=chunk_overlap,
120
- length_function=len,
121
- separators=separators, # Fixed parameter name
122
- keep_separator=True,
123
- add_start_index=True,
124
- is_separator_regex=preserve_numbering
125
  )
126
 
127
- logger.info("Splitting markdown text into chunks")
128
- chunks = text_splitter.split_documents([document])
129
-
130
- if preserve_numbering:
131
- merged_chunks = []
132
- current_chunk = None
133
-
 
 
 
 
 
 
 
 
 
134
  for chunk in chunks:
135
- content = chunk.page_content.strip()
136
- if current_chunk is None:
137
- current_chunk = chunk
138
- elif content.startswith(tuple(f"{i}." for i in range(10))):
139
- if current_chunk:
140
- merged_chunks.append(current_chunk)
141
- current_chunk = chunk
142
  else:
143
- current_chunk.page_content += "\n" + content
144
- current_chunk.metadata["end_index"] = chunk.metadata["start_index"] + len(content)
145
-
146
- if current_chunk:
147
- merged_chunks.append(current_chunk)
148
- chunks = merged_chunks
 
149
 
150
  logger.info(f"Created {len(chunks)} chunks")
151
  return chunks
@@ -162,7 +163,7 @@ def text_to_base64_dummy(text: str, chunk_index: int):
162
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
163
 
164
  # Process file: OCR -> Chunk -> Save
165
- def process_file_and_save(file, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name):
166
  try:
167
  # Step 1: Perform OCR
168
  combined_markdown, raw_markdown = perform_ocr_file(file)
@@ -170,29 +171,24 @@ def process_file_and_save(file, chunk_size, chunk_overlap, preserve_numbering, h
170
  return combined_markdown
171
 
172
  # Step 2: Chunk the markdown
173
- chunks = chunk_markdown(combined_markdown, chunk_size, chunk_overlap, preserve_numbering)
174
 
175
  # Step 3: Prepare dataset
176
  data = {
177
  "chunk_id": [],
178
  "content": [],
179
  "metadata": [],
180
- "page_image": []
181
  }
182
 
183
  for i, chunk in enumerate(chunks):
184
  data["chunk_id"].append(i)
185
  data["content"].append(chunk.page_content)
186
- data["metadata"].append(chunk.metadata)
187
- img_base64 = None
188
- if "![image" in chunk.page_content:
189
- start = chunk.page_content.find("data:image")
190
- if start != -1:
191
- end = chunk.page_content.find(")", start)
192
- img_base64 = chunk.page_content[start:end]
193
- if not img_base64:
194
- img_base64 = text_to_base64_dummy(chunk.page_content, i)
195
- data["page_image"].append(img_base64)
196
 
197
  # Step 4: Create and push dataset to Hugging Face
198
  dataset = Dataset.from_dict(data)
@@ -205,16 +201,16 @@ def process_file_and_save(file, chunk_size, chunk_overlap, preserve_numbering, h
205
  return f"Error: {str(e)}"
206
 
207
  # Gradio Interface
208
- with gr.Blocks(title="PDF/Image OCR, Chunking, and Dataset Creator") as demo:
209
- gr.Markdown("# PDF/Image OCR, Chunking, and Dataset Creator")
210
- gr.Markdown("Upload a PDF or image, extract text/images with Mistral OCR, chunk the markdown, and save to Hugging Face.")
211
 
212
  with gr.Row():
213
  with gr.Column():
214
  file_input = gr.File(label="Upload PDF or Image")
215
- chunk_size = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size")
216
  chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap")
217
- preserve_numbering = gr.Checkbox(label="Preserve Numbering", value=True)
218
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
219
  repo_name = gr.Textbox(label="Hugging Face Repository Name (e.g., username/dataset-name)")
220
  submit_btn = gr.Button("Process and Save")
@@ -224,7 +220,7 @@ with gr.Blocks(title="PDF/Image OCR, Chunking, and Dataset Creator") as demo:
224
 
225
  submit_btn.click(
226
  fn=process_file_and_save,
227
- inputs=[file_input, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name],
228
  outputs=output
229
  )
230
 
 
1
  import gradio as gr
2
+ from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
3
  from langchain.schema import Document
4
  from typing import List
5
  import logging
6
+ import re
7
  from pathlib import Path
8
  import requests
9
  import base64
 
92
  except Exception as e:
93
  return f"Error during OCR: {str(e)}", ""
94
 
95
+ # Function to extract base64 images from markdown content
96
+ def extract_images_from_markdown(markdown_text: str) -> List[str]:
97
+ # Regex to match markdown image syntax with base64 data
98
+ pattern = r"!\[.*?\]\((data:image/[a-z]+;base64,[^\)]+)\)"
99
+ return re.findall(pattern, markdown_text)
100
+
101
+ # Function to chunk markdown text with image handling
102
  def chunk_markdown(
103
  markdown_text: str,
104
  chunk_size: int = 1000,
105
  chunk_overlap: int = 200,
106
+ strip_headers: bool = True
107
  ) -> List[Document]:
 
 
 
 
 
 
 
108
  try:
109
+ # Define headers to split on
110
+ headers_to_split_on = [
111
+ ("#", "Header 1"),
112
+ ("##", "Header 2"),
113
+ ("###", "Header 3"),
114
+ ]
115
+
116
+ # Initialize MarkdownHeaderTextSplitter
117
+ markdown_splitter = MarkdownHeaderTextSplitter(
118
+ headers_to_split_on=headers_to_split_on,
119
+ strip_headers=strip_headers
 
 
 
 
 
120
  )
121
 
122
+ # Split markdown by headers
123
+ logger.info("Splitting markdown by headers")
124
+ chunks = markdown_splitter.split_text(markdown_text)
125
+
126
+ # If chunk_size is specified, further split large chunks
127
+ if chunk_size > 0:
128
+ text_splitter = RecursiveCharacterTextSplitter(
129
+ chunk_size=chunk_size,
130
+ chunk_overlap=chunk_overlap,
131
+ length_function=len,
132
+ separators=["\n\n", "\n", ".", " ", ""],
133
+ keep_separator=True,
134
+ add_start_index=True
135
+ )
136
+ logger.info(f"Applying character-level splitting with chunk_size={chunk_size}")
137
+ final_chunks = []
138
  for chunk in chunks:
139
+ if len(chunk.page_content) > chunk_size:
140
+ sub_chunks = text_splitter.split_documents([chunk])
141
+ final_chunks.extend(sub_chunks)
 
 
 
 
142
  else:
143
+ final_chunks.append(chunk)
144
+ chunks = final_chunks
145
+
146
+ # Add images to metadata
147
+ for chunk in chunks:
148
+ images = extract_images_from_markdown(chunk.page_content)
149
+ chunk.metadata["images"] = images
150
 
151
  logger.info(f"Created {len(chunks)} chunks")
152
  return chunks
 
163
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
164
 
165
  # Process file: OCR -> Chunk -> Save
166
+ def process_file_and_save(file, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name):
167
  try:
168
  # Step 1: Perform OCR
169
  combined_markdown, raw_markdown = perform_ocr_file(file)
 
171
  return combined_markdown
172
 
173
  # Step 2: Chunk the markdown
174
+ chunks = chunk_markdown(combined_markdown, chunk_size, chunk_overlap, strip_headers)
175
 
176
  # Step 3: Prepare dataset
177
  data = {
178
  "chunk_id": [],
179
  "content": [],
180
  "metadata": [],
181
+ "images": [] # Changed to store list of images
182
  }
183
 
184
  for i, chunk in enumerate(chunks):
185
  data["chunk_id"].append(i)
186
  data["content"].append(chunk.page_content)
187
+ data["metadata"].append({k: v for k, v in chunk.metadata.items() if k != "images"}) # Exclude images from metadata column
188
+ images = chunk.metadata.get("images", [])
189
+ if not images: # If no images, add a placeholder
190
+ images = [text_to_base64_dummy(chunk.page_content, i)]
191
+ data["images"].append(images)
 
 
 
 
 
192
 
193
  # Step 4: Create and push dataset to Hugging Face
194
  dataset = Dataset.from_dict(data)
 
201
  return f"Error: {str(e)}"
202
 
203
  # Gradio Interface
204
+ with gr.Blocks(title="PDF/Image OCR, Markdown Chunking, and Dataset Creator") as demo:
205
+ gr.Markdown("# PDF/Image OCR, Markdown Chunking, and Dataset Creator")
206
+ gr.Markdown("Upload a PDF or image, extract text/images with Mistral OCR, chunk the markdown by headers, and save to Hugging Face.")
207
 
208
  with gr.Row():
209
  with gr.Column():
210
  file_input = gr.File(label="Upload PDF or Image")
211
+ chunk_size = gr.Slider(0, 2000, value=1000, step=100, label="Max Chunk Size (0 to disable)")
212
  chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap")
213
+ strip_headers = gr.Checkbox(label="Strip Headers from Content", value=True)
214
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
215
  repo_name = gr.Textbox(label="Hugging Face Repository Name (e.g., username/dataset-name)")
216
  submit_btn = gr.Button("Process and Save")
 
220
 
221
  submit_btn.click(
222
  fn=process_file_and_save,
223
+ inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name],
224
  outputs=output
225
  )
226