MinerU / worker.py
princhman's picture
typoo fix
c52e28a
raw
history blame
5.87 kB
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika
from typing import Tuple, Dict, Any
from mineru_single import Processor
class RabbitMQWorker:
def __init__(self, num_workers: int = 1):
self.num_workers = num_workers
self.rabbit_url = os.getenv("RABBITMQ_URL")
self.processor = Processor()
self.publisher_connection = None
self.publisher_channel = None
def setup_publisher(self):
if not self.publisher_connection or self.publisher_connection.is_closed:
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 600 # Match the consumer heartbeat
connection_params.blocked_connection_timeout = 300 # Increase timeout
self.publisher_connection = pika.BlockingConnection(connection_params)
self.publisher_channel = self.publisher_connection.channel()
self.publisher_channel.queue_declare(queue="ml_server", durable=True)
def publish_message(self, body_dict: dict, headers: dict):
"""Use persistent connection for publishing"""
max_retries = 3
for attempt in range(max_retries):
try:
# Ensure publisher connection is setup
self.setup_publisher()
self.publisher_channel.basic_publish(
exchange="",
routing_key="ml_server",
body=json.dumps(body_dict),
properties=pika.BasicProperties(
delivery_mode=2,
headers=headers
)
)
return True
except Exception as e:
print(f"Publish attempt {attempt + 1} failed: {e}")
# Close failed connection
if self.publisher_connection and not self.publisher_connection.is_closed:
try:
self.publisher_connection.close()
except:
pass
self.publisher_connection = None
self.publisher_channel = None
if attempt == max_retries - 1:
print(f"Failed to publish after {max_retries} attempts")
return False
time.sleep(2)
def callback(self, ch, method, properties, body):
"""Handle incoming RabbitMQ messages"""
thread_id = threading.current_thread().name
headers = properties.headers or {}
print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
try:
if headers.get("request_type") == "process_files":
contexts = []
body_dict = json.loads(body)
# Process files
for file in body_dict.get("input_files", []):
try:
context = {"key": file["key"], "body": self.processor.process(file["url"], file["key"])}
contexts.append(context)
except Exception as e:
print(f"Error processing file {file['key']}: {e}")
contexts.append({"key": file["key"], "body": f"Error: {str(e)}"})
body_dict["md_context"] = contexts
ch.basic_ack(delivery_tag=method.delivery_tag)
# Publish results
if self.publish_message(body_dict, headers):
print(f"[Worker {thread_id}] Successfully published results")
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
print(f"[Worker {thread_id}] Failed to publish results")
print(f"[Worker {thread_id}] Contexts: {contexts}")
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
print(f"[Worker {thread_id}] Unknown process")
except Exception as e:
print(f"Error in callback: {e}")
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
def connect_to_rabbitmq(self):
"""Establish connection to RabbitMQ with heartbeat"""
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 30
connection_params.blocked_connection_timeout = 10
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()
channel.queue_declare(queue="gpu_server", durable=True)
channel.basic_qos(prefetch_count=1)
channel.basic_consume(
queue="gpu_server",
on_message_callback=self.callback
)
return connection, channel
def worker(self, channel):
"""Worker function"""
print(f"Worker started")
try:
channel.start_consuming()
except Exception as e:
print(f"Worker stopped: {e}")
finally:
channel.close()
def start(self):
"""Start the worker threads"""
print(f"Starting {self.num_workers} workers")
while True:
try:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
for _ in range(self.num_workers):
connection, channel = self.connect_to_rabbitmq()
executor.submit(self.worker, channel)
except Exception as e:
print(f"Connection lost, reconnecting... Error: {e}")
time.sleep(5) # Wait before reconnecting
def main():
worker = RabbitMQWorker()
worker.start()