|
|
|
import os |
|
import json |
|
import time |
|
import threading |
|
import multiprocessing |
|
from concurrent.futures import ThreadPoolExecutor |
|
import pika |
|
from typing import Tuple, Dict, Any |
|
from mineru_single import Processor |
|
from topic_extr import TopicExtractionProcessor |
|
|
|
import logging |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class RabbitMQWorker: |
|
def __init__(self, num_workers: int = 1): |
|
self.num_workers = num_workers |
|
self.rabbit_url = os.getenv("RABBITMQ_URL") |
|
logger.info("Initializing RabbitMQWorker") |
|
self.processor = Processor() |
|
|
|
self.topic_processor = TopicExtractionProcessor() |
|
|
|
self.publisher_connection = None |
|
self.publisher_channel = None |
|
|
|
|
|
def setup_publisher(self): |
|
if not self.publisher_connection or self.publisher_connection.is_closed: |
|
logger.info("Setting up publisher connection to RabbitMQ.") |
|
connection_params = pika.URLParameters(self.rabbit_url) |
|
connection_params.heartbeat = 1000 |
|
connection_params.blocked_connection_timeout = 500 |
|
|
|
self.publisher_connection = pika.BlockingConnection(connection_params) |
|
self.publisher_channel = self.publisher_connection.channel() |
|
self.publisher_channel.queue_declare(queue="ml_server", durable=True) |
|
logger.info("Publisher connection/channel established successfully.") |
|
|
|
def publish_message(self, body_dict: dict, headers: dict): |
|
"""Use persistent connection for publishing""" |
|
max_retries = 3 |
|
for attempt in range(max_retries): |
|
try: |
|
|
|
self.setup_publisher() |
|
|
|
self.publisher_channel.basic_publish( |
|
exchange="", |
|
routing_key="ml_server", |
|
body=json.dumps(body_dict).encode('utf-8'), |
|
properties=pika.BasicProperties( |
|
headers=headers |
|
) |
|
) |
|
logger.info("Published message to ml_server queue (attempt=%d).", attempt + 1) |
|
return True |
|
except Exception as e: |
|
logger.error("Publish attempt %d failed: %s", attempt + 1, e) |
|
|
|
if self.publisher_connection and not self.publisher_connection.is_closed: |
|
try: |
|
self.publisher_connection.close() |
|
except: |
|
pass |
|
self.publisher_connection = None |
|
self.publisher_channel = None |
|
|
|
if attempt == max_retries - 1: |
|
logger.error("Failed to publish after %d attempts", max_retries) |
|
return False |
|
time.sleep(2) |
|
|
|
def callback(self, ch, method, properties, body): |
|
"""Handle incoming RabbitMQ messages""" |
|
thread_id = threading.current_thread().name |
|
headers = properties.headers or {} |
|
|
|
logger.info("[Worker %s] Received message: %s, headers: %s", thread_id, body, headers) |
|
|
|
try: |
|
contexts = [] |
|
|
|
body_dict = json.loads(body) |
|
|
|
pattern = body_dict.get("pattern") |
|
if pattern == "process_files": |
|
data = body_dict.get("data") |
|
input_files = data.get("input_files") |
|
logger.info("[Worker %s] Found %d file(s) to process.", thread_id, len(input_files)) |
|
|
|
for files in input_files: |
|
try: |
|
context = { |
|
"key": files["key"], |
|
"body": self.processor.process(files["url"], properties.headers["request_id"]) |
|
} |
|
contexts.append(context) |
|
except Exception as e: |
|
err_str = f"Error processing file {files.get('key', '')}: {e}" |
|
logger.error(err_str) |
|
contexts.append({"key": files.get("key", ""), "body": err_str}) |
|
|
|
|
|
data["md_context"] = contexts |
|
|
|
|
|
body_dict["pattern"] = "question_extraction_update_from_gpu_server" |
|
body_dict["data"] = data |
|
|
|
|
|
if self.publish_message(body_dict, headers): |
|
logger.info("[Worker %s] Successfully published results to ml_server.", thread_id) |
|
ch.basic_ack(delivery_tag=method.delivery_tag) |
|
else: |
|
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) |
|
logger.error("[Worker %s] Failed to publish results.", thread_id) |
|
|
|
logger.info("[Worker %s] Contexts: %s", thread_id, contexts) |
|
|
|
elif pattern == "topic_extraction": |
|
data = body_dict.get("data") |
|
input_files = data.get("input_files") |
|
logger.info("[Worker %s] Found %d file(s) for topic extraction.", thread_id, len(input_files)) |
|
|
|
for file in input_files: |
|
try: |
|
|
|
markdown_content = self.topic_processor.process(file) |
|
|
|
|
|
context = { |
|
"key": file["key"] + ".md", |
|
|
|
"body": markdown_content |
|
} |
|
contexts.append(context) |
|
except Exception as e: |
|
err_str = f"Error processing file {file.get('key', '')}: {e}" |
|
logger.error(err_str) |
|
contexts.append({"key": file.get("key", ""), "body": err_str}) |
|
|
|
|
|
data["md_context"] = contexts |
|
|
|
body_dict["pattern"] = "topic_extraction_update_from_gpu_server" |
|
body_dict["data"] = data |
|
|
|
|
|
if self.publish_message(body_dict, headers): |
|
logger.info("[Worker %s] Published topic extraction results to ml_server.", thread_id) |
|
ch.basic_ack(delivery_tag=method.delivery_tag) |
|
else: |
|
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) |
|
logger.error("[Worker %s] Failed to publish topic results.", thread_id) |
|
|
|
logger.info("[Worker %s] Topic contexts: %s", thread_id, contexts) |
|
|
|
else: |
|
ch.basic_ack(delivery_tag=method.delivery_tag, requeue=False) |
|
logger.warning("[Worker %s] Unknown pattern type in headers: %s", thread_id, pattern) |
|
|
|
except Exception as e: |
|
logger.error("Error in callback: %s", e) |
|
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True) |
|
|
|
def connect_to_rabbitmq(self): |
|
"""Establish connection to RabbitMQ with heartbeat""" |
|
connection_params = pika.URLParameters(self.rabbit_url) |
|
connection_params.heartbeat = 1000 |
|
connection_params.blocked_connection_timeout = 500 |
|
logger.info("Connecting to RabbitMQ for consumer with heartbeat=1000.") |
|
|
|
|
|
connection = pika.BlockingConnection(connection_params) |
|
channel = connection.channel() |
|
|
|
channel.queue_declare(queue="gpu_server", durable=True) |
|
channel.basic_qos(prefetch_count=1) |
|
channel.basic_consume( |
|
queue="gpu_server", |
|
on_message_callback=self.callback |
|
) |
|
logger.info("Consumer connected. Listening on queue='gpu_server'...") |
|
return connection, channel |
|
|
|
def worker(self, channel): |
|
"""Worker function""" |
|
logger.info("Worker thread started. Beginning consuming...") |
|
try: |
|
channel.start_consuming() |
|
except Exception as e: |
|
logger.error("Worker thread encountered an error: %s", e) |
|
finally: |
|
logger.info("Worker thread shutting down. Closing channel.") |
|
channel.close() |
|
|
|
def start(self): |
|
"""Start the worker threads""" |
|
logger.info("Starting %d workers in a ThreadPoolExecutor.", self.num_workers) |
|
while True: |
|
try: |
|
with ThreadPoolExecutor(max_workers=self.num_workers) as executor: |
|
for _ in range(self.num_workers): |
|
connection, channel = self.connect_to_rabbitmq() |
|
executor.submit(self.worker, channel) |
|
except Exception as e: |
|
logger.error("Connection lost, reconnecting... Error: %s", e) |
|
time.sleep(5) |
|
|
|
def main(): |
|
worker = RabbitMQWorker() |
|
worker.start() |
|
|
|
if __name__ == "__main__": |
|
main() |