File size: 8,882 Bytes
0a65f9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8710952
0a65f9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8710952
 
0a65f9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import gradio as gr
import json
import requests
import os
import subprocess
import wget
from loguru import logger
from data_utils.line_based_parsing import parse_line_based_query, convert_to_lines
from data_utils.base_conversion_utils import (
    build_schema_maps,
    convert_modified_to_actual_code_string
)
from data_utils.schema_utils import schema_to_line_based
from configs.prompt_config import SYSTEM_PROMPT_V3, MODEL_PROMPT_V3

LLAMA_SERVER_URL = "http://127.0.0.1:8080/v1/chat/completions"
MODEL_PATH = "./models/unsloth.Q8_0.gguf"

def download_model():
    """Download the model if it doesn't exist"""
    os.makedirs("./models", exist_ok=True)
    if not os.path.exists(MODEL_PATH):
        logger.info("Downloading model weights...")
        wget.download(
            "https://huggingface.co/ByteMaster01/NL2SQL/resolve/main/unsloth.Q8_0.gguf",
            MODEL_PATH
        )
        logger.info("\nModel download complete!")

def start_llama_server():
    """Start the llama.cpp server with the downloaded model"""
    try:
        logger.info("Starting llama.cpp server...")
        subprocess.Popen([
            "python", "-m", "llama_cpp.server",
            "--model", MODEL_PATH,
            "--port", "8080"
        ])
        logger.info("Server started successfully!")
    except Exception as e:
        logger.error(f"Failed to start server: {e}")
        raise

def convert_line_parsed_to_mongo(line_parsed: str, schema: dict) -> str:
    try:
        modified_query = parse_line_based_query(line_parsed)
        collection_name = schema["collections"][0]["name"]
        in2out, _ = build_schema_maps(schema)
        reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name)
        return reconstructed_query
    except Exception as e:
        logger.error(f"Error converting line parsed to MongoDB query: {e}")
        return ""

def process_query(schema_text: str, nl_query: str, additional_info: str = "") -> str:
    try:
        # Parse schema from string to dict
        schema = json.loads(schema_text)
        
        # Convert schema to line-based format
        line_based_schema = schema_to_line_based(schema)
        
        # Format prompt with line-based schema
        prompt = MODEL_PROMPT_V3.format(
            schema=line_based_schema,
            natural_language_query=nl_query,
            additional_info=additional_info
        )
        
        # Prepare request payload
        payload = {
            "slot_id": 0,
            "temperature": 0.1,
            "n_keep": -1,
            "cache_prompt": True,
            "messages": [
                {
                    "role": "system",
                    "content": SYSTEM_PROMPT_V3,
                },
                {
                    "role": "user",
                    "content": prompt
                },
            ]
        }
        
        # Make request to llama.cpp server
        response = requests.post(LLAMA_SERVER_URL, json=payload)
        response.raise_for_status()
        
        # Extract output from response
        output = response.json()["choices"][0]["message"]["content"].strip()
        logger.info(f"Model output: {output}")
        
        # Convert line-based output to MongoDB query
        mongo_query = convert_line_parsed_to_mongo(output, schema)
        
        return [
            mongo_query,
            output
        ]
    except Exception as e:
        logger.error(f"Error processing query: {e}")
        error_msg = f"Error: {str(e)}"
        return [error_msg, error_msg, error_msg]

def create_interface():
    # Create Gradio interface
    iface = gr.Interface(
        fn=process_query,
        inputs=[
            gr.Textbox(
                label="Schema (JSON format)", 
                placeholder="Enter your MongoDB schema in JSON format...",
                lines=10
            ),
            gr.Textbox(
                label="Natural Language Query",
                placeholder="Enter your query in natural language..."
            ),
            gr.Textbox(
                label="Additional Info (Optional)",
                placeholder="Enter any additional context (timestamps, etc)..."
            ),
        ],
        outputs=[
            gr.Code(label="MongoDB Query", language="javascript", lines=1),
            gr.Textbox(label="Line-based Query")
        ],
        title="Natural Language to MongoDB Query Converter",
        description="Convert natural language queries to MongoDB queries based on your schema.",
        examples=[
            [
                '''{
    "collections": [{
        "name": "events",
        "document": {
            "properties": {
                "timestamp": {"bsonType": "int"},
                "severity": {"bsonType": "int"},
                "location": {
                    "bsonType": "object",
                    "properties": {
                        "lat": {"bsonType": "double"},
                        "lon": {"bsonType": "double"}
                    }
                }
            }
        }
    }]}''',
                "Find all events with severity greater than 5",
                ""
            ],
            [
                '''{
    "collections": [{
        "name": "vehicles",
        "document": {
            "properties": {
                "timestamp": {"bsonType": "int"},
                "vehicle_details": {
                    "bsonType": "object",
                    "properties": {
                        "license_plate": {"bsonType": "string"},
                        "make": {"bsonType": "string"},
                        "model": {"bsonType": "string"},
                        "year": {"bsonType": "int"},
                        "color": {"bsonType": "string"}
                    }
                },
                "speed": {"bsonType": "double"},
                "location": {
                    "bsonType": "object",
                    "properties": {
                        "lat": {"bsonType": "double"},
                        "lon": {"bsonType": "double"}
                    }
                }
            }
        }
    }]}''',
                "Find red Toyota vehicles manufactured after 2020 with speed above 60",
                ""
            ],
            [
                '''{
    "collections": [{
        "name": "sensors",
        "document": {
            "properties": {
                "sensor_id": {"bsonType": "string"},
                "readings": {
                    "bsonType": "object",
                    "properties": {
                        "temperature": {"bsonType": "double"},
                        "humidity": {"bsonType": "double"},
                        "pressure": {"bsonType": "double"}
                    }
                },
                "timestamp": {"bsonType": "date"},
                "status": {"bsonType": "string"}
            }
        }
    }]}''',
                "Find active sensors with temperature above 30 degrees in the last one day",
                '''current date is 21 january 2025'''
            ],
            [
                '''{
    "collections": [{
        "name": "orders",
        "document": {
            "properties": {
                "order_id": {"bsonType": "string"},
                "customer": {
                    "bsonType": "object",
                    "properties": {
                        "id": {"bsonType": "string"},
                        "name": {"bsonType": "string"},
                        "email": {"bsonType": "string"}
                    }
                },
                "items": {
                    "bsonType": "array",
                    "items": {
                        "bsonType": "object",
                        "properties": {
                            "product_id": {"bsonType": "string"},
                            "quantity": {"bsonType": "int"},
                            "price": {"bsonType": "double"}
                        }
                    }
                },
                "total_amount": {"bsonType": "double"},
                "status": {"bsonType": "string"},
                "created_at": {"bsonType": "int"}
            }
        }
    }]}''',
                "Find orders with total amount greater than $100 that contain more than 3 items and were created in the last 24 hours",
                '''{"current_time": 1685890800, "last_24_hours": 1685804400}'''
            ]
        ],
        cache_examples=False,
    )
    return iface

if __name__ == "__main__":
    # Download the model
    download_model()
    
    # Start the llama.cpp server
    start_llama_server()
    
    # Give the server a moment to start
    import time
    time.sleep(5)
    
    # Launch the Gradio interface
    print("Starting Gradio interface...")
    iface = create_interface()
    iface.launch()