File size: 5,869 Bytes
73d131e 0aedcf3 1fecd54 73d131e e4b41bd c52e28a 0aedcf3 c52e28a e4b41bd 0aedcf3 e4b41bd eb15d7d e4b41bd eb15d7d 73d131e eb15d7d 0aedcf3 eb15d7d 0aedcf3 eb15d7d 0aedcf3 eb15d7d 0aedcf3 eb15d7d 73d131e eb15d7d 0aedcf3 73d131e 0aedcf3 73d131e 0aedcf3 73d131e 0aedcf3 73d131e eb15d7d 73d131e 1e9bbdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika
from typing import Tuple, Dict, Any
from mineru_single import Processor
class RabbitMQWorker:
def __init__(self, num_workers: int = 1):
self.num_workers = num_workers
self.rabbit_url = os.getenv("RABBITMQ_URL")
self.processor = Processor()
self.publisher_connection = None
self.publisher_channel = None
def setup_publisher(self):
if not self.publisher_connection or self.publisher_connection.is_closed:
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 600 # Match the consumer heartbeat
connection_params.blocked_connection_timeout = 300 # Increase timeout
self.publisher_connection = pika.BlockingConnection(connection_params)
self.publisher_channel = self.publisher_connection.channel()
self.publisher_channel.queue_declare(queue="ml_server", durable=True)
def publish_message(self, body_dict: dict, headers: dict):
"""Use persistent connection for publishing"""
max_retries = 3
for attempt in range(max_retries):
try:
# Ensure publisher connection is setup
self.setup_publisher()
self.publisher_channel.basic_publish(
exchange="",
routing_key="ml_server",
body=json.dumps(body_dict),
properties=pika.BasicProperties(
delivery_mode=2,
headers=headers
)
)
return True
except Exception as e:
print(f"Publish attempt {attempt + 1} failed: {e}")
# Close failed connection
if self.publisher_connection and not self.publisher_connection.is_closed:
try:
self.publisher_connection.close()
except:
pass
self.publisher_connection = None
self.publisher_channel = None
if attempt == max_retries - 1:
print(f"Failed to publish after {max_retries} attempts")
return False
time.sleep(2)
def callback(self, ch, method, properties, body):
"""Handle incoming RabbitMQ messages"""
thread_id = threading.current_thread().name
headers = properties.headers or {}
print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
try:
if headers.get("request_type") == "process_files":
contexts = []
body_dict = json.loads(body)
# Process files
for file in body_dict.get("input_files", []):
try:
context = {"key": file["key"], "body": self.processor.process(file["url"], file["key"])}
contexts.append(context)
except Exception as e:
print(f"Error processing file {file['key']}: {e}")
contexts.append({"key": file["key"], "body": f"Error: {str(e)}"})
body_dict["md_context"] = contexts
ch.basic_ack(delivery_tag=method.delivery_tag)
# Publish results
if self.publish_message(body_dict, headers):
print(f"[Worker {thread_id}] Successfully published results")
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
print(f"[Worker {thread_id}] Failed to publish results")
print(f"[Worker {thread_id}] Contexts: {contexts}")
else:
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
print(f"[Worker {thread_id}] Unknown process")
except Exception as e:
print(f"Error in callback: {e}")
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
def connect_to_rabbitmq(self):
"""Establish connection to RabbitMQ with heartbeat"""
connection_params = pika.URLParameters(self.rabbit_url)
connection_params.heartbeat = 30
connection_params.blocked_connection_timeout = 10
connection = pika.BlockingConnection(connection_params)
channel = connection.channel()
channel.queue_declare(queue="gpu_server", durable=True)
channel.basic_qos(prefetch_count=1)
channel.basic_consume(
queue="gpu_server",
on_message_callback=self.callback
)
return connection, channel
def worker(self, channel):
"""Worker function"""
print(f"Worker started")
try:
channel.start_consuming()
except Exception as e:
print(f"Worker stopped: {e}")
finally:
channel.close()
def start(self):
"""Start the worker threads"""
print(f"Starting {self.num_workers} workers")
while True:
try:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
for _ in range(self.num_workers):
connection, channel = self.connect_to_rabbitmq()
executor.submit(self.worker, channel)
except Exception as e:
print(f"Connection lost, reconnecting... Error: {e}")
time.sleep(5) # Wait before reconnecting
def main():
worker = RabbitMQWorker()
worker.start() |