Svngoku commited on
Commit
8a5a9ab
·
verified ·
1 Parent(s): 099e67a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -60
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- from langchain.document_loaders import PyMuPDFLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.schema import Document
5
  from typing import List
@@ -8,22 +7,95 @@ from pathlib import Path
8
  import requests
9
  import base64
10
  import io
11
- import fitz
12
  from PIL import Image
13
  from datasets import Dataset
14
  from huggingface_hub import HfApi
15
  import os
 
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- # Original chunk_pdf function (slightly modified for Gradio)
22
- def chunk_pdf(
23
- file_path: str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  chunk_size: int = 1000,
25
  chunk_overlap: int = 200,
26
- encoding: str = "utf-8",
27
  preserve_numbering: bool = True
28
  ) -> List[Document]:
29
  if chunk_size <= 0:
@@ -34,27 +106,8 @@ def chunk_pdf(
34
  raise ValueError("chunk_overlap must be less than chunk_size")
35
 
36
  try:
37
- temp_file = None
38
- if file_path.startswith(("http://", "https://")):
39
- logger.info(f"Downloading PDF from {file_path}")
40
- response = requests.get(file_path, stream=True, timeout=10)
41
- response.raise_for_status()
42
- temp_file = Path("temp.pdf")
43
- with open(temp_file, "wb") as f:
44
- for chunk in response.iter_content(chunk_size=8192):
45
- f.write(chunk)
46
- file_path = str(temp_file)
47
- elif not Path(file_path).exists():
48
- raise FileNotFoundError(f"PDF file not found at: {file_path}")
49
-
50
- logger.info(f"Loading PDF from {file_path}")
51
- loader = PyMuPDFLoader(file_path)
52
- pages = loader.load()
53
 
54
- if not pages:
55
- logger.warning(f"No content extracted from {file_path}")
56
- return []
57
-
58
  separators = (
59
  ["\n\d+\.\s+", "\n\n", "\n", ".", " ", ""]
60
  if preserve_numbering
@@ -65,14 +118,14 @@ def chunk_pdf(
65
  chunk_size=chunk_size,
66
  chunk_overlap=chunk_overlap,
67
  length_function=len,
68
- separators=separators,
69
  keep_separator=True,
70
  add_start_index=True,
71
  is_separator_regex=preserve_numbering
72
  )
73
 
74
- logger.info(f"Splitting {len(pages)} pages into chunks")
75
- chunks = text_splitter.split_documents(pages)
76
 
77
  if preserve_numbering:
78
  merged_chunks = []
@@ -98,31 +151,28 @@ def chunk_pdf(
98
  return chunks
99
 
100
  except Exception as e:
101
- logger.error(f"Error processing PDF {file_path}: {str(e)}")
102
  raise
103
- finally:
104
- if temp_file and temp_file.exists():
105
- temp_file.unlink()
106
-
107
- # Custom function to convert PDF page to base64
108
- def pdf_page_to_base64(pdf_path: str, page_number: int):
109
- pdf_document = fitz.open(pdf_path)
110
- page = pdf_document.load_page(page_number - 1) # input is one-indexed
111
- pix = page.get_pixmap()
112
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
113
 
 
 
 
114
  buffer = io.BytesIO()
115
  img.save(buffer, format="PNG")
116
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
117
 
118
- # Function to process PDF and create dataset
119
- def process_pdf_and_save(pdf_file, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name):
120
  try:
121
- # Save uploaded file temporarily
122
- pdf_path = pdf_file.name
123
- chunks = chunk_pdf(pdf_path, chunk_size, chunk_overlap, "utf-8", preserve_numbering)
 
 
 
 
124
 
125
- # Prepare dataset
126
  data = {
127
  "chunk_id": [],
128
  "content": [],
@@ -134,14 +184,20 @@ def process_pdf_and_save(pdf_file, chunk_size, chunk_overlap, preserve_numbering
134
  data["chunk_id"].append(i)
135
  data["content"].append(chunk.page_content)
136
  data["metadata"].append(chunk.metadata)
137
- page_num = chunk.metadata.get("page", 1)
138
- img_base64 = pdf_page_to_base64(pdf_path, page_num)
 
 
 
 
 
 
 
 
139
  data["page_image"].append(img_base64)
140
 
141
- # Create Hugging Face dataset
142
  dataset = Dataset.from_dict(data)
143
-
144
- # Push to Hugging Face
145
  api = HfApi()
146
  api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True)
147
  dataset.push_to_hub(repo_name, token=hf_token)
@@ -151,13 +207,13 @@ def process_pdf_and_save(pdf_file, chunk_size, chunk_overlap, preserve_numbering
151
  return f"Error: {str(e)}"
152
 
153
  # Gradio Interface
154
- with gr.Blocks(title="PDF Chunking and Dataset Creator") as demo:
155
- gr.Markdown("# PDF Chunking and Dataset Creator")
156
- gr.Markdown("Upload a PDF, configure chunking parameters, and save the dataset to Hugging Face.")
157
 
158
  with gr.Row():
159
  with gr.Column():
160
- pdf_input = gr.File(label="Upload PDF")
161
  chunk_size = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size")
162
  chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap")
163
  preserve_numbering = gr.Checkbox(label="Preserve Numbering", value=True)
@@ -169,11 +225,9 @@ with gr.Blocks(title="PDF Chunking and Dataset Creator") as demo:
169
  output = gr.Textbox(label="Result")
170
 
171
  submit_btn.click(
172
- fn=process_pdf_and_save,
173
- inputs=[pdf_input, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name],
174
  outputs=output
175
  )
176
 
177
- demo.launch(
178
- share=True,
179
- )
 
1
  import gradio as gr
 
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain.schema import Document
4
  from typing import List
 
7
  import requests
8
  import base64
9
  import io
 
10
  from PIL import Image
11
  from datasets import Dataset
12
  from huggingface_hub import HfApi
13
  import os
14
+ from mistralai import Mistral
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Mistral OCR setup (ensure you have your API key set)
21
+ api_key = os.environ.get("MISTRAL_API_KEY")
22
+ if not api_key:
23
+ raise ValueError("MISTRAL_API_KEY environment variable not set")
24
+ client = Mistral(api_key=api_key)
25
+
26
+ # Function to encode image to base64
27
+ def encode_image(image_path):
28
+ try:
29
+ with open(image_path, "rb") as image_file:
30
+ return base64.b64encode(image_file.read()).decode('utf-8')
31
+ except FileNotFoundError:
32
+ return "Error: The file was not found."
33
+ except Exception as e:
34
+ return f"Error: {e}"
35
+
36
+ # Function to replace images in markdown with base64 strings
37
+ def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
38
+ for img_name, base64_str in images_dict.items():
39
+ markdown_str = markdown_str.replace(f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})")
40
+ return markdown_str
41
+
42
+ # Function to combine markdown from OCR response
43
+ def get_combined_markdown(ocr_response) -> tuple:
44
+ markdowns = []
45
+ raw_markdowns = []
46
+ for page in ocr_response.pages:
47
+ image_data = {}
48
+ for img in page.images:
49
+ image_data[img.id] = img.image_base64
50
+ markdowns.append(replace_images_in_markdown(page.markdown, image_data))
51
+ raw_markdowns.append(page.markdown)
52
+ return "\n\n".join(markdowns), "\n\n".join(raw_markdowns)
53
+
54
+ # Perform OCR on uploaded file
55
+ def perform_ocr_file(file):
56
+ try:
57
+ if file.name.lower().endswith('.pdf'):
58
+ uploaded_pdf = client.files.upload(
59
+ file={
60
+ "file_name": file.name,
61
+ "content": open(file.name, "rb"),
62
+ },
63
+ purpose="ocr"
64
+ )
65
+ signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id)
66
+ ocr_response = client.ocr.process(
67
+ model="mistral-ocr-latest",
68
+ document={
69
+ "type": "document_url",
70
+ "document_url": signed_url.url,
71
+ },
72
+ include_image_base64=True
73
+ )
74
+ client.files.delete(file_id=uploaded_pdf.id)
75
+
76
+ elif file.name.lower().endswith(('.png', '.jpg', '.jpeg')):
77
+ base64_image = encode_image(file.name)
78
+ ocr_response = client.ocr.process(
79
+ model="mistral-ocr-latest",
80
+ document={
81
+ "type": "image_url",
82
+ "image_url": f"data:image/jpeg;base64,{base64_image}"
83
+ },
84
+ include_image_base64=True
85
+ )
86
+ else:
87
+ return "Unsupported file type. Please provide a PDF or an image (png, jpeg, jpg).", ""
88
+
89
+ combined_markdown, raw_markdown = get_combined_markdown(ocr_response)
90
+ return combined_markdown, raw_markdown
91
+ except Exception as e:
92
+ return f"Error during OCR: {str(e)}", ""
93
+
94
+ # Function to chunk markdown text
95
+ def chunk_markdown(
96
+ markdown_text: str,
97
  chunk_size: int = 1000,
98
  chunk_overlap: int = 200,
 
99
  preserve_numbering: bool = True
100
  ) -> List[Document]:
101
  if chunk_size <= 0:
 
106
  raise ValueError("chunk_overlap must be less than chunk_size")
107
 
108
  try:
109
+ document = Document(page_content=markdown_text, metadata={"source": "ocr_output"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
 
 
 
 
111
  separators = (
112
  ["\n\d+\.\s+", "\n\n", "\n", ".", " ", ""]
113
  if preserve_numbering
 
118
  chunk_size=chunk_size,
119
  chunk_overlap=chunk_overlap,
120
  length_function=len,
121
+  on=separators,
122
  keep_separator=True,
123
  add_start_index=True,
124
  is_separator_regex=preserve_numbering
125
  )
126
 
127
+ logger.info("Splitting markdown text into chunks")
128
+ chunks = text_splitter.split_documents([document])
129
 
130
  if preserve_numbering:
131
  merged_chunks = []
 
151
  return chunks
152
 
153
  except Exception as e:
154
+ logger.error(f"Error processing markdown: {str(e)}")
155
  raise
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # Placeholder image generation (for chunks without images)
158
+ def text_to_base64_dummy(text: str, chunk_index: int):
159
+ img = Image.new('RGB', (200, 200), color='white')
160
  buffer = io.BytesIO()
161
  img.save(buffer, format="PNG")
162
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
163
 
164
+ # Process file: OCR -> Chunk -> Save
165
+ def process_file_and_save(file, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name):
166
  try:
167
+ # Step 1: Perform OCR
168
+ combined_markdown, raw_markdown = perform_ocr_file(file)
169
+ if "Error" in combined_markdown:
170
+ return combined_markdown
171
+
172
+ # Step 2: Chunk the markdown
173
+ chunks = chunk_markdown(combined_markdown, chunk_size, chunk_overlap, preserve_numbering)
174
 
175
+ # Step 3: Prepare dataset
176
  data = {
177
  "chunk_id": [],
178
  "content": [],
 
184
  data["chunk_id"].append(i)
185
  data["content"].append(chunk.page_content)
186
  data["metadata"].append(chunk.metadata)
187
+ # Extract base64 images from markdown if present, else use placeholder
188
+ img_base64 = None
189
+ if "![image" in chunk.page_content:
190
+ # Simple extraction (assumes one image per chunk for simplicity)
191
+ start = chunk.page_content.find("data:image")
192
+ if start != -1:
193
+ end = chunk.page_content.find(")", start)
194
+ img_base64 = chunk.page_content[start:end]
195
+ if not img_base64:
196
+ img_base64 = text_to_base64_dummy(chunk.page_content, i)
197
  data["page_image"].append(img_base64)
198
 
199
+ # Step 4: Create and push dataset to Hugging Face
200
  dataset = Dataset.from_dict(data)
 
 
201
  api = HfApi()
202
  api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True)
203
  dataset.push_to_hub(repo_name, token=hf_token)
 
207
  return f"Error: {str(e)}"
208
 
209
  # Gradio Interface
210
+ with gr.Blocks(title="PDF/Image OCR, Chunking, and Dataset Creator") as demo:
211
+ gr.Markdown("# PDF/Image OCR, Chunking, and Dataset Creator")
212
+ gr.Markdown("Upload a PDF or image, extract text/images with Mistral OCR, chunk the markdown, and save to Hugging Face.")
213
 
214
  with gr.Row():
215
  with gr.Column():
216
+ file_input = gr.File(label="Upload PDF or Image")
217
  chunk_size = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size")
218
  chunk_overlap = gr.Slider(0, 500, value=200, step=50, label="Chunk Overlap")
219
  preserve_numbering = gr.Checkbox(label="Preserve Numbering", value=True)
 
225
  output = gr.Textbox(label="Result")
226
 
227
  submit_btn.click(
228
+ fn=process_file_and_save,
229
+ inputs=[file_input, chunk_size, chunk_overlap, preserve_numbering, hf_token, repo_name],
230
  outputs=output
231
  )
232
 
233
+ demo.launch(share=True)