File size: 5,473 Bytes
b273357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6bff1f
b273357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#!/usr/bin/env python3
import os
import json
import time
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pika

from mineru_single import Processor

processor = Processor()

def run_pipeline(body_bytes: bytes):
    """

    1) Decode the body bytes to a string.

    2) Parse the JSON. We expect something like:

       {

         "headers": {"request_type": "process_files", "request_id": "..."},

         "body": {

             "input_files": [...],

             "topics": [...]

         }

       }

    3) If request_type == "process_files", call processor.process_batch(...) on the URLs.

    4) Return raw_text_outputs (str) and parsed_json_outputs (dict).

    """

    body_str = body_bytes.decode("utf-8")
    data = json.loads(body_str)

    headers = data.get("headers", {})
    request_type = headers.get("request_type", "")
    request_id = headers.get("request_id", "")
    body = data.get("body", {})

    # If it's not "process_files", we do nothing special
    if request_type != "process_files":
        return "No processing done", data

    # Gather file URLs
    input_files = body.get("input_files", [])
    topics = body.get("topics", [])

    urls = []
    file_key_map = {}
    for f in input_files:
        key = f.get("key", "")
        url = f.get("url", "")
        urls.append(url)
        file_key_map[url] = key

    batch_results = processor.process_batch(urls)  # {url: markdown_string}

    md_context = []
    for url, md_content in batch_results.items():
        key = file_key_map.get(url, "")
        md_context.append({"key": key, "body": md_content})

    out_headers = {
        "request_type": "question_extraction_update_from_gpu_server",
        "request_id": request_id
    }
    out_body = {
        "input_files": input_files,
        "topics": topics,
        "md_context": md_context
    }
    final_json = {
        "headers": out_headers,
        "body": out_body
    }

    return json.dumps(final_json, ensure_ascii=False), final_json

def callback(ch, method, properties, body):
    """

    This function is invoked for each incoming RabbitMQ message.

    """
    thread_id = threading.current_thread().name
    headers = properties.headers or {}

    print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")

    # If the header "process" is "topic_extraction", we run our pipeline
    if headers.get("process") == "topic_extraction":
        raw_text_outputs, parsed_json_outputs = run_pipeline(body)
        # Do something with the result, e.g. print or store in DB
        print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
    else:
        # Fallback if "process" is something else
        print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
        time.sleep(10)
        print("[Worker] Done")

def worker(channel):
    try:
        channel.start_consuming()
    except Exception as e:
        print(f"[Worker] Error: {e}")

def connect_to_rabbitmq():
    rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
    connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
    channel = connection.channel()

    # Declare the queue
    channel.queue_declare(queue="ml_server", durable=True)

    # Limit messages per worker
    channel.basic_qos(prefetch_count=1)

    # auto_ack=True for simplicity, else you must ack manually
    channel.basic_consume(
        queue="ml_server",
        on_message_callback=callback,
        auto_ack=True
    )
    return connection, channel

def main():
    """

    Main entry: starts multiple worker threads to consume from the queue.

    """
    num_workers = 2 # hard code for now
    print(f"Starting {num_workers} workers")

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        for _ in range(num_workers):
            connection, channel = connect_to_rabbitmq()
            executor.submit(worker, channel)

if __name__ == "__main__":
    """

    If run directly, we also publish a test message, then start the workers.

    """
    rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
    connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
    channel = connection.channel()
    channel.queue_declare(queue="ml_server", durable=True)

    sample_message = {
        "headers": {
            "request_type": "process_files",
            "request_id": "abc123"
        },
        "body": {
            "input_files": [
                {
                    "key": "file1",
                    "url": "https://example.com/file1.pdf",
                    "type": "mark_scheme"
                },
                {
                    "key": "file2",
                    "url": "https://example.com/file2.pdf",
                    "type": "question"
                }
            ],
            "topics": [
                {
                    "title": "Algebra",
                    "id": 123
                }
            ]
        }
    }

    channel.basic_publish(
        exchange="",
        routing_key="ml_server",
        body=json.dumps(sample_message),
        properties=pika.BasicProperties(
            headers={"process": "topic_extraction"}
        )
    )
    connection.close()

    main()