MinerU / worker.py
SkyNait's picture
fix pattern
059f61a
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika
from typing import Tuple, Dict, Any
from mineru_single import Processor
from topic_extr import TopicExtractionProcessor
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
class RabbitMQWorker:
def __init__(self, num_workers: int = 1):
self.num_workers = num_workers
self.rabbit_url = os.getenv("RABBITMQ_URL")
logger.info("Initializing RabbitMQWorker")
self.processor = Processor()
self.topic_processor = TopicExtractionProcessor()
self.publisher_connection = None
self.publisher_channel = None
def setup_publisher(self):
if not self.publisher_connection or self.publisher_connection.is_closed:
logger.info("Setting up publisher connection to RabbitMQ.")
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 1000 # Match the consumer heartbeat
connection_params.blocked_connection_timeout = 500 # Increase timeout
self.publisher_connection = pika.BlockingConnection(connection_params)
self.publisher_channel = self.publisher_connection.channel()
self.publisher_channel.queue_declare(queue="ml_server", durable=True)
logger.info("Publisher connection/channel established successfully.")
def publish_message(self, body_dict: dict, headers: dict):
"""Use persistent connection for publishing"""
max_retries = 3
for attempt in range(max_retries):
try:
# Ensure publisher connection is setup
self.setup_publisher()
self.publisher_channel.basic_publish(
exchange="",
routing_key="ml_server",
body=json.dumps(body_dict).encode('utf-8'),
properties=pika.BasicProperties(
headers=headers
)
)
logger.info("Published message to ml_server queue (attempt=%d).", attempt + 1)
return True
except Exception as e:
logger.error("Publish attempt %d failed: %s", attempt + 1, e)
# Close failed connection
if self.publisher_connection and not self.publisher_connection.is_closed:
try:
self.publisher_connection.close()
except:
pass
self.publisher_connection = None
self.publisher_channel = None
if attempt == max_retries - 1:
logger.error("Failed to publish after %d attempts", max_retries)
return False
time.sleep(2)
def callback(self, ch, method, properties, body):
"""Handle incoming RabbitMQ messages"""
thread_id = threading.current_thread().name
headers = properties.headers or {}
logger.info("[Worker %s] Received message: %s, headers: %s", thread_id, body, headers)
try:
contexts = []
body_dict = json.loads(body)
pattern = body_dict.get("pattern")
if pattern == "process_files":
data = body_dict.get("data")
input_files = data.get("input_files")
logger.info("[Worker %s] Found %d file(s) to process.", thread_id, len(input_files))
for files in input_files:
try:
context = {
"key": files["key"],
"body": self.processor.process(files["url"], properties.headers["request_id"])
}
contexts.append(context)
except Exception as e:
err_str = f"Error processing file {files.get('key', '')}: {e}"
logger.error(err_str)
contexts.append({"key": files.get("key", ""), "body": err_str})
data["md_context"] = contexts
# topics = data.get("topics", [])
body_dict["pattern"] = "question_extraction_update_from_gpu_server"
body_dict["data"] = data
# Publish results
if self.publish_message(body_dict, headers):
logger.info("[Worker %s] Successfully published results to ml_server.", thread_id)
ch.basic_ack(delivery_tag=method.delivery_tag)
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
logger.error("[Worker %s] Failed to publish results.", thread_id)
logger.info("[Worker %s] Contexts: %s", thread_id, contexts)
elif pattern == "topic_extraction":
data = body_dict.get("data")
input_files = data.get("input_files")
logger.info("[Worker %s] Found %d file(s) for topic extraction.", thread_id, len(input_files))
for file in input_files:
try:
# Process the file and get markdown content
markdown_content = self.topic_processor.process(file)
# Create context with the markdown content
context = {
"key": file["key"] + ".md",
# "body": self.topic_processor.process(file)
"body": markdown_content
}
contexts.append(context)
except Exception as e:
err_str = f"Error processing file {file.get('key', '')}: {e}"
logger.error(err_str)
contexts.append({"key": file.get("key", ""), "body": err_str})
# Add the markdown contexts to the data
data["md_context"] = contexts
body_dict["pattern"] = "topic_extraction_update_from_gpu_server"
body_dict["data"] = data
# Publish the results back to the ML server
if self.publish_message(body_dict, headers):
logger.info("[Worker %s] Published topic extraction results to ml_server.", thread_id)
ch.basic_ack(delivery_tag=method.delivery_tag)
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
logger.error("[Worker %s] Failed to publish topic results.", thread_id)
logger.info("[Worker %s] Topic contexts: %s", thread_id, contexts)
else:
ch.basic_ack(delivery_tag=method.delivery_tag, requeue=False)
logger.warning("[Worker %s] Unknown pattern type in headers: %s", thread_id, pattern)
except Exception as e:
logger.error("Error in callback: %s", e)
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
def connect_to_rabbitmq(self):
"""Establish connection to RabbitMQ with heartbeat"""
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 1000
connection_params.blocked_connection_timeout = 500
logger.info("Connecting to RabbitMQ for consumer with heartbeat=1000.")
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()
channel.queue_declare(queue="gpu_server", durable=True)
channel.basic_qos(prefetch_count=1)
channel.basic_consume(
queue="gpu_server",
on_message_callback=self.callback
)
logger.info("Consumer connected. Listening on queue='gpu_server'...")
return connection, channel
def worker(self, channel):
"""Worker function"""
logger.info("Worker thread started. Beginning consuming...")
try:
channel.start_consuming()
except Exception as e:
logger.error("Worker thread encountered an error: %s", e)
finally:
logger.info("Worker thread shutting down. Closing channel.")
channel.close()
def start(self):
"""Start the worker threads"""
logger.info("Starting %d workers in a ThreadPoolExecutor.", self.num_workers)
while True:
try:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
for _ in range(self.num_workers):
connection, channel = self.connect_to_rabbitmq()
executor.submit(self.worker, channel)
except Exception as e:
logger.error("Connection lost, reconnecting... Error: %s", e)
time.sleep(5) # Wait before reconnecting
def main():
worker = RabbitMQWorker()
worker.start()
if __name__ == "__main__":
main()