File size: 5,473 Bytes
b273357 c6bff1f b273357 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika
from mineru_single import Processor
processor = Processor()
def run_pipeline(body_bytes: bytes):
"""
1) Decode the body bytes to a string.
2) Parse the JSON. We expect something like:
{
"headers": {"request_type": "process_files", "request_id": "..."},
"body": {
"input_files": [...],
"topics": [...]
}
}
3) If request_type == "process_files", call processor.process_batch(...) on the URLs.
4) Return raw_text_outputs (str) and parsed_json_outputs (dict).
"""
body_str = body_bytes.decode("utf-8")
data = json.loads(body_str)
headers = data.get("headers", {})
request_type = headers.get("request_type", "")
request_id = headers.get("request_id", "")
body = data.get("body", {})
# If it's not "process_files", we do nothing special
if request_type != "process_files":
return "No processing done", data
# Gather file URLs
input_files = body.get("input_files", [])
topics = body.get("topics", [])
urls = []
file_key_map = {}
for f in input_files:
key = f.get("key", "")
url = f.get("url", "")
urls.append(url)
file_key_map[url] = key
batch_results = processor.process_batch(urls) # {url: markdown_string}
md_context = []
for url, md_content in batch_results.items():
key = file_key_map.get(url, "")
md_context.append({"key": key, "body": md_content})
out_headers = {
"request_type": "question_extraction_update_from_gpu_server",
"request_id": request_id
}
out_body = {
"input_files": input_files,
"topics": topics,
"md_context": md_context
}
final_json = {
"headers": out_headers,
"body": out_body
}
return json.dumps(final_json, ensure_ascii=False), final_json
def callback(ch, method, properties, body):
"""
This function is invoked for each incoming RabbitMQ message.
"""
thread_id = threading.current_thread().name
headers = properties.headers or {}
print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
# If the header "process" is "topic_extraction", we run our pipeline
if headers.get("process") == "topic_extraction":
raw_text_outputs, parsed_json_outputs = run_pipeline(body)
# Do something with the result, e.g. print or store in DB
print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
else:
# Fallback if "process" is something else
print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
time.sleep(10)
print("[Worker] Done")
def worker(channel):
try:
channel.start_consuming()
except Exception as e:
print(f"[Worker] Error: {e}")
def connect_to_rabbitmq():
rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
channel = connection.channel()
# Declare the queue
channel.queue_declare(queue="ml_server", durable=True)
# Limit messages per worker
channel.basic_qos(prefetch_count=1)
# auto_ack=True for simplicity, else you must ack manually
channel.basic_consume(
queue="ml_server",
on_message_callback=callback,
auto_ack=True
)
return connection, channel
def main():
"""
Main entry: starts multiple worker threads to consume from the queue.
"""
num_workers = 2 # hard code for now
print(f"Starting {num_workers} workers")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for _ in range(num_workers):
connection, channel = connect_to_rabbitmq()
executor.submit(worker, channel)
if __name__ == "__main__":
"""
If run directly, we also publish a test message, then start the workers.
"""
rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
channel = connection.channel()
channel.queue_declare(queue="ml_server", durable=True)
sample_message = {
"headers": {
"request_type": "process_files",
"request_id": "abc123"
},
"body": {
"input_files": [
{
"key": "file1",
"url": "https://example.com/file1.pdf",
"type": "mark_scheme"
},
{
"key": "file2",
"url": "https://example.com/file2.pdf",
"type": "question"
}
],
"topics": [
{
"title": "Algebra",
"id": 123
}
]
}
}
channel.basic_publish(
exchange="",
routing_key="ml_server",
body=json.dumps(sample_message),
properties=pika.BasicProperties(
headers={"process": "topic_extraction"}
)
)
connection.close()
main() |