Spaces:
Runtime error
Runtime error
Ari
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,8 @@ import streamlit as st
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
from langchain import OpenAI, LLMChain, PromptTemplate
|
|
|
|
|
6 |
import sqlparse
|
7 |
import logging
|
8 |
|
@@ -11,7 +13,6 @@ if 'history' not in st.session_state:
|
|
11 |
st.session_state.history = []
|
12 |
|
13 |
# OpenAI API key (ensure it is securely stored)
|
14 |
-
# You can set the API key in your environment variables or a .env file
|
15 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
16 |
|
17 |
# Check if the API key is set
|
@@ -19,6 +20,18 @@ if not openai_api_key:
|
|
19 |
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
|
20 |
st.stop()
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# Step 1: Upload CSV data file (or use default)
|
23 |
st.title("Natural Language to SQL Query App with Enhanced Insights")
|
24 |
st.write("Upload a CSV file to get started, or use the default dataset.")
|
@@ -43,8 +56,8 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
|
|
43 |
valid_columns = list(data.columns)
|
44 |
st.write(f"Valid columns: {valid_columns}")
|
45 |
|
46 |
-
# Step 3: Set up the LLM Chains
|
47 |
-
# SQL Generation Chain
|
48 |
sql_template = """
|
49 |
You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
|
50 |
|
@@ -66,34 +79,22 @@ Valid columns: {columns}
|
|
66 |
SQL Query:
|
67 |
"""
|
68 |
sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
|
69 |
-
|
70 |
-
sql_generation_chain = LLMChain(llm=
|
71 |
-
|
72 |
-
# Insights Generation Chain
|
73 |
-
insights_template = """
|
74 |
-
You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
|
75 |
-
|
76 |
-
User's Question: {question}
|
77 |
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
"""
|
83 |
-
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
|
84 |
-
insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
|
85 |
-
|
86 |
-
# General Insights and Recommendations Chain
|
87 |
-
general_insights_template = """
|
88 |
-
You are an expert data scientist. Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words.
|
89 |
|
90 |
-
Dataset Summary:
|
91 |
-
{
|
92 |
|
93 |
-
Concise
|
94 |
-
"""
|
95 |
-
|
96 |
-
general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
|
97 |
|
98 |
# Optional: Clean up function to remove incorrect COLLATE NOCASE usage
|
99 |
def clean_sql_query(query):
|
@@ -130,7 +131,7 @@ def classify_query(question):
|
|
130 |
Category (SQL/INSIGHTS):
|
131 |
"""
|
132 |
classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
|
133 |
-
classification_chain = LLMChain(llm=
|
134 |
category = classification_chain.run({'question': question}).strip().upper()
|
135 |
if category.startswith('SQL'):
|
136 |
return 'SQL'
|
@@ -140,17 +141,7 @@ def classify_query(question):
|
|
140 |
# Function to generate dataset summary
|
141 |
def generate_dataset_summary(data):
|
142 |
"""Generate a summary of the dataset for general insights."""
|
143 |
-
|
144 |
-
You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
|
145 |
-
|
146 |
-
Dataset:
|
147 |
-
{data}
|
148 |
-
|
149 |
-
Dataset Summary:
|
150 |
-
"""
|
151 |
-
summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
|
152 |
-
summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
|
153 |
-
summary = summary_chain.run({'data': data.head().to_string(index=False)})
|
154 |
return summary
|
155 |
|
156 |
# Define the callback function
|
@@ -178,21 +169,9 @@ def process_input():
|
|
178 |
}).strip()
|
179 |
|
180 |
if generated_sql.upper() == "NO_SQL":
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
# Generate dataset summary
|
185 |
-
dataset_summary = generate_dataset_summary(data)
|
186 |
-
|
187 |
-
# Generate general insights and recommendations
|
188 |
-
general_insights = general_insights_chain.run({
|
189 |
-
'dataset_summary': dataset_summary
|
190 |
-
})
|
191 |
-
|
192 |
-
# Append the assistant's insights to the history
|
193 |
-
st.session_state.history.append({"role": "assistant", "content": general_insights})
|
194 |
else:
|
195 |
-
# Clean the SQL query
|
196 |
cleaned_sql = clean_sql_query(generated_sql)
|
197 |
logging.info(f"Generated SQL Query: {cleaned_sql}")
|
198 |
|
@@ -204,35 +183,18 @@ def process_input():
|
|
204 |
assistant_response = "The query returned no results. Please try a different question."
|
205 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
206 |
else:
|
207 |
-
#
|
208 |
-
result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
|
209 |
-
|
210 |
-
# Generate insights and recommendations based on the query result
|
211 |
-
insights = insights_chain.run({
|
212 |
-
'question': user_prompt,
|
213 |
-
'result': result_str
|
214 |
-
})
|
215 |
-
|
216 |
-
# Append the assistant's insights to the history
|
217 |
-
st.session_state.history.append({"role": "assistant", "content": insights})
|
218 |
-
# Append the result DataFrame to the history
|
219 |
st.session_state.history.append({"role": "assistant", "content": result})
|
|
|
220 |
except Exception as e:
|
221 |
logging.error(f"An error occurred during SQL execution: {e}")
|
222 |
assistant_response = f"Error executing SQL query: {e}"
|
223 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
224 |
else: # INSIGHTS category
|
225 |
-
# Generate dataset summary
|
226 |
dataset_summary = generate_dataset_summary(data)
|
|
|
|
|
227 |
|
228 |
-
# Generate general insights and recommendations
|
229 |
-
general_insights = general_insights_chain.run({
|
230 |
-
'dataset_summary': dataset_summary
|
231 |
-
})
|
232 |
-
|
233 |
-
# Append the assistant's insights to the history
|
234 |
-
st.session_state.history.append({"role": "assistant", "content": general_insights})
|
235 |
-
|
236 |
except Exception as e:
|
237 |
logging.error(f"An error occurred: {e}")
|
238 |
assistant_response = f"Error: {e}"
|
|
|
3 |
import pandas as pd
|
4 |
import sqlite3
|
5 |
from langchain import OpenAI, LLMChain, PromptTemplate
|
6 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
7 |
+
import torch
|
8 |
import sqlparse
|
9 |
import logging
|
10 |
|
|
|
13 |
st.session_state.history = []
|
14 |
|
15 |
# OpenAI API key (ensure it is securely stored)
|
|
|
16 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
17 |
|
18 |
# Check if the API key is set
|
|
|
20 |
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
|
21 |
st.stop()
|
22 |
|
23 |
+
# Load the LLaMA model and tokenizer
|
24 |
+
model_name = "huggingface/llama" # Replace with the actual LLaMA model name you want to use
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
llama_tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
27 |
+
llama_model = LlamaForCausalLM.from_pretrained(model_name).to(device)
|
28 |
+
|
29 |
+
# Function to generate responses using LLaMA
|
30 |
+
def generate_llama_response(prompt):
|
31 |
+
inputs = llama_tokenizer(prompt, return_tensors="pt").to(device)
|
32 |
+
outputs = llama_model.generate(inputs.input_ids, max_length=200)
|
33 |
+
return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
34 |
+
|
35 |
# Step 1: Upload CSV data file (or use default)
|
36 |
st.title("Natural Language to SQL Query App with Enhanced Insights")
|
37 |
st.write("Upload a CSV file to get started, or use the default dataset.")
|
|
|
56 |
valid_columns = list(data.columns)
|
57 |
st.write(f"Valid columns: {valid_columns}")
|
58 |
|
59 |
+
# Step 3: Set up the LLM Chains (SQL generation with OpenAI, insights with LLaMA)
|
60 |
+
# SQL Generation Chain with OpenAI
|
61 |
sql_template = """
|
62 |
You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
|
63 |
|
|
|
79 |
SQL Query:
|
80 |
"""
|
81 |
sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
|
82 |
+
sql_llm = OpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens=180)
|
83 |
+
sql_generation_chain = LLMChain(llm=sql_llm, prompt=sql_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
# General Insights and Recommendations Chain with LLaMA
|
86 |
+
def generate_insights_llama(question, data_summary):
|
87 |
+
insights_template = f"""
|
88 |
+
You are an expert data scientist. Based on the user's question and the dataset summary provided below, generate concise data insights and actionable recommendations.
|
89 |
|
90 |
+
User's Question: {question}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
Dataset Summary:
|
93 |
+
{data_summary}
|
94 |
|
95 |
+
Concise Insights and Recommendations:
|
96 |
+
"""
|
97 |
+
return generate_llama_response(insights_template)
|
|
|
98 |
|
99 |
# Optional: Clean up function to remove incorrect COLLATE NOCASE usage
|
100 |
def clean_sql_query(query):
|
|
|
131 |
Category (SQL/INSIGHTS):
|
132 |
"""
|
133 |
classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
|
134 |
+
classification_chain = LLMChain(llm=sql_llm, prompt=classification_prompt)
|
135 |
category = classification_chain.run({'question': question}).strip().upper()
|
136 |
if category.startswith('SQL'):
|
137 |
return 'SQL'
|
|
|
141 |
# Function to generate dataset summary
|
142 |
def generate_dataset_summary(data):
|
143 |
"""Generate a summary of the dataset for general insights."""
|
144 |
+
summary = f"Number of records: {len(data)}, Number of columns: {len(data.columns)}, Columns: {list(data.columns)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
return summary
|
146 |
|
147 |
# Define the callback function
|
|
|
169 |
}).strip()
|
170 |
|
171 |
if generated_sql.upper() == "NO_SQL":
|
172 |
+
assistant_response = "No SQL query could be generated."
|
173 |
+
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
else:
|
|
|
175 |
cleaned_sql = clean_sql_query(generated_sql)
|
176 |
logging.info(f"Generated SQL Query: {cleaned_sql}")
|
177 |
|
|
|
183 |
assistant_response = "The query returned no results. Please try a different question."
|
184 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
185 |
else:
|
186 |
+
# Display query results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
st.session_state.history.append({"role": "assistant", "content": result})
|
188 |
+
|
189 |
except Exception as e:
|
190 |
logging.error(f"An error occurred during SQL execution: {e}")
|
191 |
assistant_response = f"Error executing SQL query: {e}"
|
192 |
st.session_state.history.append({"role": "assistant", "content": assistant_response})
|
193 |
else: # INSIGHTS category
|
|
|
194 |
dataset_summary = generate_dataset_summary(data)
|
195 |
+
insights = generate_insights_llama(user_prompt, dataset_summary)
|
196 |
+
st.session_state.history.append({"role": "assistant", "content": insights})
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
except Exception as e:
|
199 |
logging.error(f"An error occurred: {e}")
|
200 |
assistant_response = f"Error: {e}"
|