sql / app.py
nileshhanotia's picture
Create app.py
7794a8e verified
raw
history blame
4.01 kB
import mysql.connector
from mysql.connector import Error
import requests
import json
import os
def generate_sql_query(natural_language_query, schema_info, space_url):
"""Generate SQL query using Hugging Face Space API."""
# Construct a more structured prompt
prompt = f"""Given this SQL table schema:
{schema_info}
Write a SQL query that will:
{natural_language_query}
The query should be valid MySQL syntax and include only the SELECT statement."""
# Make API request to the Hugging Face Space
payload = {
"inputs": prompt,
"options": {
"use_cache": False
}
}
try:
response = requests.post(space_url, json=payload)
if response.status_code == 200:
return response.json().get('generated_text', '').strip()
else:
raise Exception(f"API request failed: {response.text}")
except Exception as e:
print(f"API Error: {str(e)}")
return None
def main():
try:
# Define the Hugging Face Space URL
space_url = "https://huggingface.co/spaces/nileshhanotia/sql"
# Define your schema information
schema_info = """
CREATE TABLE sales (
pizza_id DECIMAL(8,2) PRIMARY KEY,
order_id DECIMAL(8,2),
pizza_name_id VARCHAR(14),
quantity DECIMAL(4,2),
order_date DATE,
order_time VARCHAR(8),
unit_price DECIMAL(5,2),
total_price DECIMAL(5,2),
pizza_size VARCHAR(3),
pizza_category VARCHAR(7),
pizza_ingredients VARCHAR(97),
pizza_name VARCHAR(42)
);
"""
# Establish connection to the database
connection = mysql.connector.connect(
host="localhost",
database="pizza",
user="root",
password="root",
port=8889
)
if connection.is_connected():
cursor = connection.cursor()
print("Database connected successfully!")
while True:
try:
# Get user input
print("\nEnter your question (or 'exit' to quit):")
natural_language_query = input("> ").strip()
if natural_language_query.lower() == 'exit':
break
# Generate and execute query
sql_query = generate_sql_query(natural_language_query, schema_info, space_url)
if sql_query:
print(f"\nExecuting SQL Query:\n{sql_query}")
cursor.execute(sql_query)
records = cursor.fetchall()
# Print results
if records:
print("\nResults:")
# Get column names
columns = [desc[0] for desc in cursor.description]
print(" | ".join(columns))
print("-" * (len(" | ".join(columns)) + 10))
for row in records:
print(" | ".join(str(val) for val in row))
else:
print("\nNo results found.")
except KeyboardInterrupt:
print("\nOperation cancelled by user.")
continue
except Exception as e:
print(f"\nError: {str(e)}")
continue
except Error as e:
print(f"\nDatabase error: {str(e)}")
except Exception as e:
print(f"\nApplication error: {str(e)}")
finally:
if 'connection' in locals() and connection.is_connected():
cursor.close()
connection.close()
print("\nMySQL connection closed.")
if __name__ == "__main__":
main()