LeonceNsh commited on
Commit
8cb3a33
·
verified ·
1 Parent(s): 0fd7668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -94
app.py CHANGED
@@ -1,109 +1,99 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def generate_prompt(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
6
- """
7
- Generates the prompt by reading the prompt template and table metadata,
8
- then formatting them with the user's question.
9
- """
10
- try:
11
- with open(prompt_file, "r") as f:
12
- prompt = f.read()
13
- except FileNotFoundError:
14
- return "Error: prompt.md file not found."
15
-
16
- try:
17
- with open(metadata_file, "r") as f:
18
- table_metadata_string = f.read()
19
- except FileNotFoundError:
20
- return "Error: metadata.sql file not found."
21
 
22
- prompt = prompt.format(
23
- user_question=question, table_metadata_string=table_metadata_string
24
- )
25
- return prompt
 
 
 
26
 
27
- def get_tokenizer_model(model_name):
28
- """
29
- Loads the tokenizer and model from the specified model repository.
30
- """
31
- tokenizer = AutoTokenizer.from_pretrained(model_name)
32
- model = AutoModelForCausalLM.from_pretrained(
33
- model_name,
34
- trust_remote_code=True, # Set to True if the model uses custom code
35
- torch_dtype=torch.float16,
36
- device_map="auto", # Automatically maps the model to available devices
37
- use_cache=True,
38
- )
39
- return tokenizer, model
40
 
41
- # Load the tokenizer and model once when the script starts
42
- model_name = "defog/sqlcoder-7b-2" # Replace with your model name
43
- print("Loading model and tokenizer...")
44
- tokenizer, model = get_tokenizer_model(model_name)
45
- print("Model and tokenizer loaded successfully.")
46
 
47
- # Initialize the text generation pipeline
48
- text_gen_pipeline = pipeline(
49
- "text-generation",
50
- model=model,
51
- tokenizer=tokenizer,
52
- max_new_tokens=300,
53
- do_sample=False, # Disable sampling for deterministic output
54
- return_full_text=False,
55
- num_beams=5, # Use beam search for better quality
56
- )
57
 
58
- def run_inference_gradio(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
59
- """
60
- Generates an SQL query based on the user's natural language question.
61
- """
62
- if not question.strip():
63
- return "Please enter a valid question."
 
 
 
64
 
65
- prompt = generate_prompt(question, prompt_file, metadata_file)
 
 
 
66
 
67
- if prompt.startswith("Error:"):
68
- return prompt # Return the error message if files are missing
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- eos_token_id = tokenizer.eos_token_id
71
- try:
72
- generated = text_gen_pipeline(
73
- prompt,
74
- num_return_sequences=1,
75
- eos_token_id=eos_token_id,
76
- pad_token_id=eos_token_id,
 
77
  )
78
- except Exception as e:
79
- return f"Error during model inference: {str(e)}"
80
 
81
- generated_text = generated[0]["generated_text"]
82
-
83
- # Extract the SQL query from the generated text
84
- sql_query = generated_text.split(";")[0].split("```")[0].strip() + ";"
85
- return sql_query
86
 
87
- # Define the Gradio interface
88
- iface = gr.Interface(
89
- fn=run_inference_gradio,
90
- inputs=gr.Textbox(
91
- lines=4,
92
- placeholder="Enter your natural language question here...",
93
- label="Question"
94
- ),
95
- outputs=gr.Textbox(label="Generated SQL Query"),
96
- title="Text-to-SQL Generator",
97
- description=(
98
- "Enter a natural language question related to your database, and this tool "
99
- "will generate the corresponding SQL query. Ensure that 'prompt.md' and "
100
- "'metadata.sql' are correctly set up in the application directory."
101
- ),
102
- examples=[
103
- ["Do we get more sales from customers in New York compared to customers in San Francisco? Give me the total sales for each city, and the difference between the two."]
104
- ],
105
- allow_flagging="never"
106
- )
107
 
108
  if __name__ == "__main__":
109
- iface.launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import pyperclip
4
+ import openai
5
+ import os
6
+ import pandas as pd
7
+ from sqlalchemy import create_engine, inspect
8
+ from llama_index.legacy import (
9
+ VectorStoreIndex,
10
+ SQLDatabase,
11
+ ServiceContext,
12
+ )
13
+ from llama_index.legacy.indices.struct_store import NLSQLTableQueryEngine
14
+ from llama_index.legacy.llms import OpenAI
15
+ import sqlite3
16
 
17
+ # Set up OpenAI API Key
18
+ os.environ['OPENAI_API_KEY'] = "YOUR_OPENAI_API_KEY"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Function to load database and LLM
21
+ def load_db_llm():
22
+ engine = create_engine("sqlite:///gov-contracts.db")
23
+ sql_database = SQLDatabase(engine)
24
+ llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo-1106")
25
+ service_context = ServiceContext.from_defaults(llm=llm, embed_model="local")
26
+ return sql_database, service_context, engine
27
 
28
+ # Load LLM and database context
29
+ sql_database, service_context, engine = load_db_llm()
30
+ query_engine = NLSQLTableQueryEngine(
31
+ sql_database=sql_database, synthesize_response=True, service_context=service_context
32
+ )
 
 
 
 
 
 
 
 
33
 
34
+ # Initialize table schema and connection for query
35
+ inspector = inspect(engine)
36
+ table_names = inspector.get_table_names()
 
 
37
 
38
+ # Load table data function
39
+ def get_table_data(table_name):
40
+ conn = sqlite3.connect('gov-contracts.db')
41
+ query = f"SELECT * FROM {table_name}"
42
+ df = pd.read_sql_query(query, conn)
43
+ conn.close()
44
+ return df
 
 
 
45
 
46
+ # Chat-based interaction for Gradio
47
+ def generate_response(user_input, selected_table=None, example_prompt=None):
48
+ if example_prompt:
49
+ user_input = example_prompt
50
+
51
+ response = query_engine.query(f"User Question: {user_input}")
52
+ sql_query = f"```sql\n{response.metadata['sql_query']}\n```\n**Response:**\n{response.response}\n"
53
+ pyperclip.copy(sql_query) # Optional: Copy to clipboard
54
+ return sql_query
55
 
56
+ # Define Gradio app layout and components
57
+ with gr.Blocks() as gradio_app:
58
+ gr.Markdown("## Natural Language to SQL Query Application")
59
+ gr.Markdown("### Ask a question about the data in the database to receive a precise SQL query.")
60
 
61
+ # Sidebar: Database schema and example prompts
62
+ with gr.Row():
63
+ with gr.Column():
64
+ table_dropdown = gr.Dropdown(choices=table_names, label="Select a Table")
65
+ example_prompt_box = gr.Radio(
66
+ choices=[
67
+ "Return the department_ind_agency and the sum of award in descending order",
68
+ "Return the sum of award in descending order grouped by type limited to the top 10",
69
+ "Return the sum of award by year where the sub_tier is the FEDERAL ACQUISITION SERVICE"
70
+ ],
71
+ label="Select an Example Prompt"
72
+ )
73
+ query_btn = gr.Button("Generate Query")
74
 
75
+ with gr.Column():
76
+ user_query = gr.Textbox(
77
+ label="Enter your natural language query about the database",
78
+ placeholder="Ask your question here..."
79
+ )
80
+ chat_output = gr.Textbox(
81
+ label="Generated SQL Query",
82
+ placeholder="SQL query will appear here..."
83
  )
 
 
84
 
85
+ # Function to call on click
86
+ def query_callback(user_input, table_name, example_prompt):
87
+ return generate_response(user_input, selected_table=table_name, example_prompt=example_prompt)
 
 
88
 
89
+ # Button click event
90
+ query_btn.click(query_callback, inputs=[user_query, table_dropdown, example_prompt_box], outputs=chat_output)
91
+
92
+ gr.Markdown("#### Created by Leonce Nshuti")
93
+ gr.Markdown("""
94
+ - [LinkedIn](https://www.linkedin.com/in/leoncenshuti/)
95
+ - [GitHub](https://github.com/LNshuti)
96
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
+ gradio_app.launch()