|
|
|
import os |
|
import json |
|
import time |
|
import threading |
|
import multiprocessing |
|
from concurrent.futures import ThreadPoolExecutor |
|
import pika |
|
from typing import Tuple, Dict, Any |
|
|
|
from mineru_single import Processor |
|
|
|
|
|
class RabbitMQWorker: |
|
def __init__(self, num_workers: int = 1): |
|
self.num_workers = num_workers |
|
self.rabbit_url = os.getenv("RABBITMQ_URL") |
|
self.processor = Processor() |
|
|
|
self.publisher_connection = None |
|
self.publisher_channel = None |
|
|
|
|
|
def setup_publisher(self): |
|
if not self.publisher_connection or self.publisher_connection.is_closed: |
|
connection_params = pika.URLParameters(self.rabbit_url) |
|
connection_params.heartbeat = 600 |
|
connection_params.blocked_connection_timeout = 300 |
|
|
|
self.publisher_connection = pika.BlockingConnection(connection_params) |
|
self.publisher_channel = self.publisher_connection.channel() |
|
self.publisher_channel.queue_declare(queue="ml_server", durable=True) |
|
|
|
def publish_message(self, body_dict: dict, headers: dict): |
|
"""Use persistent connection for publishing""" |
|
max_retries = 3 |
|
for attempt in range(max_retries): |
|
try: |
|
|
|
self.setup_publisher() |
|
|
|
self.publisher_channel.basic_publish( |
|
exchange="", |
|
routing_key="ml_server", |
|
body=json.dumps(body_dict), |
|
properties=pika.BasicProperties( |
|
delivery_mode=2, |
|
headers=headers |
|
) |
|
) |
|
return True |
|
except Exception as e: |
|
print(f"Publish attempt {attempt + 1} failed: {e}") |
|
|
|
if self.publisher_connection and not self.publisher_connection.is_closed: |
|
try: |
|
self.publisher_connection.close() |
|
except: |
|
pass |
|
self.publisher_connection = None |
|
self.publisher_channel = None |
|
|
|
if attempt == max_retries - 1: |
|
print(f"Failed to publish after {max_retries} attempts") |
|
return False |
|
time.sleep(2) |
|
|
|
def callback(self, ch, method, properties, body): |
|
"""Handle incoming RabbitMQ messages""" |
|
thread_id = threading.current_thread().name |
|
headers = properties.headers or {} |
|
|
|
print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}") |
|
|
|
try: |
|
if headers.get("request_type") == "process_files": |
|
contexts = [] |
|
body_dict = json.loads(body) |
|
|
|
|
|
for file in body_dict.get("input_files", []): |
|
try: |
|
context = {"key": file["key"], "body": self.processor.process(file["url"], file["key"])} |
|
contexts.append(context) |
|
except Exception as e: |
|
print(f"Error processing file {file['key']}: {e}") |
|
contexts.append({"key": file["key"], "body": f"Error: {str(e)}"}) |
|
|
|
body_dict["md_context"] = contexts |
|
ch.basic_ack(delivery_tag=method.delivery_tag) |
|
|
|
|
|
if self.publish_message(body_dict, headers): |
|
print(f"[Worker {thread_id}] Successfully published results") |
|
else: |
|
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) |
|
print(f"[Worker {thread_id}] Failed to publish results") |
|
|
|
print(f"[Worker {thread_id}] Contexts: {contexts}") |
|
else: |
|
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) |
|
print(f"[Worker {thread_id}] Unknown process") |
|
|
|
except Exception as e: |
|
print(f"Error in callback: {e}") |
|
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) |
|
|
|
def connect_to_rabbitmq(self): |
|
"""Establish connection to RabbitMQ with heartbeat""" |
|
connection_params = pika.URLParameters(self.rabbit_url) |
|
connection_params.heartbeat = 30 |
|
connection_params.blocked_connection_timeout = 10 |
|
|
|
connection = pika.BlockingConnection(connection_params) |
|
channel = connection.channel() |
|
|
|
channel.queue_declare(queue="gpu_server", durable=True) |
|
channel.basic_qos(prefetch_count=1) |
|
channel.basic_consume( |
|
queue="gpu_server", |
|
on_message_callback=self.callback |
|
) |
|
return connection, channel |
|
|
|
def worker(self, channel): |
|
"""Worker function""" |
|
print(f"Worker started") |
|
try: |
|
channel.start_consuming() |
|
except Exception as e: |
|
print(f"Worker stopped: {e}") |
|
finally: |
|
channel.close() |
|
|
|
def start(self): |
|
"""Start the worker threads""" |
|
print(f"Starting {self.num_workers} workers") |
|
while True: |
|
try: |
|
with ThreadPoolExecutor(max_workers=self.num_workers) as executor: |
|
for _ in range(self.num_workers): |
|
connection, channel = self.connect_to_rabbitmq() |
|
executor.submit(self.worker, channel) |
|
except Exception as e: |
|
print(f"Connection lost, reconnecting... Error: {e}") |
|
time.sleep(5) |
|
|
|
def main(): |
|
worker = RabbitMQWorker() |
|
worker.start() |