File size: 5,869 Bytes
73d131e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aedcf3
1fecd54
73d131e
e4b41bd
 
 
 
 
c52e28a
0aedcf3
c52e28a
e4b41bd
0aedcf3
e4b41bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb15d7d
e4b41bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb15d7d
73d131e
 
 
 
 
 
 
 
eb15d7d
 
 
 
 
 
 
0aedcf3
eb15d7d
 
 
 
 
 
0aedcf3
eb15d7d
 
 
 
 
0aedcf3
eb15d7d
 
 
 
0aedcf3
eb15d7d
 
73d131e
eb15d7d
0aedcf3
73d131e
 
0aedcf3
 
 
 
 
 
73d131e
 
 
 
 
 
0aedcf3
73d131e
 
 
0aedcf3
 
 
 
 
 
 
 
 
 
73d131e
 
 
eb15d7d
 
 
 
 
 
 
 
 
73d131e
 
 
1e9bbdd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika
from typing import Tuple, Dict, Any

from mineru_single import Processor


class RabbitMQWorker:
    def __init__(self, num_workers: int = 1):
        self.num_workers = num_workers
        self.rabbit_url = os.getenv("RABBITMQ_URL")
        self.processor = Processor()

        self.publisher_connection = None
        self.publisher_channel = None


    def setup_publisher(self):
        if not self.publisher_connection or self.publisher_connection.is_closed:
            connection_params = pika.URLParameters(self.rabbit_url)
            connection_params.heartbeat = 600  # Match the consumer heartbeat
            connection_params.blocked_connection_timeout = 300  # Increase timeout
            
            self.publisher_connection = pika.BlockingConnection(connection_params)
            self.publisher_channel = self.publisher_connection.channel()
            self.publisher_channel.queue_declare(queue="ml_server", durable=True)
    
    def publish_message(self, body_dict: dict, headers: dict):
        """Use persistent connection for publishing"""
        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Ensure publisher connection is setup
                self.setup_publisher()
                
                self.publisher_channel.basic_publish(
                    exchange="",
                    routing_key="ml_server",
                    body=json.dumps(body_dict),
                    properties=pika.BasicProperties(
                        delivery_mode=2,
                        headers=headers
                    )
                )
                return True
            except Exception as e:
                print(f"Publish attempt {attempt + 1} failed: {e}")
                # Close failed connection
                if self.publisher_connection and not self.publisher_connection.is_closed:
                    try:
                        self.publisher_connection.close()
                    except:
                        pass
                self.publisher_connection = None
                self.publisher_channel = None
                
                if attempt == max_retries - 1:
                    print(f"Failed to publish after {max_retries} attempts")
                    return False
                time.sleep(2) 

    def callback(self, ch, method, properties, body):
        """Handle incoming RabbitMQ messages"""
        thread_id = threading.current_thread().name
        headers = properties.headers or {}

        print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")

        try:
            if headers.get("request_type") == "process_files":
                contexts = []
                body_dict = json.loads(body)
                
                # Process files
                for file in body_dict.get("input_files", []):
                    try:
                        context = {"key": file["key"], "body": self.processor.process(file["url"], file["key"])}
                        contexts.append(context)
                    except Exception as e:
                        print(f"Error processing file {file['key']}: {e}")
                        contexts.append({"key": file["key"], "body": f"Error: {str(e)}"})

                body_dict["md_context"] = contexts
                ch.basic_ack(delivery_tag=method.delivery_tag)
                
                # Publish results
                if self.publish_message(body_dict, headers):
                    print(f"[Worker {thread_id}] Successfully published results")
                else:
                    ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
                    print(f"[Worker {thread_id}] Failed to publish results")
                
                print(f"[Worker {thread_id}] Contexts: {contexts}")
            else:
                ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
                print(f"[Worker {thread_id}] Unknown process")
                
        except Exception as e:
            print(f"Error in callback: {e}")
            ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)

    def connect_to_rabbitmq(self):
        """Establish connection to RabbitMQ with heartbeat"""
        connection_params = pika.URLParameters(self.rabbit_url)
        connection_params.heartbeat = 30
        connection_params.blocked_connection_timeout = 10
        
        connection = pika.BlockingConnection(connection_params)
        channel = connection.channel()

        channel.queue_declare(queue="gpu_server", durable=True)
        channel.basic_qos(prefetch_count=1)
        channel.basic_consume(
            queue="gpu_server",
            on_message_callback=self.callback
        )
        return connection, channel

    def worker(self, channel):
        """Worker function"""
        print(f"Worker started")
        try:
            channel.start_consuming()
        except Exception as e:
            print(f"Worker stopped: {e}")
        finally:
            channel.close()

    def start(self):
        """Start the worker threads"""
        print(f"Starting {self.num_workers} workers")
        while True:
            try:
                with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
                    for _ in range(self.num_workers):
                        connection, channel = self.connect_to_rabbitmq()
                        executor.submit(self.worker, channel)
            except Exception as e:
                print(f"Connection lost, reconnecting... Error: {e}")
                time.sleep(5)  # Wait before reconnecting

def main():
    worker = RabbitMQWorker()
    worker.start()