File size: 4,009 Bytes
7794a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import mysql.connector
from mysql.connector import Error
import requests
import json
import os

def generate_sql_query(natural_language_query, schema_info, space_url):
    """Generate SQL query using Hugging Face Space API."""
    
    # Construct a more structured prompt
    prompt = f"""Given this SQL table schema:
{schema_info}

Write a SQL query that will:
{natural_language_query}

The query should be valid MySQL syntax and include only the SELECT statement."""

    # Make API request to the Hugging Face Space
    payload = {
        "inputs": prompt,
        "options": {
            "use_cache": False
        }
    }

    try:
        response = requests.post(space_url, json=payload)
        if response.status_code == 200:
            return response.json().get('generated_text', '').strip()
        else:
            raise Exception(f"API request failed: {response.text}")
    except Exception as e:
        print(f"API Error: {str(e)}")
        return None

def main():
    try:
        # Define the Hugging Face Space URL
        space_url = "https://huggingface.co/spaces/nileshhanotia/sql"

        # Define your schema information
        schema_info = """
        CREATE TABLE sales (
          pizza_id DECIMAL(8,2) PRIMARY KEY,
          order_id DECIMAL(8,2),
          pizza_name_id VARCHAR(14),
          quantity DECIMAL(4,2),
          order_date DATE,
          order_time VARCHAR(8),
          unit_price DECIMAL(5,2),
          total_price DECIMAL(5,2),
          pizza_size VARCHAR(3),
          pizza_category VARCHAR(7),
          pizza_ingredients VARCHAR(97),
          pizza_name VARCHAR(42)
        );
        """

        # Establish connection to the database
        connection = mysql.connector.connect(
            host="localhost",
            database="pizza",
            user="root",
            password="root",
            port=8889
        )

        if connection.is_connected():
            cursor = connection.cursor()
            print("Database connected successfully!")

            while True:
                try:
                    # Get user input
                    print("\nEnter your question (or 'exit' to quit):")
                    natural_language_query = input("> ").strip()
                    
                    if natural_language_query.lower() == 'exit':
                        break
                    
                    # Generate and execute query
                    sql_query = generate_sql_query(natural_language_query, schema_info, space_url)
                    
                    if sql_query:
                        print(f"\nExecuting SQL Query:\n{sql_query}")
                        cursor.execute(sql_query)
                        records = cursor.fetchall()
                        
                        # Print results
                        if records:
                            print("\nResults:")
                            # Get column names
                            columns = [desc[0] for desc in cursor.description]
                            print(" | ".join(columns))
                            print("-" * (len(" | ".join(columns)) + 10))
                            for row in records:
                                print(" | ".join(str(val) for val in row))
                        else:
                            print("\nNo results found.")
                    
                except KeyboardInterrupt:
                    print("\nOperation cancelled by user.")
                    continue
                except Exception as e:
                    print(f"\nError: {str(e)}")
                    continue

    except Error as e:
        print(f"\nDatabase error: {str(e)}")
    except Exception as e:
        print(f"\nApplication error: {str(e)}")
    finally:
        if 'connection' in locals() and connection.is_connected():
            cursor.close()
            connection.close()
            print("\nMySQL connection closed.")

if __name__ == "__main__":
    main()