#!/usr/bin/env python3 import os import json import time import threading import multiprocessing from concurrent.futures import ThreadPoolExecutor import pika from typing import Tuple, Dict, Any from mineru_single import Processor from topic_extr import TopicExtractionProcessor import logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" ) logger = logging.getLogger(__name__) class RabbitMQWorker: def __init__(self, num_workers: int = 1): self.num_workers = num_workers self.rabbit_url = os.getenv("RABBITMQ_URL") logger.info("Initializing RabbitMQWorker") self.processor = Processor() self.topic_processor = TopicExtractionProcessor() self.publisher_connection = None self.publisher_channel = None def setup_publisher(self): if not self.publisher_connection or self.publisher_connection.is_closed: logger.info("Setting up publisher connection to RabbitMQ.") connection_params = pika.URLParameters(self.rabbit_url) connection_params.heartbeat = 1000 # Match the consumer heartbeat connection_params.blocked_connection_timeout = 500 # Increase timeout self.publisher_connection = pika.BlockingConnection(connection_params) self.publisher_channel = self.publisher_connection.channel() self.publisher_channel.queue_declare(queue="ml_server", durable=True) logger.info("Publisher connection/channel established successfully.") def publish_message(self, body_dict: dict, headers: dict): """Use persistent connection for publishing""" max_retries = 3 for attempt in range(max_retries): try: # Ensure publisher connection is setup self.setup_publisher() self.publisher_channel.basic_publish( exchange="", routing_key="ml_server", body=json.dumps(body_dict).encode('utf-8'), properties=pika.BasicProperties( headers=headers ) ) logger.info("Published message to ml_server queue (attempt=%d).", attempt + 1) return True except Exception as e: logger.error("Publish attempt %d failed: %s", attempt + 1, e) # Close failed connection if self.publisher_connection and not self.publisher_connection.is_closed: try: self.publisher_connection.close() except: pass self.publisher_connection = None self.publisher_channel = None if attempt == max_retries - 1: logger.error("Failed to publish after %d attempts", max_retries) return False time.sleep(2) def callback(self, ch, method, properties, body): """Handle incoming RabbitMQ messages""" thread_id = threading.current_thread().name headers = properties.headers or {} logger.info("[Worker %s] Received message: %s, headers: %s", thread_id, body, headers) try: contexts = [] body_dict = json.loads(body) pattern = body_dict.get("pattern") if pattern == "process_files": data = body_dict.get("data") input_files = data.get("input_files") logger.info("[Worker %s] Found %d file(s) to process.", thread_id, len(input_files)) for files in input_files: try: context = { "key": files["key"], "body": self.processor.process(files["url"], properties.headers["request_id"]) } contexts.append(context) except Exception as e: err_str = f"Error processing file {files.get('key', '')}: {e}" logger.error(err_str) contexts.append({"key": files.get("key", ""), "body": err_str}) data["md_context"] = contexts # topics = data.get("topics", []) body_dict["pattern"] = "question_extraction_update_from_gpu_server" body_dict["data"] = data # Publish results if self.publish_message(body_dict, headers): logger.info("[Worker %s] Successfully published results to ml_server.", thread_id) ch.basic_ack(delivery_tag=method.delivery_tag) else: ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) logger.error("[Worker %s] Failed to publish results.", thread_id) logger.info("[Worker %s] Contexts: %s", thread_id, contexts) elif pattern == "topic_extraction": data = body_dict.get("data") input_files = data.get("input_files") logger.info("[Worker %s] Found %d file(s) for topic extraction.", thread_id, len(input_files)) for file in input_files: try: # Process the file and get markdown content markdown_content = self.topic_processor.process(file) # Create context with the markdown content context = { "key": file["key"] + ".md", # "body": self.topic_processor.process(file) "body": markdown_content } contexts.append(context) except Exception as e: err_str = f"Error processing file {file.get('key', '')}: {e}" logger.error(err_str) contexts.append({"key": file.get("key", ""), "body": err_str}) # Add the markdown contexts to the data data["md_context"] = contexts body_dict["pattern"] = "topic_extraction_update_from_gpu_server" body_dict["data"] = data # Publish the results back to the ML server if self.publish_message(body_dict, headers): logger.info("[Worker %s] Published topic extraction results to ml_server.", thread_id) ch.basic_ack(delivery_tag=method.delivery_tag) else: ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) logger.error("[Worker %s] Failed to publish topic results.", thread_id) logger.info("[Worker %s] Topic contexts: %s", thread_id, contexts) else: ch.basic_ack(delivery_tag=method.delivery_tag, requeue=False) logger.warning("[Worker %s] Unknown pattern type in headers: %s", thread_id, pattern) except Exception as e: logger.error("Error in callback: %s", e) ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) def connect_to_rabbitmq(self): """Establish connection to RabbitMQ with heartbeat""" connection_params = pika.URLParameters(self.rabbit_url) connection_params.heartbeat = 1000 connection_params.blocked_connection_timeout = 500 logger.info("Connecting to RabbitMQ for consumer with heartbeat=1000.") connection = pika.BlockingConnection(connection_params) channel = connection.channel() channel.queue_declare(queue="gpu_server", durable=True) channel.basic_qos(prefetch_count=1) channel.basic_consume( queue="gpu_server", on_message_callback=self.callback ) logger.info("Consumer connected. Listening on queue='gpu_server'...") return connection, channel def worker(self, channel): """Worker function""" logger.info("Worker thread started. Beginning consuming...") try: channel.start_consuming() except Exception as e: logger.error("Worker thread encountered an error: %s", e) finally: logger.info("Worker thread shutting down. Closing channel.") channel.close() def start(self): """Start the worker threads""" logger.info("Starting %d workers in a ThreadPoolExecutor.", self.num_workers) while True: try: with ThreadPoolExecutor(max_workers=self.num_workers) as executor: for _ in range(self.num_workers): connection, channel = self.connect_to_rabbitmq() executor.submit(self.worker, channel) except Exception as e: logger.error("Connection lost, reconnecting... Error: %s", e) time.sleep(5) # Wait before reconnecting def main(): worker = RabbitMQWorker() worker.start() if __name__ == "__main__": main()