princhman commited on
Commit
eb15d7d
·
1 Parent(s): 5c5d339

rabbitmq logic update

Browse files
Files changed (1) hide show
  1. worker.py +61 -32
worker.py CHANGED
@@ -14,10 +14,33 @@ from mineru_single import Processor
14
  class RabbitMQWorker:
15
  def __init__(self, num_workers: int = 1):
16
  self.num_workers = num_workers
17
- self.public_connection = pika.BlockingConnection(pika.URLParameters(os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")))
18
- self.public_channel = self.public_connection.channel()
19
  self.processor = Processor()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def callback(self, ch, method, properties, body):
22
  """Handle incoming RabbitMQ messages"""
23
  thread_id = threading.current_thread().name
@@ -25,37 +48,38 @@ class RabbitMQWorker:
25
 
26
  print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
27
 
28
- if headers.get("request_type") == "process_files":
29
- contexts = []
30
- body_dict = json.loads(body)
31
- for file in body_dict.get("input_files", []):
32
- contexts.append({"key": file["key"], "body": self.processor.process(file["url"])})
33
- body_dict["md_context"] = contexts
34
- json_body = json.dumps(body_dict)
35
- self.public_channel.queue_declare(queue="ml_server", durable=True)
36
- self.public_channel.basic_publish(
37
- exchange="",
38
- routing_key="ml_server",
39
- body=json_body,
40
- properties=pika.BasicProperties(headers=headers)
41
- )
42
- print(f"[Worker {thread_id}] Contexts: {contexts}")
43
-
44
- else:
45
- print(f"[Worker {thread_id}] Unknown process")
46
- return
47
-
48
- def worker(self, channel):
49
- """Worker process to consume messages"""
50
  try:
51
- channel.start_consuming()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
- print(f"[Worker] Error: {e}")
54
 
55
  def connect_to_rabbitmq(self):
56
  """Establish connection to RabbitMQ"""
57
- rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
58
- connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
59
  channel = connection.channel()
60
 
61
  channel.queue_declare(queue="gpu_server", durable=True)
@@ -70,10 +94,15 @@ class RabbitMQWorker:
70
  def start(self):
71
  """Start the worker threads"""
72
  print(f"Starting {self.num_workers} workers")
73
- with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
74
- for _ in range(self.num_workers):
75
- connection, channel = self.connect_to_rabbitmq()
76
- executor.submit(self.worker, channel)
 
 
 
 
 
77
 
78
  def main():
79
  worker = RabbitMQWorker()
 
14
  class RabbitMQWorker:
15
  def __init__(self, num_workers: int = 1):
16
  self.num_workers = num_workers
17
+ self.rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
 
18
  self.processor = Processor()
19
 
20
+ def publish_message(self, body_dict: dict, headers: dict):
21
+ """Create a new connection for each publish operation"""
22
+ try:
23
+ connection = pika.BlockingConnection(pika.URLParameters(self.rabbit_url))
24
+ channel = connection.channel()
25
+
26
+ channel.queue_declare(queue="ml_server", durable=True)
27
+
28
+ channel.basic_publish(
29
+ exchange="",
30
+ routing_key="ml_server",
31
+ body=json.dumps(body_dict),
32
+ properties=pika.BasicProperties(
33
+ delivery_mode=2, # make message persistent
34
+ headers=headers
35
+ )
36
+ )
37
+
38
+ connection.close()
39
+ return True
40
+ except Exception as e:
41
+ print(f"Error publishing message: {e}")
42
+ return False
43
+
44
  def callback(self, ch, method, properties, body):
45
  """Handle incoming RabbitMQ messages"""
46
  thread_id = threading.current_thread().name
 
48
 
49
  print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
+ if headers.get("request_type") == "process_files":
53
+ contexts = []
54
+ body_dict = json.loads(body)
55
+
56
+ # Process files
57
+ for file in body_dict.get("input_files", []):
58
+ try:
59
+ context = {"key": file["key"], "body": self.processor.process(file["url"])}
60
+ contexts.append(context)
61
+ except Exception as e:
62
+ print(f"Error processing file {file['key']}: {e}")
63
+ contexts.append({"key": file["key"], "body": f"Error: {str(e)}"})
64
+
65
+ body_dict["md_context"] = contexts
66
+
67
+ # Publish results
68
+ if self.publish_message(body_dict, headers):
69
+ print(f"[Worker {thread_id}] Successfully published results")
70
+ else:
71
+ print(f"[Worker {thread_id}] Failed to publish results")
72
+
73
+ print(f"[Worker {thread_id}] Contexts: {contexts}")
74
+ else:
75
+ print(f"[Worker {thread_id}] Unknown process")
76
+
77
  except Exception as e:
78
+ print(f"Error in callback: {e}")
79
 
80
  def connect_to_rabbitmq(self):
81
  """Establish connection to RabbitMQ"""
82
+ connection = pika.BlockingConnection(pika.URLParameters(self.rabbit_url))
 
83
  channel = connection.channel()
84
 
85
  channel.queue_declare(queue="gpu_server", durable=True)
 
94
  def start(self):
95
  """Start the worker threads"""
96
  print(f"Starting {self.num_workers} workers")
97
+ while True:
98
+ try:
99
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
100
+ for _ in range(self.num_workers):
101
+ connection, channel = self.connect_to_rabbitmq()
102
+ executor.submit(self.worker, channel)
103
+ except Exception as e:
104
+ print(f"Connection lost, reconnecting... Error: {e}")
105
+ time.sleep(5) # Wait before reconnecting
106
 
107
  def main():
108
  worker = RabbitMQWorker()