Spaces:
Running
Running
import gradio as gr | |
import json | |
import requests | |
import os | |
import subprocess | |
import wget | |
from loguru import logger | |
from data_utils.line_based_parsing import parse_line_based_query, convert_to_lines | |
from data_utils.base_conversion_utils import ( | |
build_schema_maps, | |
convert_modified_to_actual_code_string | |
) | |
from data_utils.schema_utils import schema_to_line_based | |
from configs.prompt_config import SYSTEM_PROMPT_V3, MODEL_PROMPT_V3 | |
LLAMA_SERVER_URL = "http://127.0.0.1:8080/v1/chat/completions" | |
MODEL_PATH = "./models/unsloth.Q8_0.gguf" | |
def download_model(): | |
"""Download the model if it doesn't exist""" | |
os.makedirs("./models", exist_ok=True) | |
if not os.path.exists(MODEL_PATH): | |
logger.info("Downloading model weights...") | |
wget.download( | |
"https://huggingface.co/ByteMaster01/NL2SQL/resolve/main/unsloth.Q8_0.gguf", | |
MODEL_PATH | |
) | |
logger.info("\nModel download complete!") | |
def start_llama_server(): | |
"""Start the llama.cpp server with the downloaded model""" | |
try: | |
logger.info("Starting llama.cpp server...") | |
subprocess.Popen([ | |
"python", "-m", "llama_cpp.server", | |
"--model", MODEL_PATH, | |
"--port", "8080" | |
]) | |
logger.info("Server started successfully!") | |
except Exception as e: | |
logger.error(f"Failed to start server: {e}") | |
raise | |
def convert_line_parsed_to_mongo(line_parsed: str, schema: dict) -> str: | |
try: | |
modified_query = parse_line_based_query(line_parsed) | |
collection_name = schema["collections"][0]["name"] | |
in2out, _ = build_schema_maps(schema) | |
reconstructed_query = convert_modified_to_actual_code_string(modified_query, in2out, collection_name) | |
return reconstructed_query | |
except Exception as e: | |
logger.error(f"Error converting line parsed to MongoDB query: {e}") | |
return "" | |
def process_query(schema_text: str, nl_query: str, additional_info: str = "") -> str: | |
try: | |
# Parse schema from string to dict | |
schema = json.loads(schema_text) | |
# Convert schema to line-based format | |
line_based_schema = schema_to_line_based(schema) | |
# Format prompt with line-based schema | |
prompt = MODEL_PROMPT_V3.format( | |
schema=line_based_schema, | |
natural_language_query=nl_query, | |
additional_info=additional_info | |
) | |
# Prepare request payload | |
payload = { | |
"slot_id": 0, | |
"temperature": 0.1, | |
"n_keep": -1, | |
"cache_prompt": True, | |
"messages": [ | |
{ | |
"role": "system", | |
"content": SYSTEM_PROMPT_V3, | |
}, | |
{ | |
"role": "user", | |
"content": prompt | |
}, | |
] | |
} | |
# Make request to llama.cpp server | |
response = requests.post(LLAMA_SERVER_URL, json=payload) | |
response.raise_for_status() | |
# Extract output from response | |
output = response.json()["choices"][0]["message"]["content"].strip() | |
logger.info(f"Model output: {output}") | |
# Convert line-based output to MongoDB query | |
mongo_query = convert_line_parsed_to_mongo(output, schema) | |
return [ | |
mongo_query, | |
output | |
] | |
except Exception as e: | |
logger.error(f"Error processing query: {e}") | |
error_msg = f"Error: {str(e)}" | |
return [error_msg, error_msg, error_msg] | |
def create_interface(): | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=process_query, | |
inputs=[ | |
gr.Textbox( | |
label="Schema (JSON format)", | |
placeholder="Enter your MongoDB schema in JSON format...", | |
lines=10 | |
), | |
gr.Textbox( | |
label="Natural Language Query", | |
placeholder="Enter your query in natural language..." | |
), | |
gr.Textbox( | |
label="Additional Info (Optional)", | |
placeholder="Enter any additional context (timestamps, etc)..." | |
), | |
], | |
outputs=[ | |
gr.Code(label="MongoDB Query", language="javascript", lines=1), | |
gr.Textbox(label="Line-based Query") | |
], | |
title="Natural Language to MongoDB Query Converter", | |
description="Convert natural language queries to MongoDB queries based on your schema.", | |
examples=[ | |
[ | |
'''{ | |
"collections": [{ | |
"name": "events", | |
"document": { | |
"properties": { | |
"timestamp": {"bsonType": "int"}, | |
"severity": {"bsonType": "int"}, | |
"location": { | |
"bsonType": "object", | |
"properties": { | |
"lat": {"bsonType": "double"}, | |
"lon": {"bsonType": "double"} | |
} | |
} | |
} | |
} | |
}]}''', | |
"Find all events with severity greater than 5", | |
"" | |
], | |
[ | |
'''{ | |
"collections": [{ | |
"name": "vehicles", | |
"document": { | |
"properties": { | |
"timestamp": {"bsonType": "int"}, | |
"vehicle_details": { | |
"bsonType": "object", | |
"properties": { | |
"license_plate": {"bsonType": "string"}, | |
"make": {"bsonType": "string"}, | |
"model": {"bsonType": "string"}, | |
"year": {"bsonType": "int"}, | |
"color": {"bsonType": "string"} | |
} | |
}, | |
"speed": {"bsonType": "double"}, | |
"location": { | |
"bsonType": "object", | |
"properties": { | |
"lat": {"bsonType": "double"}, | |
"lon": {"bsonType": "double"} | |
} | |
} | |
} | |
} | |
}]}''', | |
"Find red Toyota vehicles manufactured after 2020 with speed above 60", | |
"" | |
], | |
[ | |
'''{ | |
"collections": [{ | |
"name": "sensors", | |
"document": { | |
"properties": { | |
"sensor_id": {"bsonType": "string"}, | |
"readings": { | |
"bsonType": "object", | |
"properties": { | |
"temperature": {"bsonType": "double"}, | |
"humidity": {"bsonType": "double"}, | |
"pressure": {"bsonType": "double"} | |
} | |
}, | |
"timestamp": {"bsonType": "date"}, | |
"status": {"bsonType": "string"} | |
} | |
} | |
}]}''', | |
"Find active sensors with temperature above 30 degrees in the last one day", | |
'''current date is 21 january 2025''' | |
], | |
[ | |
'''{ | |
"collections": [{ | |
"name": "orders", | |
"document": { | |
"properties": { | |
"order_id": {"bsonType": "string"}, | |
"customer": { | |
"bsonType": "object", | |
"properties": { | |
"id": {"bsonType": "string"}, | |
"name": {"bsonType": "string"}, | |
"email": {"bsonType": "string"} | |
} | |
}, | |
"items": { | |
"bsonType": "array", | |
"items": { | |
"bsonType": "object", | |
"properties": { | |
"product_id": {"bsonType": "string"}, | |
"quantity": {"bsonType": "int"}, | |
"price": {"bsonType": "double"} | |
} | |
} | |
}, | |
"total_amount": {"bsonType": "double"}, | |
"status": {"bsonType": "string"}, | |
"created_at": {"bsonType": "int"} | |
} | |
} | |
}]}''', | |
"Find orders with total amount greater than $100 that contain more than 3 items and were created in the last 24 hours", | |
'''{"current_time": 1685890800, "last_24_hours": 1685804400}''' | |
] | |
], | |
cache_examples=False, | |
) | |
return iface | |
if __name__ == "__main__": | |
# Download the model | |
download_model() | |
# Start the llama.cpp server | |
start_llama_server() | |
# Give the server a moment to start | |
import time | |
time.sleep(5) | |
# Launch the Gradio interface | |
print("Starting Gradio interface...") | |
iface = create_interface() | |
iface.launch() |