|
|
|
import os
|
|
import json
|
|
import time
|
|
import threading
|
|
import multiprocessing
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import pika
|
|
|
|
from mineru_single import Processor
|
|
|
|
processor = Processor()
|
|
|
|
def run_pipeline(body_bytes: bytes):
|
|
"""
|
|
1) Decode the body bytes to a string.
|
|
2) Parse the JSON. We expect something like:
|
|
{
|
|
"headers": {"request_type": "process_files", "request_id": "..."},
|
|
"body": {
|
|
"input_files": [...],
|
|
"topics": [...]
|
|
}
|
|
}
|
|
3) If request_type == "process_files", call processor.process_batch(...) on the URLs.
|
|
4) Return raw_text_outputs (str) and parsed_json_outputs (dict).
|
|
"""
|
|
|
|
body_str = body_bytes.decode("utf-8")
|
|
data = json.loads(body_str)
|
|
|
|
headers = data.get("headers", {})
|
|
request_type = headers.get("request_type", "")
|
|
request_id = headers.get("request_id", "")
|
|
body = data.get("body", {})
|
|
|
|
|
|
if request_type != "process_files":
|
|
return "No processing done", data
|
|
|
|
|
|
input_files = body.get("input_files", [])
|
|
topics = body.get("topics", [])
|
|
|
|
urls = []
|
|
file_key_map = {}
|
|
for f in input_files:
|
|
key = f.get("key", "")
|
|
url = f.get("url", "")
|
|
urls.append(url)
|
|
file_key_map[url] = key
|
|
|
|
batch_results = processor.process_batch(urls)
|
|
|
|
md_context = []
|
|
for url, md_content in batch_results.items():
|
|
key = file_key_map.get(url, "")
|
|
md_context.append({"key": key, "body": md_content})
|
|
|
|
out_headers = {
|
|
"request_type": "question_extraction_update_from_gpu_server",
|
|
"request_id": request_id
|
|
}
|
|
out_body = {
|
|
"input_files": input_files,
|
|
"topics": topics,
|
|
"md_context": md_context
|
|
}
|
|
final_json = {
|
|
"headers": out_headers,
|
|
"body": out_body
|
|
}
|
|
|
|
return json.dumps(final_json, ensure_ascii=False), final_json
|
|
|
|
def callback(ch, method, properties, body):
|
|
"""
|
|
This function is invoked for each incoming RabbitMQ message.
|
|
"""
|
|
thread_id = threading.current_thread().name
|
|
headers = properties.headers or {}
|
|
|
|
print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
|
|
|
|
|
|
if headers.get("process") == "topic_extraction":
|
|
raw_text_outputs, parsed_json_outputs = run_pipeline(body)
|
|
|
|
print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
|
|
else:
|
|
|
|
print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
|
|
time.sleep(10)
|
|
print("[Worker] Done")
|
|
|
|
def worker(channel):
|
|
try:
|
|
channel.start_consuming()
|
|
except Exception as e:
|
|
print(f"[Worker] Error: {e}")
|
|
|
|
def connect_to_rabbitmq():
|
|
rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
|
|
connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
|
|
channel = connection.channel()
|
|
|
|
|
|
channel.queue_declare(queue="ml_server", durable=True)
|
|
|
|
|
|
channel.basic_qos(prefetch_count=1)
|
|
|
|
|
|
channel.basic_consume(
|
|
queue="ml_server",
|
|
on_message_callback=callback,
|
|
auto_ack=True
|
|
)
|
|
return connection, channel
|
|
|
|
def main():
|
|
"""
|
|
Main entry: starts multiple worker threads to consume from the queue.
|
|
"""
|
|
num_workers = 2
|
|
print(f"Starting {num_workers} workers")
|
|
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
for _ in range(num_workers):
|
|
connection, channel = connect_to_rabbitmq()
|
|
executor.submit(worker, channel)
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
If run directly, we also publish a test message, then start the workers.
|
|
"""
|
|
rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
|
|
connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
|
|
channel = connection.channel()
|
|
channel.queue_declare(queue="ml_server", durable=True)
|
|
|
|
sample_message = {
|
|
"headers": {
|
|
"request_type": "process_files",
|
|
"request_id": "abc123"
|
|
},
|
|
"body": {
|
|
"input_files": [
|
|
{
|
|
"key": "file1",
|
|
"url": "https://example.com/file1.pdf",
|
|
"type": "mark_scheme"
|
|
},
|
|
{
|
|
"key": "file2",
|
|
"url": "https://example.com/file2.pdf",
|
|
"type": "question"
|
|
}
|
|
],
|
|
"topics": [
|
|
{
|
|
"title": "Algebra",
|
|
"id": 123
|
|
}
|
|
]
|
|
}
|
|
}
|
|
|
|
channel.basic_publish(
|
|
exchange="",
|
|
routing_key="ml_server",
|
|
body=json.dumps(sample_message),
|
|
properties=pika.BasicProperties(
|
|
headers={"process": "topic_extraction"}
|
|
)
|
|
)
|
|
connection.close()
|
|
|
|
main() |