minar09 commited on
Commit
b2ae432
·
verified ·
1 Parent(s): 82b3972

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -46
main.py CHANGED
@@ -6,7 +6,10 @@ from pathlib import Path
6
  from typing import List, Dict, Optional
7
  from dataclasses import dataclass, asdict
8
 
9
- from mineru import Mineru, Layout, Table
 
 
 
10
  from sentence_transformers import SentenceTransformer
11
  from llama_cpp import Llama
12
  from fastapi.encoders import jsonable_encoder
@@ -27,11 +30,10 @@ class ProductSpec:
27
 
28
  class PDFProcessor:
29
  def __init__(self):
30
- self.mineru = Mineru()
31
  self.emb_model = SentenceTransformer('all-MiniLM-L6-v2')
32
-
33
- # Initialize LLM with automatic download
34
  self.llm = self._initialize_llm()
 
 
35
 
36
  def _initialize_llm(self):
37
  """Initialize LLM with automatic download if needed"""
@@ -44,21 +46,89 @@ class PDFProcessor:
44
  verbose=False
45
  )
46
 
47
- def extract_layout(self, pdf_path: str) -> List[Layout]:
48
- """Extract structured layout using MinerU"""
49
- return self.mineru.process_pdf(pdf_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- def process_tables(self, tables: List[Table]) -> List[Dict]:
52
- """Convert MinerU tables to structured format"""
53
- return [{
54
- "page": table.page_number,
55
- "cells": table.cells,
56
- "header": table.headers,
57
- "content": table.content
58
- } for table in tables]
 
 
 
 
59
 
60
- def generate_query_prompt(self, text: str) -> str:
61
- """Create optimized extraction prompt"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return f"""Extract product specifications from this text:
63
  {text}
64
 
@@ -70,8 +140,8 @@ Return JSON format:
70
  "attributes": {{ "key": "value" }}
71
  }}"""
72
 
73
- def parse_response(self, response: str) -> Optional[ProductSpec]:
74
- """Robust JSON parsing with fallbacks"""
75
  try:
76
  json_start = response.find('{')
77
  json_end = response.rfind('}') + 1
@@ -86,33 +156,6 @@ Return JSON format:
86
  logger.warning(f"Parse error: {e}")
87
  return None
88
 
89
- def process_pdf(self, pdf_path: str) -> Dict:
90
- """Main processing pipeline"""
91
- start_time = time.time()
92
-
93
- # Extract structured content
94
- layout = self.extract_layout(pdf_path)
95
- tables = self.process_tables(layout.tables)
96
-
97
- # Process text blocks
98
- products = []
99
- for block in layout.text_blocks:
100
- prompt = self.generate_query_prompt(block.text)
101
-
102
- # Generate response with hardware optimization
103
- response = self.llm.create_chat_completion(
104
- messages=[{"role": "user", "content": prompt}],
105
- temperature=0.1,
106
- max_tokens=512
107
- )
108
-
109
- if product := self.parse_response(response['choices'][0]['message']['content']):
110
- product.tables = tables
111
- products.append(product.to_dict())
112
-
113
- logger.info(f"Processed {len(products)} products in {time.time()-start_time:.2f}s")
114
- return {"products": products, "tables": tables}
115
-
116
  def process_pdf_catalog(pdf_path: str):
117
  processor = PDFProcessor()
118
  try:
 
6
  from typing import List, Dict, Optional
7
  from dataclasses import dataclass, asdict
8
 
9
+ from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
10
+ from magic_pdf.data.dataset import PymuDocDataset
11
+ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
12
+ from magic_pdf.config.enums import SupportedPdfParseMethod
13
  from sentence_transformers import SentenceTransformer
14
  from llama_cpp import Llama
15
  from fastapi.encoders import jsonable_encoder
 
30
 
31
  class PDFProcessor:
32
  def __init__(self):
 
33
  self.emb_model = SentenceTransformer('all-MiniLM-L6-v2')
 
 
34
  self.llm = self._initialize_llm()
35
+ self.output_dir = Path("./output")
36
+ self.output_dir.mkdir(exist_ok=True)
37
 
38
  def _initialize_llm(self):
39
  """Initialize LLM with automatic download if needed"""
 
46
  verbose=False
47
  )
48
 
49
+ def process_pdf(self, pdf_path: str) -> Dict:
50
+ """Process PDF using MinerU pipeline"""
51
+ start_time = time.time()
52
+
53
+ # Initialize MinerU components
54
+ local_image_dir = self.output_dir / "images"
55
+ local_md_dir = self.output_dir
56
+ image_dir = str(local_image_dir.name)
57
+
58
+ os.makedirs(local_image_dir, exist_ok=True)
59
+
60
+ image_writer = FileBasedDataWriter(str(local_image_dir))
61
+ md_writer = FileBasedDataWriter(str(local_md_dir))
62
+
63
+ # Read PDF
64
+ reader = FileBasedDataReader("")
65
+ pdf_bytes = reader.read(pdf_path)
66
+
67
+ # Create dataset and process
68
+ ds = PymuDocDataset(pdf_bytes)
69
+
70
+ if ds.classify() == SupportedPdfParseMethod.OCR:
71
+ infer_result = ds.apply(doc_analyze, ocr=True)
72
+ pipe_result = infer_result.pipe_ocr_mode(image_writer)
73
+ else:
74
+ infer_result = ds.apply(doc_analyze, ocr=False)
75
+ pipe_result = infer_result.pipe_txt_mode(image_writer)
76
+
77
+ # Get structured content
78
+ middle_json = pipe_result.get_middle_json()
79
+ tables = self._extract_tables(middle_json)
80
+ text_blocks = self._extract_text_blocks(middle_json)
81
+
82
+ # Process text blocks with LLM
83
+ products = []
84
+ for block in text_blocks:
85
+ product = self._process_text_block(block)
86
+ if product:
87
+ product.tables = tables
88
+ products.append(product.to_dict())
89
+
90
+ logger.info(f"Processed {len(products)} products in {time.time()-start_time:.2f}s")
91
+ return {"products": products, "tables": tables}
92
 
93
+ def _extract_tables(self, middle_json: Dict) -> List[Dict]:
94
+ """Extract tables from MinerU's middle JSON"""
95
+ tables = []
96
+ for page in middle_json.get('pages', []):
97
+ for table in page.get('tables', []):
98
+ tables.append({
99
+ "page": page.get('page_number'),
100
+ "cells": table.get('cells', []),
101
+ "header": table.get('header', []),
102
+ "content": table.get('content', [])
103
+ })
104
+ return tables
105
 
106
+ def _extract_text_blocks(self, middle_json: Dict) -> List[str]:
107
+ """Extract text blocks from MinerU's middle JSON"""
108
+ text_blocks = []
109
+ for page in middle_json.get('pages', []):
110
+ for block in page.get('blocks', []):
111
+ if block.get('type') == 'text':
112
+ text_blocks.append(block.get('text', ''))
113
+ return text_blocks
114
+
115
+ def _process_text_block(self, text: str) -> Optional[ProductSpec]:
116
+ """Process text block with LLM"""
117
+ prompt = self._generate_query_prompt(text)
118
+
119
+ try:
120
+ response = self.llm.create_chat_completion(
121
+ messages=[{"role": "user", "content": prompt}],
122
+ temperature=0.1,
123
+ max_tokens=512
124
+ )
125
+ return self._parse_response(response['choices'][0]['message']['content'])
126
+ except Exception as e:
127
+ logger.warning(f"Error processing text block: {e}")
128
+ return None
129
+
130
+ def _generate_query_prompt(self, text: str) -> str:
131
+ """Generate extraction prompt"""
132
  return f"""Extract product specifications from this text:
133
  {text}
134
 
 
140
  "attributes": {{ "key": "value" }}
141
  }}"""
142
 
143
+ def _parse_response(self, response: str) -> Optional[ProductSpec]:
144
+ """Parse LLM response"""
145
  try:
146
  json_start = response.find('{')
147
  json_end = response.rfind('}') + 1
 
156
  logger.warning(f"Parse error: {e}")
157
  return None
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def process_pdf_catalog(pdf_path: str):
160
  processor = PDFProcessor()
161
  try: