#!/usr/bin/env python3 import os import json import time import threading import multiprocessing from concurrent.futures import ThreadPoolExecutor import pika from typing import Tuple, Dict, Any from mineru_single import Processor class RabbitMQWorker: def __init__(self, num_workers: int = 1): self.num_workers = num_workers self.rabbit_url = os.getenv("RABBITMQ_URL") self.processor = Processor() self.publisher_connection = None self.publisher_channel = None def setup_publisher(self): if not self.publisher_connection or self.publisher_connection.is_closed: connection_params = pika.URLParameters(self.rabbit_url) connection_params.heartbeat = 600 # Match the consumer heartbeat connection_params.blocked_connection_timeout = 300 # Increase timeout self.publisher_connection = pika.BlockingConnection(connection_params) self.publisher_channel = self.publisher_connection.channel() self.publisher_channel.queue_declare(queue="ml_server", durable=True) def publish_message(self, body_dict: dict, headers: dict): """Use persistent connection for publishing""" max_retries = 3 for attempt in range(max_retries): try: # Ensure publisher connection is setup self.setup_publisher() self.publisher_channel.basic_publish( exchange="", routing_key="ml_server", body=json.dumps(body_dict), properties=pika.BasicProperties( delivery_mode=2, headers=headers ) ) return True except Exception as e: print(f"Publish attempt {attempt + 1} failed: {e}") # Close failed connection if self.publisher_connection and not self.publisher_connection.is_closed: try: self.publisher_connection.close() except: pass self.publisher_connection = None self.publisher_channel = None if attempt == max_retries - 1: print(f"Failed to publish after {max_retries} attempts") return False time.sleep(2) def callback(self, ch, method, properties, body): """Handle incoming RabbitMQ messages""" thread_id = threading.current_thread().name headers = properties.headers or {} print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}") try: if headers.get("request_type") == "process_files": contexts = [] body_dict = json.loads(body) # Process files for file in body_dict.get("input_files", []): try: context = {"key": file["key"], "body": self.processor.process(file["url"], file["key"])} contexts.append(context) except Exception as e: print(f"Error processing file {file['key']}: {e}") contexts.append({"key": file["key"], "body": f"Error: {str(e)}"}) body_dict["md_context"] = contexts ch.basic_ack(delivery_tag=method.delivery_tag) # Publish results if self.publish_message(body_dict, headers): print(f"[Worker {thread_id}] Successfully published results") else: ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) print(f"[Worker {thread_id}] Failed to publish results") print(f"[Worker {thread_id}] Contexts: {contexts}") else: ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) print(f"[Worker {thread_id}] Unknown process") except Exception as e: print(f"Error in callback: {e}") ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) def connect_to_rabbitmq(self): """Establish connection to RabbitMQ with heartbeat""" connection_params = pika.URLParameters(self.rabbit_url) connection_params.heartbeat = 30 connection_params.blocked_connection_timeout = 10 connection = pika.BlockingConnection(connection_params) channel = connection.channel() channel.queue_declare(queue="gpu_server", durable=True) channel.basic_qos(prefetch_count=1) channel.basic_consume( queue="gpu_server", on_message_callback=self.callback ) return connection, channel def worker(self, channel): """Worker function""" print(f"Worker started") try: channel.start_consuming() except Exception as e: print(f"Worker stopped: {e}") finally: channel.close() def start(self): """Start the worker threads""" print(f"Starting {self.num_workers} workers") while True: try: with ThreadPoolExecutor(max_workers=self.num_workers) as executor: for _ in range(self.num_workers): connection, channel = self.connect_to_rabbitmq() executor.submit(self.worker, channel) except Exception as e: print(f"Connection lost, reconnecting... Error: {e}") time.sleep(5) # Wait before reconnecting def main(): worker = RabbitMQWorker() worker.start()