MinerU / worker.py
princhman's picture
finalasing
c6bff1f
raw
history blame
5.47 kB
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika
from mineru_single import Processor
processor = Processor()
def run_pipeline(body_bytes: bytes):
"""
1) Decode the body bytes to a string.
2) Parse the JSON. We expect something like:
{
"headers": {"request_type": "process_files", "request_id": "..."},
"body": {
"input_files": [...],
"topics": [...]
}
}
3) If request_type == "process_files", call processor.process_batch(...) on the URLs.
4) Return raw_text_outputs (str) and parsed_json_outputs (dict).
"""
body_str = body_bytes.decode("utf-8")
data = json.loads(body_str)
headers = data.get("headers", {})
request_type = headers.get("request_type", "")
request_id = headers.get("request_id", "")
body = data.get("body", {})
# If it's not "process_files", we do nothing special
if request_type != "process_files":
return "No processing done", data
# Gather file URLs
input_files = body.get("input_files", [])
topics = body.get("topics", [])
urls = []
file_key_map = {}
for f in input_files:
key = f.get("key", "")
url = f.get("url", "")
urls.append(url)
file_key_map[url] = key
batch_results = processor.process_batch(urls) # {url: markdown_string}
md_context = []
for url, md_content in batch_results.items():
key = file_key_map.get(url, "")
md_context.append({"key": key, "body": md_content})
out_headers = {
"request_type": "question_extraction_update_from_gpu_server",
"request_id": request_id
}
out_body = {
"input_files": input_files,
"topics": topics,
"md_context": md_context
}
final_json = {
"headers": out_headers,
"body": out_body
}
return json.dumps(final_json, ensure_ascii=False), final_json
def callback(ch, method, properties, body):
"""
This function is invoked for each incoming RabbitMQ message.
"""
thread_id = threading.current_thread().name
headers = properties.headers or {}
print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
# If the header "process" is "topic_extraction", we run our pipeline
if headers.get("process") == "topic_extraction":
raw_text_outputs, parsed_json_outputs = run_pipeline(body)
# Do something with the result, e.g. print or store in DB
print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
else:
# Fallback if "process" is something else
print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
time.sleep(10)
print("[Worker] Done")
def worker(channel):
try:
channel.start_consuming()
except Exception as e:
print(f"[Worker] Error: {e}")
def connect_to_rabbitmq():
rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
channel = connection.channel()
# Declare the queue
channel.queue_declare(queue="ml_server", durable=True)
# Limit messages per worker
channel.basic_qos(prefetch_count=1)
# auto_ack=True for simplicity, else you must ack manually
channel.basic_consume(
queue="ml_server",
on_message_callback=callback,
auto_ack=True
)
return connection, channel
def main():
"""
Main entry: starts multiple worker threads to consume from the queue.
"""
num_workers = 2 # hard code for now
print(f"Starting {num_workers} workers")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for _ in range(num_workers):
connection, channel = connect_to_rabbitmq()
executor.submit(worker, channel)
if __name__ == "__main__":
"""
If run directly, we also publish a test message, then start the workers.
"""
rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
channel = connection.channel()
channel.queue_declare(queue="ml_server", durable=True)
sample_message = {
"headers": {
"request_type": "process_files",
"request_id": "abc123"
},
"body": {
"input_files": [
{
"key": "file1",
"url": "https://example.com/file1.pdf",
"type": "mark_scheme"
},
{
"key": "file2",
"url": "https://example.com/file2.pdf",
"type": "question"
}
],
"topics": [
{
"title": "Algebra",
"id": 123
}
]
}
}
channel.basic_publish(
exchange="",
routing_key="ml_server",
body=json.dumps(sample_message),
properties=pika.BasicProperties(
headers={"process": "topic_extraction"}
)
)
connection.close()
main()