import gradio as gr import json import requests import os import subprocess import wget from loguru import logger from data_utils.line_based_parsing import parse_line_based_query, convert_to_lines from data_utils.base_conversion_utils import ( build_schema_maps, convert_modified_to_actual_code_string ) from data_utils.schema_utils import schema_to_line_based from configs.prompt_config import SYSTEM_PROMPT_V3, MODEL_PROMPT_V3 LLAMA_SERVER_URL = "http://127.0.0.1:8080/v1/chat/completions" MODEL_PATH = "./models/unsloth.Q8_0.gguf" def download_model(): """Download the model if it doesn't exist""" os.makedirs("./models", exist_ok=True) if not os.path.exists(MODEL_PATH): logger.info("Downloading model weights...") wget.download( "https://huggingface.co/ByteMaster01/NL2SQL/resolve/main/unsloth.Q8_0.gguf", MODEL_PATH ) logger.info("\nModel download complete!") def start_llama_server(): """Start the llama.cpp server with the downloaded model""" try: logger.info("Starting llama.cpp server...") subprocess.Popen([ "python", "-m", "llama_cpp.server", "--model", MODEL_PATH, "--port", "8080" ]) logger.info("Server started successfully!") except Exception as e: logger.error(f"Failed to start server: {e}") raise def convert_line_parsed_to_mongo(line_parsed: str, schema: dict) -> str: try: modified_query = parse_line_based_query(line_parsed) collection_name = schema["collections"][0]["name"] in2out, _ = build_schema_maps(schema) reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name) return reconstructed_query except Exception as e: logger.error(f"Error converting line parsed to MongoDB query: {e}") return "" def process_query(schema_text: str, nl_query: str, additional_info: str = "") -> str: try: # Parse schema from string to dict schema = json.loads(schema_text) # Convert schema to line-based format line_based_schema = schema_to_line_based(schema) # Format prompt with line-based schema prompt = MODEL_PROMPT_V3.format( schema=line_based_schema, natural_language_query=nl_query, additional_info=additional_info ) # Prepare request payload payload = { "slot_id": 0, "temperature": 0.1, "n_keep": -1, "cache_prompt": True, "messages": [ { "role": "system", "content": SYSTEM_PROMPT_V3, }, { "role": "user", "content": prompt }, ] } # Make request to llama.cpp server response = requests.post(LLAMA_SERVER_URL, json=payload) response.raise_for_status() # Extract output from response output = response.json()["choices"][0]["message"]["content"].strip() logger.info(f"Model output: {output}") # Convert line-based output to MongoDB query mongo_query = convert_line_parsed_to_mongo(output, schema) return [ mongo_query, output ] except Exception as e: logger.error(f"Error processing query: {e}") error_msg = f"Error: {str(e)}" return [error_msg, error_msg, error_msg] def create_interface(): # Create Gradio interface iface = gr.Interface( fn=process_query, inputs=[ gr.Textbox( label="Schema (JSON format)", placeholder="Enter your MongoDB schema in JSON format...", lines=10 ), gr.Textbox( label="Natural Language Query", placeholder="Enter your query in natural language..." ), gr.Textbox( label="Additional Info (Optional)", placeholder="Enter any additional context (timestamps, etc)..." ), ], outputs=[ gr.Code(label="MongoDB Query", language="javascript", lines=1), gr.Textbox(label="Line-based Query") ], title="Natural Language to MongoDB Query Converter", description="Convert natural language queries to MongoDB queries based on your schema.", examples=[ [ '''{ "collections": [{ "name": "events", "document": { "properties": { "timestamp": {"bsonType": "int"}, "severity": {"bsonType": "int"}, "location": { "bsonType": "object", "properties": { "lat": {"bsonType": "double"}, "lon": {"bsonType": "double"} } } } } }]}''', "Find all events with severity greater than 5", "" ], [ '''{ "collections": [{ "name": "vehicles", "document": { "properties": { "timestamp": {"bsonType": "int"}, "vehicle_details": { "bsonType": "object", "properties": { "license_plate": {"bsonType": "string"}, "make": {"bsonType": "string"}, "model": {"bsonType": "string"}, "year": {"bsonType": "int"}, "color": {"bsonType": "string"} } }, "speed": {"bsonType": "double"}, "location": { "bsonType": "object", "properties": { "lat": {"bsonType": "double"}, "lon": {"bsonType": "double"} } } } } }]}''', "Find red Toyota vehicles manufactured after 2020 with speed above 60", "" ], [ '''{ "collections": [{ "name": "sensors", "document": { "properties": { "sensor_id": {"bsonType": "string"}, "readings": { "bsonType": "object", "properties": { "temperature": {"bsonType": "double"}, "humidity": {"bsonType": "double"}, "pressure": {"bsonType": "double"} } }, "timestamp": {"bsonType": "date"}, "status": {"bsonType": "string"} } } }]}''', "Find active sensors with temperature above 30 degrees in the last one day", '''current date is 21 january 2025''' ], [ '''{ "collections": [{ "name": "orders", "document": { "properties": { "order_id": {"bsonType": "string"}, "customer": { "bsonType": "object", "properties": { "id": {"bsonType": "string"}, "name": {"bsonType": "string"}, "email": {"bsonType": "string"} } }, "items": { "bsonType": "array", "items": { "bsonType": "object", "properties": { "product_id": {"bsonType": "string"}, "quantity": {"bsonType": "int"}, "price": {"bsonType": "double"} } } }, "total_amount": {"bsonType": "double"}, "status": {"bsonType": "string"}, "created_at": {"bsonType": "int"} } } }]}''', "Find orders with total amount greater than $100 that contain more than 3 items and were created in the last 24 hours", '''{"current_time": 1685890800, "last_24_hours": 1685804400}''' ] ], cache_examples=False, ) return iface if __name__ == "__main__": # Download the model download_model() # Start the llama.cpp server start_llama_server() # Give the server a moment to start import time time.sleep(5) # Launch the Gradio interface print("Starting Gradio interface...") iface = create_interface() iface.launch()