princhman commited on
Commit
1e9bbdd
·
1 Parent(s): afe430b
Files changed (1) hide show
  1. worker.py +112 -165
worker.py CHANGED
@@ -6,174 +6,121 @@ import threading
6
  import multiprocessing
7
  from concurrent.futures import ThreadPoolExecutor
8
  import pika
 
9
 
10
  from mineru_single import Processor
11
 
12
- processor = Processor()
13
-
14
- def run_pipeline(body_bytes: bytes):
15
- """
16
- 1) Decode the body bytes to a string.
17
- 2) Parse the JSON. We expect something like:
18
- {
19
- "headers": {"request_type": "process_files", "request_id": "..."},
20
- "body": {
21
- "input_files": [...],
22
- "topics": [...]
23
- }
24
- }
25
- 3) If request_type == "process_files", call processor.process_batch(...) on the URLs.
26
- 4) Return raw_text_outputs (str) and parsed_json_outputs (dict).
27
- """
28
-
29
- body_str = body_bytes.decode("utf-8")
30
- data = json.loads(body_str)
31
-
32
- headers = data.get("headers", {})
33
- request_type = headers.get("request_type", "")
34
- request_id = headers.get("request_id", "")
35
- body = data.get("body", {})
36
-
37
- # If it's not "process_files", we do nothing special
38
- if request_type != "process_files":
39
- return "No processing done", data
40
-
41
- # Gather file URLs
42
- input_files = body.get("input_files", [])
43
- topics = body.get("topics", [])
44
-
45
- urls = []
46
- file_key_map = {}
47
- for f in input_files:
48
- key = f.get("key", "")
49
- url = f.get("url", "")
50
- urls.append(url)
51
- file_key_map[url] = key
52
-
53
- batch_results = processor.process_batch(urls) # {url: markdown_string}
54
-
55
- md_context = []
56
- for url, md_content in batch_results.items():
57
- key = file_key_map.get(url, "")
58
- md_context.append({"key": key, "body": md_content})
59
-
60
- out_headers = {
61
- "request_type": "question_extraction_update_from_gpu_server",
62
- "request_id": request_id
63
- }
64
- out_body = {
65
- "input_files": input_files,
66
- "topics": topics,
67
- "md_context": md_context
68
- }
69
- final_json = {
70
- "headers": out_headers,
71
- "body": out_body
72
- }
73
-
74
- return json.dumps(final_json, ensure_ascii=False), final_json
75
-
76
- def callback(ch, method, properties, body):
77
- """
78
- This function is invoked for each incoming RabbitMQ message.
79
- """
80
- thread_id = threading.current_thread().name
81
- headers = properties.headers or {}
82
-
83
- print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
84
-
85
- # If the header "process" is "topic_extraction", we run our pipeline
86
- if headers.get("process") == "topic_extraction":
87
- raw_text_outputs, parsed_json_outputs = run_pipeline(body)
88
- # Do something with the result, e.g. print or store in DB
89
- print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
90
- else:
91
- # Fallback if "process" is something else
92
- print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
93
- time.sleep(10)
94
- print("[Worker] Done")
95
-
96
- def worker(channel):
97
- try:
98
- channel.start_consuming()
99
- except Exception as e:
100
- print(f"[Worker] Error: {e}")
101
-
102
- def connect_to_rabbitmq():
103
- rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
104
- connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
105
- channel = connection.channel()
106
-
107
- # Declare the queue
108
- channel.queue_declare(queue="ml_server", durable=True)
109
-
110
- # Limit messages per worker
111
- channel.basic_qos(prefetch_count=1)
112
-
113
- # auto_ack=True for simplicity, else you must ack manually
114
- channel.basic_consume(
115
- queue="ml_server",
116
- on_message_callback=callback,
117
- auto_ack=True
118
- )
119
- return connection, channel
120
-
121
- def main():
122
- """
123
- Main entry: starts multiple worker threads to consume from the queue.
124
- """
125
- num_workers = 2 # hard code for now
126
- print(f"Starting {num_workers} workers")
127
-
128
- with ThreadPoolExecutor(max_workers=num_workers) as executor:
129
- for _ in range(num_workers):
130
- connection, channel = connect_to_rabbitmq()
131
- executor.submit(worker, channel)
132
-
133
- if __name__ == "__main__":
134
- """
135
- If run directly, we also publish a test message, then start the workers.
136
- """
137
- rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
138
- connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
139
- channel = connection.channel()
140
- channel.queue_declare(queue="ml_server", durable=True)
141
-
142
- sample_message = {
143
- "headers": {
144
- "request_type": "process_files",
145
- "request_id": "abc123"
146
- },
147
- "body": {
148
- "input_files": [
149
- {
150
- "key": "file1",
151
- "url": "https://example.com/file1.pdf",
152
- "type": "mark_scheme"
153
- },
154
- {
155
- "key": "file2",
156
- "url": "https://example.com/file2.pdf",
157
- "type": "question"
158
- }
159
- ],
160
- "topics": [
161
- {
162
- "title": "Algebra",
163
- "id": 123
164
- }
165
- ]
166
  }
167
- }
168
-
169
- channel.basic_publish(
170
- exchange="",
171
- routing_key="ml_server",
172
- body=json.dumps(sample_message),
173
- properties=pika.BasicProperties(
174
- headers={"process": "topic_extraction"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
- )
177
- connection.close()
178
 
179
- main()
 
 
 
 
 
 
 
 
 
 
 
6
  import multiprocessing
7
  from concurrent.futures import ThreadPoolExecutor
8
  import pika
9
+ from typing import Tuple, Dict, Any
10
 
11
  from mineru_single import Processor
12
 
13
+ class MessageProcessor:
14
+ def __init__(self):
15
+ self.processor = Processor()
16
+
17
+ def process_message(self, body_bytes: bytes) -> Tuple[str, Dict[str, Any]]:
18
+ """Process incoming message and return processed results"""
19
+ body_str = body_bytes.decode("utf-8")
20
+ data = json.loads(body_str)
21
+
22
+ headers = data.get("headers", {})
23
+ request_type = headers.get("request_type", "")
24
+ request_id = headers.get("request_id", "")
25
+ body = data.get("body", {})
26
+
27
+ if request_type != "process_files":
28
+ return "No processing done", data
29
+
30
+ input_files = body.get("input_files", [])
31
+ topics = body.get("topics", [])
32
+
33
+ urls, file_key_map = self._extract_urls_and_keys(input_files)
34
+ batch_results = self.processor.process_batch(urls)
35
+ md_context = self._create_markdown_context(batch_results, file_key_map)
36
+
37
+ final_json = self._create_response_json(request_id, input_files, topics, md_context)
38
+ return json.dumps(final_json, ensure_ascii=False), final_json
39
+
40
+ def _extract_urls_and_keys(self, input_files: list) -> Tuple[list, dict]:
41
+ """Extract URLs and create file key mapping"""
42
+ urls = []
43
+ file_key_map = {}
44
+ for f in input_files:
45
+ key = f.get("key", "")
46
+ url = f.get("url", "")
47
+ urls.append(url)
48
+ file_key_map[url] = key
49
+ return urls, file_key_map
50
+
51
+ def _create_markdown_context(self, batch_results: dict, file_key_map: dict) -> list:
52
+ """Create markdown context from batch results"""
53
+ md_context = []
54
+ for url, md_content in batch_results.items():
55
+ key = file_key_map.get(url, "")
56
+ md_context.append({"key": key, "body": md_content})
57
+ return md_context
58
+
59
+ def _create_response_json(self, request_id: str, input_files: list,
60
+ topics: list, md_context: list) -> dict:
61
+ """Create the final response JSON"""
62
+ return {
63
+ "headers": {
64
+ "request_type": "question_extraction_update_from_gpu_server",
65
+ "request_id": request_id
66
+ },
67
+ "body": {
68
+ "input_files": input_files,
69
+ "topics": topics,
70
+ "md_context": md_context
71
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  }
73
+
74
+ class RabbitMQWorker:
75
+ def __init__(self, num_workers: int = 1):
76
+ self.num_workers = num_workers
77
+ self.message_processor = MessageProcessor()
78
+
79
+ def callback(self, ch, method, properties, body):
80
+ """Handle incoming RabbitMQ messages"""
81
+ thread_id = threading.current_thread().name
82
+ headers = properties.headers or {}
83
+
84
+ print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
85
+
86
+ if headers.get("process") == "topic_extraction":
87
+ raw_text_outputs, parsed_json_outputs = self.message_processor.process_message(body)
88
+ print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
89
+ else:
90
+ print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
91
+ time.sleep(10)
92
+ print("[Worker] Done")
93
+
94
+ def worker(self, channel):
95
+ """Worker process to consume messages"""
96
+ try:
97
+ channel.start_consuming()
98
+ except Exception as e:
99
+ print(f"[Worker] Error: {e}")
100
+
101
+ def connect_to_rabbitmq(self):
102
+ """Establish connection to RabbitMQ"""
103
+ rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
104
+ connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
105
+ channel = connection.channel()
106
+
107
+ channel.queue_declare(queue="ml_server", durable=True)
108
+ channel.basic_qos(prefetch_count=1)
109
+ channel.basic_consume(
110
+ queue="ml_server",
111
+ on_message_callback=self.callback,
112
+ auto_ack=True
113
  )
114
+ return connection, channel
 
115
 
116
+ def start(self):
117
+ """Start the worker threads"""
118
+ print(f"Starting {self.num_workers} workers")
119
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
120
+ for _ in range(self.num_workers):
121
+ connection, channel = self.connect_to_rabbitmq()
122
+ executor.submit(self.worker, channel)
123
+
124
+ def main():
125
+ worker = RabbitMQWorker()
126
+ worker.start()