princhman commited on
Commit
0aedcf3
·
1 Parent(s): 78a8154

final update of the logic

Browse files
__pycache__/inference_svm_model.cpython-310.pyc CHANGED
Binary files a/__pycache__/inference_svm_model.cpython-310.pyc and b/__pycache__/inference_svm_model.cpython-310.pyc differ
 
__pycache__/mineru_single.cpython-310.pyc CHANGED
Binary files a/__pycache__/mineru_single.cpython-310.pyc and b/__pycache__/mineru_single.cpython-310.pyc differ
 
__pycache__/worker.cpython-310.pyc CHANGED
Binary files a/__pycache__/worker.cpython-310.pyc and b/__pycache__/worker.cpython-310.pyc differ
 
app.py CHANGED
@@ -24,36 +24,6 @@ app.add_middleware(
24
  async def root():
25
  return {"status": "ok", "message": "API is running"}
26
 
27
- @app.post("/process")
28
- async def process_pdf(
29
- input_json: dict = Body(...),
30
- x_api_key: str = Header(None, alias="X-API-Key")
31
- ):
32
- if not x_api_key:
33
- raise HTTPException(status_code=401, detail="API key is missing")
34
- if x_api_key != API_KEY:
35
- raise HTTPException(status_code=401, detail="Invalid API key")
36
-
37
- # Connect to RabbitMQ
38
- rabbit_url = os.getenv("RABBITMQ_URL")
39
- connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
40
- channel = connection.channel()
41
- channel.queue_declare(queue="ml_server", durable=True)
42
-
43
- channel.basic_publish(
44
- exchange="",
45
- routing_key="gpu_server",
46
- body=json.dumps(input_json),
47
- properties=pika.BasicProperties(
48
- headers={"process": "topic_extraction"}
49
- )
50
- )
51
- connection.close()
52
-
53
- return {
54
- "message": "Job queued",
55
- "request_id": input_json.get("headers", {}).get("request_id", str(uuid.uuid4()))
56
- }
57
 
58
  if __name__ == "__main__":
59
  os.system('python download_models_hf.py')
 
24
  async def root():
25
  return {"status": "ok", "message": "API is running"}
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  if __name__ == "__main__":
29
  os.system('python download_models_hf.py')
inference_svm_model.py CHANGED
@@ -1,29 +1,31 @@
1
  #!/usr/bin/env python3
2
  import cv2
3
  import numpy as np
 
4
  from joblib import load
5
 
6
- def load_svm_model(model_path: str):
7
- return load(model_path)
8
 
9
- def classify_image(
10
- image_path: str,
11
- loaded_model,
12
- label_map: dict,
13
- image_size=(128, 128)
14
- ) -> str:
15
- img = cv2.imread(image_path)
16
- if img is None:
17
- # If image fails to load, default to "irrelevant" or handle differently
18
- return label_map[0]
19
 
20
- img = cv2.resize(img, image_size)
21
- x = img.flatten().reshape(1, -1)
22
- pred = loaded_model.predict(x)[0]
23
- return label_map[pred]
 
 
 
 
 
 
 
 
 
 
24
 
25
  if __name__ == "__main__":
26
  model = load_svm_model("/home/user/app/model_classification/svm_model.joblib")
27
- label_map = {0: "irrelevant", 1: "relevant"}
28
- result = classify_image("test.jpg", model, label_map)
29
  print("Classification result:", result)
 
1
  #!/usr/bin/env python3
2
  import cv2
3
  import numpy as np
4
+ import os
5
  from joblib import load
6
 
 
 
7
 
8
+ class SVMModel:
9
+ def __init__(self):
10
+ path = os.getenv("SVM_MODEL_PATH", "/home/user/app/model_classification/svm_model.joblib")
11
+ self.model = load(path)
 
 
 
 
 
 
12
 
13
+ def classify_image(
14
+ self,
15
+ image_bytes: bytes,
16
+ image_size=(128, 128)
17
+ ) -> int:
18
+ img = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
19
+ if img is None:
20
+ # If image fails to load, default to "irrelevant" or handle differently
21
+ return 0
22
+
23
+ img = cv2.resize(img, image_size)
24
+ x = img.flatten().reshape(1, -1)
25
+ pred = self.model.predict(x)[0]
26
+ return pred
27
 
28
  if __name__ == "__main__":
29
  model = load_svm_model("/home/user/app/model_classification/svm_model.joblib")
30
+ result = classify_image("test.jpg", model)
 
31
  print("Classification result:", result)
mineru_single.py CHANGED
@@ -10,7 +10,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
10
  from magic_pdf.data.io.s3 import S3Writer
11
  from magic_pdf.data.data_reader_writer.base import DataWriter
12
 
13
- from inference_svm_model import load_svm_model, classify_image
14
 
15
  class Processor:
16
  def __init__(self):
@@ -21,9 +21,7 @@ class Processor:
21
  endpoint_url=os.getenv("S3_ENDPOINT"),
22
  )
23
 
24
- model_path = os.getenv("SVM_MODEL_PATH", "/home/user/app/model_classification/svm_model.joblib")
25
- self.svm_model = load_svm_model(model_path)
26
- self.label_map = {0: "irrelevant", 1: "relevant"}
27
 
28
  with open("/home/user/magic-pdf.json", "r") as f:
29
  config = json.load(f)
@@ -37,7 +35,7 @@ class Processor:
37
  bucket = os.getenv("S3_BUCKET_NAME", "")
38
  self.prefix = f"{endpoint}/{bucket}/document-extracts/"
39
 
40
- def process(self, file_url: str) -> str:
41
  logger.info("Processing file: {}", file_url)
42
  response = requests.get(file_url)
43
  if response.status_code != 200:
@@ -54,53 +52,30 @@ class Processor:
54
  table_enable=self.table_enable
55
  )
56
 
57
- image_writer = ImageWriter(self.s3_writer, self.svm_model, self.label_map)
58
 
59
  pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
60
 
61
- folder_name = str(uuid.uuid4())
62
- md_content = pipe_result.get_markdown(self.prefix + folder_name + "/")
63
 
64
  # Remove references to images classified as "irrelevant"
65
  final_markdown = image_writer.remove_redundant_images(md_content)
66
  return final_markdown
67
 
68
- def process_batch(self, file_urls: list[str]) -> dict:
69
- results = {}
70
- for url in file_urls:
71
- try:
72
- md = self.process(url)
73
- results[url] = md
74
- except Exception as e:
75
- results[url] = f"Error: {str(e)}"
76
- return results
77
-
78
  class ImageWriter(DataWriter):
79
  """
80
  Receives each extracted image. Classifies it, uploads if relevant, or flags
81
  it for removal if irrelevant.
82
  """
83
- def __init__(self, s3_writer: S3Writer, svm_model, label_map):
84
  self.s3_writer = s3_writer
85
  self.svm_model = svm_model
86
- self.label_map = label_map
87
  self._redundant_images_paths = []
88
 
89
  def write(self, path: str, data: bytes) -> None:
90
- import tempfile
91
- import os
92
- import uuid
93
-
94
- tmp_name = f"{uuid.uuid4()}.jpg"
95
- tmp_path = os.path.join(tempfile.gettempdir(), tmp_name)
96
- with open(tmp_path, "wb") as f:
97
- f.write(data)
98
-
99
- label_str = classify_image(tmp_path, self.svm_model, self.label_map)
100
-
101
- os.remove(tmp_path)
102
 
103
- if label_str == "relevant":
104
  # Upload to S3
105
  self.s3_writer.write(path, data)
106
  else:
 
10
  from magic_pdf.data.io.s3 import S3Writer
11
  from magic_pdf.data.data_reader_writer.base import DataWriter
12
 
13
+ from inference_svm_model import SVMModel
14
 
15
  class Processor:
16
  def __init__(self):
 
21
  endpoint_url=os.getenv("S3_ENDPOINT"),
22
  )
23
 
24
+ self.svm_model = SVMModel()
 
 
25
 
26
  with open("/home/user/magic-pdf.json", "r") as f:
27
  config = json.load(f)
 
35
  bucket = os.getenv("S3_BUCKET_NAME", "")
36
  self.prefix = f"{endpoint}/{bucket}/document-extracts/"
37
 
38
+ def process(self, file_url: str, key: str) -> str:
39
  logger.info("Processing file: {}", file_url)
40
  response = requests.get(file_url)
41
  if response.status_code != 200:
 
52
  table_enable=self.table_enable
53
  )
54
 
55
+ image_writer = ImageWriter(self.s3_writer, self.svm_model)
56
 
57
  pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
58
 
59
+ md_content = pipe_result.get_markdown(self.prefix + key + "/")
 
60
 
61
  # Remove references to images classified as "irrelevant"
62
  final_markdown = image_writer.remove_redundant_images(md_content)
63
  return final_markdown
64
 
 
 
 
 
 
 
 
 
 
 
65
  class ImageWriter(DataWriter):
66
  """
67
  Receives each extracted image. Classifies it, uploads if relevant, or flags
68
  it for removal if irrelevant.
69
  """
70
+ def __init__(self, s3_writer: S3Writer, svm_model: SVMModel):
71
  self.s3_writer = s3_writer
72
  self.svm_model = svm_model
 
73
  self._redundant_images_paths = []
74
 
75
  def write(self, path: str, data: bytes) -> None:
76
+ label_str = self.svm_model.classify_image(data)
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ if label_str == 1:
79
  # Upload to S3
80
  self.s3_writer.write(path, data)
81
  else:
worker.py CHANGED
@@ -14,13 +14,17 @@ from mineru_single import Processor
14
  class RabbitMQWorker:
15
  def __init__(self, num_workers: int = 1):
16
  self.num_workers = num_workers
17
- self.rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
18
  self.processor = Processor()
19
 
20
  def publish_message(self, body_dict: dict, headers: dict):
21
  """Create a new connection for each publish operation"""
22
  try:
23
- connection = pika.BlockingConnection(pika.URLParameters(self.rabbit_url))
 
 
 
 
24
  channel = connection.channel()
25
 
26
  channel.queue_declare(queue="ml_server", durable=True)
@@ -56,41 +60,58 @@ class RabbitMQWorker:
56
  # Process files
57
  for file in body_dict.get("input_files", []):
58
  try:
59
- context = {"key": file["key"], "body": self.processor.process(file["url"])}
60
  contexts.append(context)
61
  except Exception as e:
62
  print(f"Error processing file {file['key']}: {e}")
63
  contexts.append({"key": file["key"], "body": f"Error: {str(e)}"})
64
 
65
  body_dict["md_context"] = contexts
 
66
 
67
  # Publish results
68
  if self.publish_message(body_dict, headers):
69
  print(f"[Worker {thread_id}] Successfully published results")
70
  else:
 
71
  print(f"[Worker {thread_id}] Failed to publish results")
72
 
73
  print(f"[Worker {thread_id}] Contexts: {contexts}")
74
  else:
 
75
  print(f"[Worker {thread_id}] Unknown process")
76
 
77
  except Exception as e:
78
  print(f"Error in callback: {e}")
 
79
 
80
  def connect_to_rabbitmq(self):
81
- """Establish connection to RabbitMQ"""
82
- connection = pika.BlockingConnection(pika.URLParameters(self.rabbit_url))
 
 
 
 
83
  channel = connection.channel()
84
 
85
  channel.queue_declare(queue="gpu_server", durable=True)
86
  channel.basic_qos(prefetch_count=1)
87
  channel.basic_consume(
88
  queue="gpu_server",
89
- on_message_callback=self.callback,
90
- auto_ack=True
91
  )
92
  return connection, channel
93
 
 
 
 
 
 
 
 
 
 
 
94
  def start(self):
95
  """Start the worker threads"""
96
  print(f"Starting {self.num_workers} workers")
 
14
  class RabbitMQWorker:
15
  def __init__(self, num_workers: int = 1):
16
  self.num_workers = num_workers
17
+ self.rabbit_url = os.getenv("RABBITMQ_URL")
18
  self.processor = Processor()
19
 
20
  def publish_message(self, body_dict: dict, headers: dict):
21
  """Create a new connection for each publish operation"""
22
  try:
23
+ connection_params = pika.URLParameters(self.rabbit_url)
24
+ connection_params.heartbeat = 10
25
+ connection_params.blocked_connection_timeout = 5
26
+
27
+ connection = pika.BlockingConnection(connection_params)
28
  channel = connection.channel()
29
 
30
  channel.queue_declare(queue="ml_server", durable=True)
 
60
  # Process files
61
  for file in body_dict.get("input_files", []):
62
  try:
63
+ context = {"key": file["key"], "body": self.processor.process(file["url"], file["key"])}
64
  contexts.append(context)
65
  except Exception as e:
66
  print(f"Error processing file {file['key']}: {e}")
67
  contexts.append({"key": file["key"], "body": f"Error: {str(e)}"})
68
 
69
  body_dict["md_context"] = contexts
70
+ ch.basic_ack(delivery_tag=method.delivery_tag)
71
 
72
  # Publish results
73
  if self.publish_message(body_dict, headers):
74
  print(f"[Worker {thread_id}] Successfully published results")
75
  else:
76
+ ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
77
  print(f"[Worker {thread_id}] Failed to publish results")
78
 
79
  print(f"[Worker {thread_id}] Contexts: {contexts}")
80
  else:
81
+ ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
82
  print(f"[Worker {thread_id}] Unknown process")
83
 
84
  except Exception as e:
85
  print(f"Error in callback: {e}")
86
+ ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
87
 
88
  def connect_to_rabbitmq(self):
89
+ """Establish connection to RabbitMQ with heartbeat"""
90
+ connection_params = pika.URLParameters(self.rabbit_url)
91
+ connection_params.heartbeat = 30
92
+ connection_params.blocked_connection_timeout = 10
93
+
94
+ connection = pika.BlockingConnection(connection_params)
95
  channel = connection.channel()
96
 
97
  channel.queue_declare(queue="gpu_server", durable=True)
98
  channel.basic_qos(prefetch_count=1)
99
  channel.basic_consume(
100
  queue="gpu_server",
101
+ on_message_callback=self.callback
 
102
  )
103
  return connection, channel
104
 
105
+ def worker(self, channel):
106
+ """Worker function"""
107
+ print(f"Worker started")
108
+ try:
109
+ channel.start_consuming()
110
+ except Exception as e:
111
+ print(f"Worker stopped: {e}")
112
+ finally:
113
+ channel.close()
114
+
115
  def start(self):
116
  """Start the worker threads"""
117
  print(f"Starting {self.num_workers} workers")