File size: 9,359 Bytes
73d131e 6fc2b3e aa071f3 bab6bde 73d131e 0aedcf3 bab6bde 1fecd54 73d131e 6fc2b3e aa071f3 e4b41bd c52e28a bab6bde 0aedcf3 a6f9a33 0aedcf3 e4b41bd bab6bde e4b41bd 4bec1a3 e4b41bd eb15d7d bab6bde e4b41bd bab6bde e4b41bd bab6bde e4b41bd eb15d7d 73d131e bab6bde 73d131e 4bec1a3 e9359a4 4bec1a3 e9359a4 eb15d7d e9359a4 458e97c e9359a4 eb15d7d e9359a4 4bec1a3 e9359a4 4bec1a3 ce21cab eb15d7d bab6bde 77c0aba eb15d7d 0aedcf3 bab6bde eb15d7d bab6bde 8f78162 04fd3ea 8f78162 aa071f3 6fc2b3e 908672e 059f61a 908672e 059f61a 908672e 6fc2b3e 908672e 6fc2b3e 908672e 6fc2b3e 059f61a 6fc2b3e 059f61a 908672e 059f61a 908672e 6fc2b3e aa071f3 908672e 8f78162 6fc2b3e 8f78162 eb15d7d 77c0aba 4bec1a3 eb15d7d 73d131e bab6bde 0aedcf3 73d131e 0aedcf3 a6f9a33 bab6bde 0aedcf3 73d131e 0aedcf3 73d131e bab6bde 73d131e 0aedcf3 bab6bde 0aedcf3 bab6bde 0aedcf3 bab6bde 0aedcf3 73d131e bab6bde eb15d7d bab6bde eb15d7d 73d131e 8cf3fe8 6fc2b3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika
from typing import Tuple, Dict, Any
from mineru_single import Processor
from topic_extr import TopicExtractionProcessor
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
class RabbitMQWorker:
def __init__(self, num_workers: int = 1):
self.num_workers = num_workers
self.rabbit_url = os.getenv("RABBITMQ_URL")
logger.info("Initializing RabbitMQWorker")
self.processor = Processor()
self.topic_processor = TopicExtractionProcessor()
self.publisher_connection = None
self.publisher_channel = None
def setup_publisher(self):
if not self.publisher_connection or self.publisher_connection.is_closed:
logger.info("Setting up publisher connection to RabbitMQ.")
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 1000 # Match the consumer heartbeat
connection_params.blocked_connection_timeout = 500 # Increase timeout
self.publisher_connection = pika.BlockingConnection(connection_params)
self.publisher_channel = self.publisher_connection.channel()
self.publisher_channel.queue_declare(queue="ml_server", durable=True)
logger.info("Publisher connection/channel established successfully.")
def publish_message(self, body_dict: dict, headers: dict):
"""Use persistent connection for publishing"""
max_retries = 3
for attempt in range(max_retries):
try:
# Ensure publisher connection is setup
self.setup_publisher()
self.publisher_channel.basic_publish(
exchange="",
routing_key="ml_server",
body=json.dumps(body_dict).encode('utf-8'),
properties=pika.BasicProperties(
headers=headers
)
)
logger.info("Published message to ml_server queue (attempt=%d).", attempt + 1)
return True
except Exception as e:
logger.error("Publish attempt %d failed: %s", attempt + 1, e)
# Close failed connection
if self.publisher_connection and not self.publisher_connection.is_closed:
try:
self.publisher_connection.close()
except:
pass
self.publisher_connection = None
self.publisher_channel = None
if attempt == max_retries - 1:
logger.error("Failed to publish after %d attempts", max_retries)
return False
time.sleep(2)
def callback(self, ch, method, properties, body):
"""Handle incoming RabbitMQ messages"""
thread_id = threading.current_thread().name
headers = properties.headers or {}
logger.info("[Worker %s] Received message: %s, headers: %s", thread_id, body, headers)
try:
contexts = []
body_dict = json.loads(body)
pattern = body_dict.get("pattern")
if pattern == "process_files":
data = body_dict.get("data")
input_files = data.get("input_files")
logger.info("[Worker %s] Found %d file(s) to process.", thread_id, len(input_files))
for files in input_files:
try:
context = {
"key": files["key"],
"body": self.processor.process(files["url"], properties.headers["request_id"])
}
contexts.append(context)
except Exception as e:
err_str = f"Error processing file {files.get('key', '')}: {e}"
logger.error(err_str)
contexts.append({"key": files.get("key", ""), "body": err_str})
data["md_context"] = contexts
# topics = data.get("topics", [])
body_dict["pattern"] = "question_extraction_update_from_gpu_server"
body_dict["data"] = data
# Publish results
if self.publish_message(body_dict, headers):
logger.info("[Worker %s] Successfully published results to ml_server.", thread_id)
ch.basic_ack(delivery_tag=method.delivery_tag)
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
logger.error("[Worker %s] Failed to publish results.", thread_id)
logger.info("[Worker %s] Contexts: %s", thread_id, contexts)
elif pattern == "topic_extraction":
data = body_dict.get("data")
input_files = data.get("input_files")
logger.info("[Worker %s] Found %d file(s) for topic extraction.", thread_id, len(input_files))
for file in input_files:
try:
# Process the file and get markdown content
markdown_content = self.topic_processor.process(file)
# Create context with the markdown content
context = {
"key": file["key"] + ".md",
# "body": self.topic_processor.process(file)
"body": markdown_content
}
contexts.append(context)
except Exception as e:
err_str = f"Error processing file {file.get('key', '')}: {e}"
logger.error(err_str)
contexts.append({"key": file.get("key", ""), "body": err_str})
# Add the markdown contexts to the data
data["md_context"] = contexts
body_dict["pattern"] = "topic_extraction_update_from_gpu_server"
body_dict["data"] = data
# Publish the results back to the ML server
if self.publish_message(body_dict, headers):
logger.info("[Worker %s] Published topic extraction results to ml_server.", thread_id)
ch.basic_ack(delivery_tag=method.delivery_tag)
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
logger.error("[Worker %s] Failed to publish topic results.", thread_id)
logger.info("[Worker %s] Topic contexts: %s", thread_id, contexts)
else:
ch.basic_ack(delivery_tag=method.delivery_tag, requeue=False)
logger.warning("[Worker %s] Unknown pattern type in headers: %s", thread_id, pattern)
except Exception as e:
logger.error("Error in callback: %s", e)
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
def connect_to_rabbitmq(self):
"""Establish connection to RabbitMQ with heartbeat"""
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 1000
connection_params.blocked_connection_timeout = 500
logger.info("Connecting to RabbitMQ for consumer with heartbeat=1000.")
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()
channel.queue_declare(queue="gpu_server", durable=True)
channel.basic_qos(prefetch_count=1)
channel.basic_consume(
queue="gpu_server",
on_message_callback=self.callback
)
logger.info("Consumer connected. Listening on queue='gpu_server'...")
return connection, channel
def worker(self, channel):
"""Worker function"""
logger.info("Worker thread started. Beginning consuming...")
try:
channel.start_consuming()
except Exception as e:
logger.error("Worker thread encountered an error: %s", e)
finally:
logger.info("Worker thread shutting down. Closing channel.")
channel.close()
def start(self):
"""Start the worker threads"""
logger.info("Starting %d workers in a ThreadPoolExecutor.", self.num_workers)
while True:
try:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
for _ in range(self.num_workers):
connection, channel = self.connect_to_rabbitmq()
executor.submit(self.worker, channel)
except Exception as e:
logger.error("Connection lost, reconnecting... Error: %s", e)
time.sleep(5) # Wait before reconnecting
def main():
worker = RabbitMQWorker()
worker.start()
if __name__ == "__main__":
main() |