minar09 commited on
Commit
45cbcdf
·
verified ·
1 Parent(s): ef59284

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +54 -45
main.py CHANGED
@@ -4,15 +4,21 @@ import time
4
  import logging
5
  from pathlib import Path
6
  from typing import List, Dict, Optional
7
- from dataclasses import dataclass, asdict
8
-
9
- from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
10
- from magic_pdf.data.dataset import PymuDocDataset
11
- from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
12
- from magic_pdf.config.enums import SupportedPdfParseMethod
13
  from sentence_transformers import SentenceTransformer
14
  from llama_cpp import Llama
15
- from fastapi.encoders import jsonable_encoder
 
 
 
 
 
 
 
 
 
 
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
@@ -57,39 +63,43 @@ class PDFProcessor:
57
 
58
  os.makedirs(local_image_dir, exist_ok=True)
59
 
60
- image_writer = FileBasedDataWriter(str(local_image_dir))
61
- md_writer = FileBasedDataWriter(str(local_md_dir))
62
-
63
- # Read PDF
64
- reader = FileBasedDataReader("")
65
- pdf_bytes = reader.read(pdf_path)
66
-
67
- # Create dataset and process
68
- ds = PymuDocDataset(pdf_bytes)
69
-
70
- if ds.classify() == SupportedPdfParseMethod.OCR:
71
- infer_result = ds.apply(doc_analyze, ocr=True)
72
- pipe_result = infer_result.pipe_ocr_mode(image_writer)
73
- else:
74
- infer_result = ds.apply(doc_analyze, ocr=False)
75
- pipe_result = infer_result.pipe_txt_mode(image_writer)
76
-
77
- # Get structured content
78
- middle_json = pipe_result.get_middle_json()
79
- tables = self._extract_tables(middle_json)
80
- text_blocks = self._extract_text_blocks(middle_json)
81
-
82
- # Process text blocks with LLM
83
- products = []
84
- for block in text_blocks:
85
- product = self._process_text_block(block)
86
- if product:
87
- product.tables = tables
88
- products.append(product.to_dict())
89
-
90
- logger.info(f"Processed {len(products)} products in {time.time()-start_time:.2f}s")
91
- return {"products": products, "tables": tables}
92
-
 
 
 
 
93
  def _extract_tables(self, middle_json: Dict) -> List[Dict]:
94
  """Extract tables from MinerU's middle JSON"""
95
  tables = []
@@ -102,7 +112,7 @@ class PDFProcessor:
102
  "content": table.get('content', [])
103
  })
104
  return tables
105
-
106
  def _extract_text_blocks(self, middle_json: Dict) -> List[str]:
107
  """Extract text blocks from MinerU's middle JSON"""
108
  text_blocks = []
@@ -111,7 +121,7 @@ class PDFProcessor:
111
  if block.get('type') == 'text':
112
  text_blocks.append(block.get('text', ''))
113
  return text_blocks
114
-
115
  def _process_text_block(self, text: str) -> Optional[ProductSpec]:
116
  """Process text block with LLM"""
117
  prompt = self._generate_query_prompt(text)
@@ -126,12 +136,11 @@ class PDFProcessor:
126
  except Exception as e:
127
  logger.warning(f"Error processing text block: {e}")
128
  return None
129
-
130
  def _generate_query_prompt(self, text: str) -> str:
131
  """Generate extraction prompt"""
132
  return f"""Extract product specifications from this text:
133
  {text}
134
-
135
  Return JSON format:
136
  {{
137
  "name": "product name",
@@ -139,7 +148,7 @@ Return JSON format:
139
  "price": numeric_price,
140
  "attributes": {{ "key": "value" }}
141
  }}"""
142
-
143
  def _parse_response(self, response: str) -> Optional[ProductSpec]:
144
  """Parse LLM response"""
145
  try:
 
4
  import logging
5
  from pathlib import Path
6
  from typing import List, Dict, Optional
7
+ from dataclasses import dataclass
8
+ from fastapi.encoders import jsonable_encoder
 
 
 
 
9
  from sentence_transformers import SentenceTransformer
10
  from llama_cpp import Llama
11
+
12
+ # Fix: Dynamically adjust the module path if magic_pdf is in a non-standard location
13
+ try:
14
+ from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
15
+ from magic_pdf.data.dataset import PymuDocDataset
16
+ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
17
+ from magic_pdf.config.enums import SupportedPdfParseMethod
18
+ except ModuleNotFoundError as e:
19
+ logging.error(f"Failed to import magic_pdf modules: {e}")
20
+ logging.info("Ensure that the magic_pdf package is installed and accessible in your Python environment.")
21
+ raise e
22
 
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
 
63
 
64
  os.makedirs(local_image_dir, exist_ok=True)
65
 
66
+ try:
67
+ image_writer = FileBasedDataWriter(str(local_image_dir))
68
+ md_writer = FileBasedDataWriter(str(local_md_dir))
69
+
70
+ # Read PDF
71
+ reader = FileBasedDataReader("")
72
+ pdf_bytes = reader.read(pdf_path)
73
+
74
+ # Create dataset and process
75
+ ds = PymuDocDataset(pdf_bytes)
76
+
77
+ if ds.classify() == SupportedPdfParseMethod.OCR:
78
+ infer_result = ds.apply(doc_analyze, ocr=True)
79
+ pipe_result = infer_result.pipe_ocr_mode(image_writer)
80
+ else:
81
+ infer_result = ds.apply(doc_analyze, ocr=False)
82
+ pipe_result = infer_result.pipe_txt_mode(image_writer)
83
+
84
+ # Get structured content
85
+ middle_json = pipe_result.get_middle_json()
86
+ tables = self._extract_tables(middle_json)
87
+ text_blocks = self._extract_text_blocks(middle_json)
88
+
89
+ # Process text blocks with LLM
90
+ products = []
91
+ for block in text_blocks:
92
+ product = self._process_text_block(block)
93
+ if product:
94
+ product.tables = tables
95
+ products.append(product.to_dict())
96
+
97
+ logger.info(f"Processed {len(products)} products in {time.time()-start_time:.2f}s")
98
+ return {"products": products, "tables": tables}
99
+ except Exception as e:
100
+ logger.error(f"Error during PDF processing: {e}")
101
+ raise RuntimeError("PDF processing failed.") from e
102
+
103
  def _extract_tables(self, middle_json: Dict) -> List[Dict]:
104
  """Extract tables from MinerU's middle JSON"""
105
  tables = []
 
112
  "content": table.get('content', [])
113
  })
114
  return tables
115
+
116
  def _extract_text_blocks(self, middle_json: Dict) -> List[str]:
117
  """Extract text blocks from MinerU's middle JSON"""
118
  text_blocks = []
 
121
  if block.get('type') == 'text':
122
  text_blocks.append(block.get('text', ''))
123
  return text_blocks
124
+
125
  def _process_text_block(self, text: str) -> Optional[ProductSpec]:
126
  """Process text block with LLM"""
127
  prompt = self._generate_query_prompt(text)
 
136
  except Exception as e:
137
  logger.warning(f"Error processing text block: {e}")
138
  return None
139
+
140
  def _generate_query_prompt(self, text: str) -> str:
141
  """Generate extraction prompt"""
142
  return f"""Extract product specifications from this text:
143
  {text}
 
144
  Return JSON format:
145
  {{
146
  "name": "product name",
 
148
  "price": numeric_price,
149
  "attributes": {{ "key": "value" }}
150
  }}"""
151
+
152
  def _parse_response(self, response: str) -> Optional[ProductSpec]:
153
  """Parse LLM response"""
154
  try: