Spaces:
Running
Running
import mysql.connector | |
from mysql.connector import Error | |
import requests | |
import json | |
import os | |
def generate_sql_query(natural_language_query, schema_info, space_url): | |
"""Generate SQL query using Hugging Face Space API.""" | |
# Construct a more structured prompt | |
prompt = f"""Given this SQL table schema: | |
{schema_info} | |
Write a SQL query that will: | |
{natural_language_query} | |
The query should be valid MySQL syntax and include only the SELECT statement.""" | |
# Make API request to the Hugging Face Space | |
payload = { | |
"inputs": prompt, | |
"options": { | |
"use_cache": False | |
} | |
} | |
try: | |
response = requests.post(space_url, json=payload) | |
if response.status_code == 200: | |
return response.json().get('generated_text', '').strip() | |
else: | |
raise Exception(f"API request failed: {response.text}") | |
except Exception as e: | |
print(f"API Error: {str(e)}") | |
return None | |
def main(): | |
try: | |
# Define the Hugging Face Space URL | |
space_url = "https://huggingface.co/spaces/nileshhanotia/sql" | |
# Define your schema information | |
schema_info = """ | |
CREATE TABLE sales ( | |
pizza_id DECIMAL(8,2) PRIMARY KEY, | |
order_id DECIMAL(8,2), | |
pizza_name_id VARCHAR(14), | |
quantity DECIMAL(4,2), | |
order_date DATE, | |
order_time VARCHAR(8), | |
unit_price DECIMAL(5,2), | |
total_price DECIMAL(5,2), | |
pizza_size VARCHAR(3), | |
pizza_category VARCHAR(7), | |
pizza_ingredients VARCHAR(97), | |
pizza_name VARCHAR(42) | |
); | |
""" | |
# Establish connection to the database | |
connection = mysql.connector.connect( | |
host="localhost", | |
database="pizza", | |
user="root", | |
password="root", | |
port=8889 | |
) | |
if connection.is_connected(): | |
cursor = connection.cursor() | |
print("Database connected successfully!") | |
while True: | |
try: | |
# Get user input | |
print("\nEnter your question (or 'exit' to quit):") | |
natural_language_query = input("> ").strip() | |
if natural_language_query.lower() == 'exit': | |
break | |
# Generate and execute query | |
sql_query = generate_sql_query(natural_language_query, schema_info, space_url) | |
if sql_query: | |
print(f"\nExecuting SQL Query:\n{sql_query}") | |
cursor.execute(sql_query) | |
records = cursor.fetchall() | |
# Print results | |
if records: | |
print("\nResults:") | |
# Get column names | |
columns = [desc[0] for desc in cursor.description] | |
print(" | ".join(columns)) | |
print("-" * (len(" | ".join(columns)) + 10)) | |
for row in records: | |
print(" | ".join(str(val) for val in row)) | |
else: | |
print("\nNo results found.") | |
except KeyboardInterrupt: | |
print("\nOperation cancelled by user.") | |
continue | |
except Exception as e: | |
print(f"\nError: {str(e)}") | |
continue | |
except Error as e: | |
print(f"\nDatabase error: {str(e)}") | |
except Exception as e: | |
print(f"\nApplication error: {str(e)}") | |
finally: | |
if 'connection' in locals() and connection.is_connected(): | |
cursor.close() | |
connection.close() | |
print("\nMySQL connection closed.") | |
if __name__ == "__main__": | |
main() | |