SkyNait commited on
Commit
ee89119
·
1 Parent(s): aa23348

Topic extraction upgrades

Browse files
__pycache__/inference_svm_model.cpython-310.pyc CHANGED
Binary files a/__pycache__/inference_svm_model.cpython-310.pyc and b/__pycache__/inference_svm_model.cpython-310.pyc differ
 
__pycache__/mineru_single.cpython-310.pyc CHANGED
Binary files a/__pycache__/mineru_single.cpython-310.pyc and b/__pycache__/mineru_single.cpython-310.pyc differ
 
__pycache__/topic_extraction_upgrade.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
__pycache__/worker.cpython-310.pyc CHANGED
Binary files a/__pycache__/worker.cpython-310.pyc and b/__pycache__/worker.cpython-310.pyc differ
 
input_output/aqa-Mathematics-specification.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d2a3998a4988ef6881262f22660a8e0719bb8d648a757db91041cfabbf40bb3
3
+ size 888895
mineru_test_local.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import re
4
+ import gc
5
+ from magic_pdf.data.dataset import PymuDocDataset
6
+ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
7
+ import json
8
+ import base64
9
+ import logging
10
+ import concurrent.futures
11
+ from io import BytesIO
12
+ from google import genai
13
+ from google.genai import types
14
+ import torch
15
+ import cv2
16
+
17
+ from inference_svm_model import SVMModel
18
+ from topic_extraction_upgrade import TableExtractor
19
+
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
23
+ handlers=[
24
+ logging.StreamHandler(),
25
+ logging.FileHandler('mineru.log')
26
+ ]
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+ logger.setLevel(logging.INFO)
30
+
31
+ def call_gemini_for_table_classification(image_data: bytes) -> str:
32
+ """
33
+ Returns one of: "TWO_COLUMN", "THREE_COLUMN", or "NO_TABLE".
34
+ """
35
+ if genai is None or types is None:
36
+ logger.warning("Gemini libraries not available. Defaulting to NO_TABLE.")
37
+ return "NO_TABLE"
38
+
39
+ prompt = """You are given an image. Determine if it shows a relevant table that has exactly 2 or 3 columns.
40
+ The 'relevant' table examples are the first and second reference images. The third reference image is irrelevant.
41
+ If the image is a relevant table with 2 columns, respond with 'TWO_COLUMN'.
42
+ If the image is a relevant table with 3 columns, respond with 'THREE_COLUMN'.
43
+ If the image does not show a relevant table with 2 or 3 columns, respond with 'NO_TABLE'.
44
+ Return only one of these exact labels as your entire response:
45
+ TWO_COLUMN
46
+ THREE_COLUMN
47
+ NO_TABLE
48
+ """
49
+ try:
50
+ client = genai.Client(api_key="YOUR_GEMINI_API_KEY") # Provide your real API key
51
+ response = client.models.generate_content(
52
+ model="gemini-2.0-flash",
53
+ config=types.GenerateContentConfig(temperature=0.),
54
+ contents=[
55
+ {
56
+ "parts": [
57
+ {"text": prompt},
58
+ {
59
+ "inline_data": {
60
+ "mime_type": "image/jpeg",
61
+ "data": base64.b64encode(image_data).decode('utf-8')
62
+ }
63
+ }
64
+ ]
65
+ }
66
+ ]
67
+ )
68
+
69
+ classification = response.text.strip() if (response and response.text) else "NO_TABLE"
70
+ classification = classification.upper()
71
+ if "THREE" in classification:
72
+ return "THREE_COLUMN"
73
+ elif "TWO" in classification:
74
+ return "TWO_COLUMN"
75
+ else:
76
+ return "NO_TABLE"
77
+
78
+ except Exception as e:
79
+ logger.error(f"[Gemini Table Classification Error]: {str(e)}")
80
+ return "NO_TABLE"
81
+
82
+ def call_gemini_for_image_description(image_data: bytes) -> str:
83
+ if genai is None or types is None:
84
+ logger.warning("Gemini libraries not available. Returning fallback description.")
85
+ return "Image description unavailable"
86
+
87
+ try:
88
+ client = genai.Client(api_key="AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU") # Provide your real API key
89
+ response = client.models.generate_content(
90
+ model="gemini-2.0-flash",
91
+ config=types.GenerateContentConfig(temperature=0.),
92
+ contents=[
93
+ {
94
+ "parts": [
95
+ {
96
+ "text": """The provided image is a part of a question paper or markscheme.
97
+ Extract all the necessary information from the image to be able to identify the question.
98
+ To identify the question, we only need the following: question number and question part.
99
+ Don't include redundant information.
100
+ For example, if image contains text like: "Q1 Part A Answer: Life on earth was created by diety..."
101
+ you should return just "Q1 Part A Mark Scheme"
102
+ If there is no text on this image, return the description of the image. 20 words max.
103
+ If there are not enough data, consider information from the surrounding context.
104
+ Additionally, if the image contains a truncated part, you must describe it and mark as a
105
+ part of some another image that goes before or after current image.
106
+ If the image is of a multiple-choice question’s options, then modify your answer by appending
107
+ 'MCQ: A [option] B [option] C [option] D [option]' (replacing [option] with the actual options).
108
+ Otherwise, follow the above instructions strictly.
109
+ """},
110
+ {
111
+ "inline_data": {
112
+ "mime_type": "image/jpeg",
113
+ "data": base64.b64encode(image_data).decode('utf-8')
114
+ }
115
+ }
116
+ ]
117
+ }
118
+ ]
119
+ )
120
+
121
+ description = response.text.strip() if (response and response.text) else "Image description unavailable"
122
+ return description
123
+
124
+ except Exception as e:
125
+ logger.error(f"[Gemini Description Error]: {str(e)}")
126
+ return "Image description unavailable"
127
+
128
+ class DataWriter:
129
+ """
130
+ Base class for handling extracted images.
131
+ """
132
+ def write(self, path: str, data: bytes) -> None:
133
+ raise NotImplementedError
134
+
135
+ def post_process(self, key: str, md_content: str) -> str:
136
+ raise NotImplementedError
137
+
138
+ class LocalImageWriter(DataWriter):
139
+ """
140
+ Stores extracted images locally so they can be referenced in local Markdown previews.
141
+ SVM filters out blank images. Then we do Gemini classification for table detection or normal description.
142
+ Finally, we rewrite the Markdown to reference these local images.
143
+ """
144
+ def __init__(self, output_folder: str, svm_model: SVMModel):
145
+ """
146
+ :param output_folder: Base folder where images and final MD will be saved.
147
+ :param svm_model: SVM model for blank image detection.
148
+ """
149
+ self.output_folder = output_folder
150
+ self.svm_model = svm_model
151
+ self.descriptions = {}
152
+ """
153
+ self.descriptions structure:
154
+ {
155
+ "{local_id_or_path}": {
156
+ "data": bytes,
157
+ "relative_path": str, # relative path to the saved image
158
+ "description": "", # gemini description
159
+ "table_classification": "TWO_COLUMN" / "THREE_COLUMN" / "NO_TABLE"
160
+ "final_alt": "" # final alt text for the MD
161
+ }
162
+ }
163
+ """
164
+ os.makedirs(self.output_folder, exist_ok=True)
165
+ self.images_dir = os.path.join(self.output_folder, "images")
166
+ os.makedirs(self.images_dir, exist_ok=True)
167
+
168
+ self._img_count = 0
169
+
170
+ def write(self, path: str, data: bytes) -> None:
171
+ """
172
+ 1) Use SVM to check if blank/irrelevant.
173
+ 2) If not blank, save the image locally (images/img_{count}.png).
174
+ 3) Keep track in self.descriptions for post-process usage.
175
+ """
176
+ is_blank = self.svm_model.is_blank_image(data)
177
+ if is_blank:
178
+ logger.info(f"[SVM] Detected blank/irrelevant image: {path}. Skipping.")
179
+ return
180
+
181
+ self._img_count += 1
182
+ # Example local path
183
+ local_filename = f"img_{self._img_count}.png"
184
+ local_path = os.path.join(self.images_dir, local_filename)
185
+
186
+ with open(local_path, "wb") as f:
187
+ f.write(data)
188
+
189
+ rel_path_for_md = os.path.relpath(local_path, self.output_folder)
190
+
191
+ self.descriptions[path] = {
192
+ "data": data,
193
+ "relative_path": rel_path_for_md, # e.g. "images/img_1.png"
194
+ "description": "",
195
+ "table_classification": "NO_TABLE",
196
+ "final_alt": ""
197
+ }
198
+
199
+ def post_process(self, key: str, md_content: str) -> str:
200
+ """
201
+ 1) Gemini classification (table vs no_table).
202
+ 2) If table => alt = "HAS TO BE PROCESSED - two/three column table".
203
+ Else => normal Gemini-based description.
204
+ 3) Replace all ![]({key}{path}) with ![final_alt](relative_local_path).
205
+ 4) For any "HAS TO BE PROCESSED" images, run TableExtractor.
206
+ """
207
+ # Step A: Table classification
208
+ logger.info("Starting Gemini table classification for each local image...")
209
+ if not self.descriptions:
210
+ return md_content
211
+
212
+ max_workers = len(self.descriptions)
213
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max(max_workers, 1)) as executor:
214
+ future_to_path = {
215
+ executor.submit(call_gemini_for_table_classification, info['data']): p
216
+ for p, info in self.descriptions.items()
217
+ }
218
+ for future in concurrent.futures.as_completed(future_to_path):
219
+ path = future_to_path[future]
220
+ try:
221
+ classification = future.result()
222
+ self.descriptions[path]['table_classification'] = classification
223
+ except Exception as e:
224
+ logger.error(f"[Gemini Table Classification Error for {path}]: {str(e)}")
225
+ self.descriptions[path]['table_classification'] = "NO_TABLE"
226
+
227
+ # Step B: For images that are "NO_TABLE", we do normal gemini-based description
228
+ logger.info("Starting Gemini question-based description for non-table images...")
229
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max(max_workers, 1)) as executor:
230
+ fut_map = {}
231
+ for path, info in self.descriptions.items():
232
+ if info['table_classification'] == "NO_TABLE":
233
+ fut = executor.submit(call_gemini_for_image_description, info['data'])
234
+ fut_map[fut] = path
235
+
236
+ for fut in concurrent.futures.as_completed(fut_map):
237
+ path = fut_map[fut]
238
+ try:
239
+ desc = fut.result()
240
+ self.descriptions[path]['description'] = desc
241
+ except Exception as e:
242
+ logger.error(f"[Gemini Description Error for {path}]: {str(e)}")
243
+ self.descriptions[path]['description'] = "Image description unavailable"
244
+
245
+ # Step C: Construct final alt text
246
+ for path, info in self.descriptions.items():
247
+ classification = info['table_classification']
248
+ if classification == "TWO_COLUMN":
249
+ final_alt = "HAS TO BE PROCESSED - two column table"
250
+ elif classification == "THREE_COLUMN":
251
+ final_alt = "HAS TO BE PROCESSED - three column table"
252
+ else:
253
+ # normal gemini-based description
254
+ final_alt = info['description'] or "Image description unavailable"
255
+ info['final_alt'] = final_alt
256
+
257
+ for path, info in self.descriptions.items():
258
+ old_md_tag = f"![]({key}{path})"
259
+ new_md_tag = f"![{info['final_alt']}]({info['relative_path']})"
260
+ md_content = md_content.replace(old_md_tag, new_md_tag)
261
+
262
+ md_content = self._process_table_images_in_markdown(md_content)
263
+
264
+ return md_content
265
+
266
+ def _process_table_images_in_markdown(self, md_content: str) -> str:
267
+ """
268
+ Finds images with alt text like:
269
+ ![HAS TO BE PROCESSED - (two|three) column table](images/img_1.png)
270
+ Then runs TableExtractor with specific params for two/three columns.
271
+ Saves each cell as a separate image in a subfolder next to the original.
272
+ """
273
+ pattern = r"!\[HAS TO BE PROCESSED - (two|three) column table\]\(([^)]+)\)"
274
+ matches = re.findall(pattern, md_content, flags=re.IGNORECASE)
275
+ if not matches:
276
+ return md_content
277
+
278
+ for (col_type, image_path) in matches:
279
+ logger.info(f"Detected table image in MD: {image_path}, columns={col_type}")
280
+ # Convert image_path to absolute path
281
+ abs_image_path = os.path.join(self.output_folder, image_path)
282
+
283
+ try:
284
+ if col_type.lower() == 'two':
285
+ # For two-column tables
286
+ extractor = TableExtractor(
287
+ merge_two_col_rows=True,
288
+ enable_subtopic_merge=True
289
+ )
290
+ else:
291
+ # For three-column tables
292
+ extractor = TableExtractor(
293
+ merge_two_col_rows=False,
294
+ enable_subtopic_merge=False
295
+ )
296
+
297
+ row_boxes = extractor.process_image(abs_image_path)
298
+
299
+ out_folder = abs_image_path + "_rows"
300
+ os.makedirs(out_folder, exist_ok=True)
301
+
302
+ extractor.save_extracted_cells(abs_image_path, row_boxes, out_folder)
303
+ logger.info(f"Table extraction done for {image_path}, saved to {out_folder}")
304
+
305
+ except Exception as e:
306
+ logger.error(f"Error processing table image {image_path}: {e}")
307
+
308
+ return md_content
309
+
310
+ class LocalPDFProcessor:
311
+ def __init__(self, output_folder: str):
312
+ self.output_folder = output_folder
313
+ os.makedirs(self.output_folder, exist_ok=True)
314
+
315
+ self.svm_model = SVMModel()
316
+ logger.info("Classification (SVM) model initialized successfully")
317
+
318
+ self.layout_mode = "layoutlmv3"
319
+ self.ocr_enable = False
320
+ self.formula_enable = True
321
+ self.table_enable = False
322
+ self.language = "en"
323
+
324
+ logger.info("LocalPDFProcessor initialized successfully")
325
+
326
+ def cleanup_gpu(self):
327
+ gc.collect()
328
+ torch.cuda.empty_cache()
329
+ logger.info("GPU memory cleaned up.")
330
+
331
+ def process(self, pdf_path: str) -> str:
332
+ logger.info(f"Processing local PDF: {pdf_path}")
333
+ try:
334
+ # Read PDF bytes
335
+ with open(pdf_path, "rb") as f:
336
+ pdf_bytes = f.read()
337
+
338
+ dataset = PymuDocDataset(pdf_bytes)
339
+ inference = doc_analyze(
340
+ dataset,
341
+ ocr=self.ocr_enable,
342
+ lang=self.language,
343
+ layout_model=self.layout_mode,
344
+ formula_enable=self.formula_enable,
345
+ table_enable=self.table_enable
346
+ )
347
+ logger.info("doc_analyze complete. Extracting images...")
348
+
349
+ image_writer = LocalImageWriter(self.output_folder, self.svm_model)
350
+ pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
351
+
352
+ logger.info("Image pipeline completed. Generating markdown...")
353
+ md_content = pipe_result.get_markdown("local-unique-prefix/")
354
+
355
+ final_markdown = image_writer.post_process("local-unique-prefix/", md_content)
356
+
357
+ # Save final .md file
358
+ md_path = os.path.join(self.output_folder, "final_output.md")
359
+ with open(md_path, "w", encoding="utf-8") as f:
360
+ f.write(final_markdown)
361
+
362
+ logger.info(f"Markdown saved to: {md_path}")
363
+ return final_markdown
364
+
365
+ finally:
366
+ self.cleanup_gpu()
367
+
368
+ if __name__ == "__main__":
369
+ input_pdf = "/home/user/app/input_output/aqa-Mathematics-specification.pdf"
370
+ output_dir = "/home/user/app/input_output/output"
371
+
372
+ processor = LocalPDFProcessor(output_folder=output_dir)
373
+ md_result = processor.process(input_pdf)
374
+ # print("Final Markdown:\n", md_result)
topic_extraction_upgrade.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import List, Tuple
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # if you are working with 3-column tables, change `merge_two_col_rows` and `enable_subtopic_merge` to False
11
+ # otherwise set them to True if you are working with 2-column tables (currently hardcoded, just test)
12
+
13
+ class TableExtractor:
14
+ def __init__(
15
+ self,
16
+ #preprocessing parameters
17
+ denoise_h: int = 10,
18
+ clahe_clip: float = 3.0,
19
+ clahe_grid: int = 8,
20
+ sharpen_kernel: np.ndarray = np.array([[-1, -1, -1],
21
+ [-1, 9, -1],
22
+ [-1, -1, -1]]),
23
+ thresh_block_size: int = 21,
24
+ thresh_C: int = 7,
25
+
26
+ # Row detection parameters
27
+ horizontal_scale: int = 20,
28
+ row_morph_iterations: int = 2,
29
+ min_row_height: int = 30,
30
+ min_row_density: float = 0.01,
31
+
32
+ # Column detection parameters
33
+ vertical_scale: int = 20,
34
+ col_morph_iterations: int = 2,
35
+ min_col_height_ratio: float = 0.5,
36
+ min_col_density: float = 0.01,
37
+
38
+ # Bounding box extraction
39
+ padding: int = 0,
40
+ skip_header: bool = True,
41
+
42
+ # Two-column & subtopic merges
43
+ merge_two_col_rows: bool = False,
44
+ enable_subtopic_merge: bool = False,
45
+ subtopic_threshold: float = 0.2,
46
+
47
+ #gray artifact filter
48
+ std_threshold_for_artifacts: float = 5.0,
49
+
50
+ #parameters for line removal check
51
+ line_removal_scale: int = 15,
52
+ line_removal_iterations: int = 1,
53
+ min_text_ratio_after_line_removal: float = 0.001
54
+ ):
55
+ """
56
+ :param merge_two_col_rows: If True, a row with exactly 1 vertical line => merges into 1 bounding box.
57
+ :param enable_subtopic_merge: If True, a row with 2 vertical lines => 3 columns can become 2 if left is narrow.
58
+ :param subtopic_threshold: Fraction of row width for subtopic detection.
59
+ :param std_threshold_for_artifacts: Grayscale std dev < this => skip as artifact.
60
+ :param line_removal_scale: Larger => more aggressive line detection inside the cell.
61
+ :param line_removal_iterations: Morphological iterations for line removal.
62
+ :param min_text_ratio_after_line_removal: If fraction of text after removing lines < this => skip cell.
63
+ """
64
+ # Preprocessing
65
+ self.denoise_h = denoise_h
66
+ self.clahe_clip = clahe_clip
67
+ self.clahe_grid = clahe_grid
68
+ self.sharpen_kernel = sharpen_kernel
69
+ self.thresh_block_size = thresh_block_size
70
+ self.thresh_C = thresh_C
71
+
72
+ # Row detection
73
+ self.horizontal_scale = horizontal_scale
74
+ self.row_morph_iterations = row_morph_iterations
75
+ self.min_row_height = min_row_height
76
+ self.min_row_density = min_row_density
77
+
78
+ # Column detection
79
+ self.vertical_scale = vertical_scale
80
+ self.col_morph_iterations = col_morph_iterations
81
+ self.min_col_height_ratio = min_col_height_ratio
82
+ self.min_col_density = min_col_density
83
+
84
+ # Bbox extraction
85
+ self.padding = padding
86
+ self.skip_header = skip_header
87
+
88
+ # Two-column / subtopic merges
89
+ self.merge_two_col_rows = merge_two_col_rows
90
+ self.enable_subtopic_merge = enable_subtopic_merge
91
+ self.subtopic_threshold = subtopic_threshold
92
+
93
+ #artifact filtering (gray headers, purple, etc) / currenty not working well
94
+ self.std_threshold_for_artifacts = std_threshold_for_artifacts
95
+
96
+ #line removal inside cell
97
+ self.line_removal_scale = line_removal_scale
98
+ self.line_removal_iterations = line_removal_iterations
99
+ self.min_text_ratio_after_line_removal = min_text_ratio_after_line_removal
100
+
101
+ def preprocess(self, img: np.ndarray) -> np.ndarray:
102
+ """Grayscale, denoise, CLAHE, sharpen, adaptive threshold (binary_inv)."""
103
+ if img.ndim == 3:
104
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
105
+ else:
106
+ gray = img.copy()
107
+
108
+ denoised = cv2.fastNlMeansDenoising(gray, h=self.denoise_h)
109
+ clahe = cv2.createCLAHE(clipLimit=self.clahe_clip, tileGridSize=(self.clahe_grid, self.clahe_grid))
110
+ enhanced = clahe.apply(denoised)
111
+ sharpened = cv2.filter2D(enhanced, -1, self.sharpen_kernel)
112
+
113
+ binarized = cv2.adaptiveThreshold(
114
+ sharpened, 255,
115
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
116
+ cv2.THRESH_BINARY_INV,
117
+ self.thresh_block_size,
118
+ self.thresh_C
119
+ )
120
+ return binarized
121
+
122
+ def detect_full_rows(self, bin_img: np.ndarray) -> List[Tuple[int, int]]:
123
+ """Find horizontal row boundaries in the binarized image."""
124
+ h_kernel_size = max(1, bin_img.shape[1] // self.horizontal_scale)
125
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (h_kernel_size, 1))
126
+
127
+ horizontal_lines = cv2.morphologyEx(bin_img, cv2.MORPH_OPEN, horizontal_kernel,
128
+ iterations=self.row_morph_iterations)
129
+ row_projection = np.sum(horizontal_lines, axis=1)
130
+ max_val = np.max(row_projection) if len(row_projection) else 0
131
+
132
+ # If no lines, treat entire image as one row (opt)
133
+ if max_val < 1e-5:
134
+ return [(0, bin_img.shape[0])]
135
+
136
+ threshold_val = 0.3 * max_val
137
+ line_indices = np.where(row_projection > threshold_val)[0]
138
+
139
+ if len(line_indices) < 2:
140
+ return [(0, bin_img.shape[0])]
141
+
142
+ # Group consecutive indices
143
+ lines = []
144
+ current = [line_indices[0]]
145
+ for i in range(1, len(line_indices)):
146
+ if line_indices[i] - line_indices[i - 1] <= 2:
147
+ current.append(line_indices[i])
148
+ else:
149
+ lines.append(int(np.mean(current)))
150
+ current = [line_indices[i]]
151
+ if current:
152
+ lines.append(int(np.mean(current)))
153
+
154
+ row_bounds = []
155
+ for i in range(len(lines) - 1):
156
+ y1 = lines[i]
157
+ y2 = lines[i + 1]
158
+ if (y2 - y1) >= self.min_row_height:
159
+ row_bounds.append((y1, y2))
160
+
161
+ return row_bounds if row_bounds else [(0, bin_img.shape[0])]
162
+
163
+ def detect_columns_in_row(self, row_img: np.ndarray, y1: int, y2: int) -> List[Tuple[int, int, int, int]]:
164
+ """
165
+ Detect up to two vertical lines => up to 3 bounding boxes.
166
+ - 0 lines => 1 bounding box
167
+ - 1 line => 2 bounding boxes (unless merge_two_col_rows => 1)
168
+ - 2 lines => 3 bounding boxes by default
169
+ if enable_subtopic_merge => check left box < subtopic_threshold => 2 boxes
170
+ """
171
+ row_height = (y2 - y1)
172
+ row_width = row_img.shape[1]
173
+
174
+ # Morph kernel for vertical lines
175
+ v_kernel_size = max(1, row_height // self.vertical_scale)
176
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, v_kernel_size))
177
+
178
+ vertical_lines = cv2.morphologyEx(row_img, cv2.MORPH_OPEN, vertical_kernel,
179
+ iterations=self.col_morph_iterations)
180
+ vertical_lines = cv2.dilate(vertical_lines, np.ones((3, 3), np.uint8), iterations=1)
181
+
182
+ # Find contours => x positions
183
+ contours, _ = cv2.findContours(vertical_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
184
+ x_positions = []
185
+ for c in contours:
186
+ x, y, w, h = cv2.boundingRect(c)
187
+ # Must be at least half the row height to be considered a real column divider
188
+ if h >= self.min_col_height_ratio * row_height:
189
+ x_positions.append(x)
190
+ x_positions = sorted(set(x_positions))
191
+
192
+ # Keep at most 2 vertical lines
193
+ if len(x_positions) > 2:
194
+ x_positions = x_positions[:2]
195
+
196
+ # Build bounding boxes
197
+ if len(x_positions) == 0:
198
+ # 0 lines => single bounding box
199
+ boxes = [(0, y1, row_width, row_height)]
200
+
201
+ elif len(x_positions) == 1:
202
+ # 1 line => 2 bounding boxes by default
203
+ x1 = x_positions[0]
204
+ if self.merge_two_col_rows:
205
+ # Merge => single bounding box
206
+ boxes = [(0, y1, row_width, row_height)]
207
+ else:
208
+ boxes = [
209
+ (0, y1, x1, row_height),
210
+ (x1, y1, row_width - x1, row_height)
211
+ ]
212
+
213
+ else:
214
+ # 2 lines => normally 3 bounding boxes
215
+ x1, x2 = sorted(x_positions)
216
+ if self.enable_subtopic_merge:
217
+ # If left bounding box is very narrow => treat as subtopic => 2 bounding boxes
218
+ left_box_width = x1
219
+ if left_box_width < (self.subtopic_threshold * row_width):
220
+ boxes = [
221
+ (0, y1, x1, row_height),
222
+ (x1, y1, row_width - x1, row_height)
223
+ ]
224
+ else:
225
+ boxes = [
226
+ (0, y1, x1, row_height),
227
+ (x1, y1, x2 - x1, row_height),
228
+ (x2, y1, row_width - x2, row_height)
229
+ ]
230
+ else:
231
+ boxes = [
232
+ (0, y1, x1, row_height),
233
+ (x1, y1, x2 - x1, row_height),
234
+ (x2, y1, row_width - x2, row_height)
235
+ ]
236
+
237
+ # Filter out columns with insufficient density
238
+ filtered = []
239
+ for (x, y, w, h) in boxes:
240
+ if w <= 0:
241
+ continue
242
+ subregion = row_img[:, x : x + w]
243
+ white_pixels = np.sum(subregion == 255)
244
+ total_pixels = subregion.size
245
+ if total_pixels == 0:
246
+ continue
247
+ density = white_pixels / total_pixels
248
+ if density >= self.min_col_density:
249
+ filtered.append((x, y, w, h))
250
+
251
+ return filtered
252
+
253
+ def process_image(self, image_path: str) -> List[List[Tuple[int, int, int, int]]]:
254
+ """
255
+ 1) Preprocess => bin_img
256
+ 2) Detect row segments
257
+ 3) Filter out rows by density
258
+ - optionally skip first row (header)
259
+ 5) For each row => detect columns => bounding boxes
260
+ """
261
+ img = cv2.imread(image_path)
262
+ if img is None:
263
+ raise ValueError(f"Could not read image: {image_path}")
264
+
265
+ bin_img = self.preprocess(img)
266
+ row_segments = self.detect_full_rows(bin_img)
267
+
268
+ # Filter out rows with insufficient density
269
+ valid_rows = []
270
+ for (y1, y2) in row_segments:
271
+ row_region = bin_img[y1:y2, :]
272
+ area = row_region.size
273
+ if area == 0:
274
+ continue
275
+ white_pixels = np.sum(row_region == 255)
276
+ density = white_pixels / area
277
+ if density >= self.min_row_density:
278
+ valid_rows.append((y1, y2))
279
+
280
+ # Possibly skip header row
281
+ if self.skip_header and len(valid_rows) > 1:
282
+ valid_rows = valid_rows[1:]
283
+
284
+ # Detect columns in each row
285
+ all_rows_boxes = []
286
+ for (y1, y2) in valid_rows:
287
+ row_img = bin_img[y1:y2, :]
288
+ col_boxes = self.detect_columns_in_row(row_img, y1, y2)
289
+ if col_boxes:
290
+ all_rows_boxes.append(col_boxes)
291
+
292
+ return all_rows_boxes
293
+
294
+ def extract_box_image(self, original: np.ndarray, box: Tuple[int, int, int, int]) -> np.ndarray:
295
+ """Crop bounding box from original with optional padding."""
296
+ x, y, w, h = box
297
+ Y1 = max(0, y - self.padding)
298
+ Y2 = min(original.shape[0], y + h + self.padding)
299
+ X1 = max(0, x - self.padding)
300
+ X2 = min(original.shape[1], x + w + self.padding)
301
+ return original[Y1:Y2, X1:X2]
302
+
303
+ def _remove_lines_in_cell(self, gray_bin: np.ndarray) -> np.ndarray:
304
+ """
305
+ Remove horizontal + vertical lines from a binarized subregion
306
+ and return the 'text-only' mask.
307
+ """
308
+ # 1) horizontal line detection
309
+ horiz_kernel_size = max(1, gray_bin.shape[1] // self.line_removal_scale)
310
+ horiz_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (horiz_kernel_size, 1))
311
+ horizontal = cv2.morphologyEx(gray_bin, cv2.MORPH_OPEN, horiz_kernel, iterations=self.line_removal_iterations)
312
+
313
+ # 2) vertical line detection
314
+ vert_kernel_size = max(1, gray_bin.shape[0] // self.line_removal_scale)
315
+ vert_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_kernel_size))
316
+ vertical = cv2.morphologyEx(gray_bin, cv2.MORPH_OPEN, vert_kernel, iterations=self.line_removal_iterations)
317
+
318
+ # Combine lines
319
+ lines = cv2.bitwise_or(horizontal, vertical)
320
+ # Subtract from the original => text-only
321
+ text_only = cv2.bitwise_and(gray_bin, cv2.bitwise_not(lines))
322
+ return text_only
323
+
324
+ def is_grey_artifact(self, cell_img: np.ndarray) -> bool:
325
+ """
326
+ 1) If grayscale std dev < std_threshold_for_artifacts => skip as uniform.
327
+ 2) Otherwise, remove lines from an Otsu-binarized version of the cell
328
+ and check if there's enough text left. If not, skip as artifact.
329
+ """
330
+ if cell_img.size == 0:
331
+ return True
332
+
333
+ gray = cv2.cvtColor(cell_img, cv2.COLOR_BGR2GRAY)
334
+ std_val = np.std(gray)
335
+ if std_val < self.std_threshold_for_artifacts:
336
+ return True
337
+
338
+ # 2) Binarize => remove lines => check leftover text
339
+ # Use Otsu threshold for the local cell
340
+ _, cell_bin = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
341
+
342
+ text_only = self._remove_lines_in_cell(cell_bin)
343
+ nonzero_text = cv2.countNonZero(text_only)
344
+ ratio = nonzero_text / float(cell_bin.size)
345
+
346
+ if ratio < self.min_text_ratio_after_line_removal:
347
+ # Hardly any text remains => artifact
348
+ return True
349
+
350
+ return False
351
+
352
+ def save_extracted_cells(
353
+ self, image_path: str, row_boxes: List[List[Tuple[int, int, int, int]]], output_dir: str
354
+ ):
355
+ """Save each cell from the original image, skipping uniform/gray artifacts."""
356
+ out_path = Path(output_dir)
357
+ out_path.mkdir(exist_ok=True, parents=True)
358
+
359
+ original = cv2.imread(image_path)
360
+ if original is None:
361
+ raise ValueError(f"Could not read original image: {image_path}")
362
+
363
+ for i, row in enumerate(row_boxes):
364
+ row_dir = out_path / f"row_{i}"
365
+ row_dir.mkdir(exist_ok=True)
366
+ for j, box in enumerate(row):
367
+ cell_img = self.extract_box_image(original, box)
368
+ # Skip if uniform or line-only artifact
369
+ if self.is_grey_artifact(cell_img):
370
+ logger.info(f"Skipping artifact cell at row={i}, col={j}. (uniform/grey/line-only)")
371
+ continue
372
+
373
+ out_file = row_dir / f"col_{j}.png"
374
+ cv2.imwrite(str(out_file), cell_img)
375
+ logger.info(f"Saved cell image row={i}, col={j} -> {out_file}")
376
+
377
+ class TableExtractorApp:
378
+ def __init__(self, extractor: TableExtractor):
379
+ self.extractor = extractor
380
+
381
+ def run(self, input_image: str, output_folder: str):
382
+ row_boxes = self.extractor.process_image(input_image)
383
+ logger.info(f"Detected {len(row_boxes)} row(s).")
384
+ self.extractor.save_extracted_cells(input_image, row_boxes, output_folder)
385
+ logger.info("Done. Check the output folder for results.")
386
+
387
+
388
+ if __name__ == "__main__":
389
+ input_image = "images/test/img_2.png"
390
+ output_folder = "refined_outp"
391
+
392
+ extractor = TableExtractor(
393
+ denoise_h=10,
394
+ clahe_clip=3.0,
395
+ clahe_grid=8,
396
+ thresh_block_size=21,
397
+ thresh_C=7,
398
+
399
+ horizontal_scale=20,
400
+ row_morph_iterations=2,
401
+ min_row_height=30,
402
+ min_row_density=0.01,
403
+
404
+ vertical_scale=20,
405
+ col_morph_iterations=2,
406
+ min_col_height_ratio=0.5,
407
+ min_col_density=0.01,
408
+
409
+ padding=1,
410
+ skip_header=True,
411
+
412
+ merge_two_col_rows=True,
413
+ enable_subtopic_merge=True,
414
+ subtopic_threshold=0.2,
415
+
416
+ std_threshold_for_artifacts=10.0,
417
+ line_removal_scale=20,
418
+ line_removal_iterations=1,
419
+ min_text_ratio_after_line_removal=0.001
420
+ )
421
+
422
+ app = TableExtractorApp(extractor)
423
+ app.run(input_image, output_folder)