Ari commited on
Commit
5ad9e6e
·
verified ·
1 Parent(s): e14b81b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -74
app.py CHANGED
@@ -3,6 +3,8 @@ import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  from langchain import OpenAI, LLMChain, PromptTemplate
 
 
6
  import sqlparse
7
  import logging
8
 
@@ -11,7 +13,6 @@ if 'history' not in st.session_state:
11
  st.session_state.history = []
12
 
13
  # OpenAI API key (ensure it is securely stored)
14
- # You can set the API key in your environment variables or a .env file
15
  openai_api_key = os.getenv("OPENAI_API_KEY")
16
 
17
  # Check if the API key is set
@@ -19,6 +20,18 @@ if not openai_api_key:
19
  st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
20
  st.stop()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Step 1: Upload CSV data file (or use default)
23
  st.title("Natural Language to SQL Query App with Enhanced Insights")
24
  st.write("Upload a CSV file to get started, or use the default dataset.")
@@ -43,8 +56,8 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
43
  valid_columns = list(data.columns)
44
  st.write(f"Valid columns: {valid_columns}")
45
 
46
- # Step 3: Set up the LLM Chains
47
- # SQL Generation Chain
48
  sql_template = """
49
  You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
50
 
@@ -66,34 +79,22 @@ Valid columns: {columns}
66
  SQL Query:
67
  """
68
  sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
69
- llm = OpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens = 180)
70
- sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
71
-
72
- # Insights Generation Chain
73
- insights_template = """
74
- You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
75
-
76
- User's Question: {question}
77
 
78
- SQL Query Result:
79
- {result}
 
 
80
 
81
- Concise Analysis (max 200 words):
82
- """
83
- insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
84
- insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
85
-
86
- # General Insights and Recommendations Chain
87
- general_insights_template = """
88
- You are an expert data scientist. Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words.
89
 
90
- Dataset Summary:
91
- {dataset_summary}
92
 
93
- Concise Analysis and Recommendations (max 150 words):
94
- """
95
- general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary'])
96
- general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
97
 
98
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
99
  def clean_sql_query(query):
@@ -130,7 +131,7 @@ def classify_query(question):
130
  Category (SQL/INSIGHTS):
131
  """
132
  classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
133
- classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
134
  category = classification_chain.run({'question': question}).strip().upper()
135
  if category.startswith('SQL'):
136
  return 'SQL'
@@ -140,17 +141,7 @@ def classify_query(question):
140
  # Function to generate dataset summary
141
  def generate_dataset_summary(data):
142
  """Generate a summary of the dataset for general insights."""
143
- summary_template = """
144
- You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
145
-
146
- Dataset:
147
- {data}
148
-
149
- Dataset Summary:
150
- """
151
- summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
152
- summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
153
- summary = summary_chain.run({'data': data.head().to_string(index=False)})
154
  return summary
155
 
156
  # Define the callback function
@@ -178,21 +169,9 @@ def process_input():
178
  }).strip()
179
 
180
  if generated_sql.upper() == "NO_SQL":
181
- # Handle cases where no SQL should be generated
182
- assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
183
-
184
- # Generate dataset summary
185
- dataset_summary = generate_dataset_summary(data)
186
-
187
- # Generate general insights and recommendations
188
- general_insights = general_insights_chain.run({
189
- 'dataset_summary': dataset_summary
190
- })
191
-
192
- # Append the assistant's insights to the history
193
- st.session_state.history.append({"role": "assistant", "content": general_insights})
194
  else:
195
- # Clean the SQL query
196
  cleaned_sql = clean_sql_query(generated_sql)
197
  logging.info(f"Generated SQL Query: {cleaned_sql}")
198
 
@@ -204,35 +183,18 @@ def process_input():
204
  assistant_response = "The query returned no results. Please try a different question."
205
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
206
  else:
207
- # Convert the result to a string for the insights prompt
208
- result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
209
-
210
- # Generate insights and recommendations based on the query result
211
- insights = insights_chain.run({
212
- 'question': user_prompt,
213
- 'result': result_str
214
- })
215
-
216
- # Append the assistant's insights to the history
217
- st.session_state.history.append({"role": "assistant", "content": insights})
218
- # Append the result DataFrame to the history
219
  st.session_state.history.append({"role": "assistant", "content": result})
 
220
  except Exception as e:
221
  logging.error(f"An error occurred during SQL execution: {e}")
222
  assistant_response = f"Error executing SQL query: {e}"
223
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
224
  else: # INSIGHTS category
225
- # Generate dataset summary
226
  dataset_summary = generate_dataset_summary(data)
 
 
227
 
228
- # Generate general insights and recommendations
229
- general_insights = general_insights_chain.run({
230
- 'dataset_summary': dataset_summary
231
- })
232
-
233
- # Append the assistant's insights to the history
234
- st.session_state.history.append({"role": "assistant", "content": general_insights})
235
-
236
  except Exception as e:
237
  logging.error(f"An error occurred: {e}")
238
  assistant_response = f"Error: {e}"
 
3
  import pandas as pd
4
  import sqlite3
5
  from langchain import OpenAI, LLMChain, PromptTemplate
6
+ from transformers import LlamaForCausalLM, LlamaTokenizer
7
+ import torch
8
  import sqlparse
9
  import logging
10
 
 
13
  st.session_state.history = []
14
 
15
  # OpenAI API key (ensure it is securely stored)
 
16
  openai_api_key = os.getenv("OPENAI_API_KEY")
17
 
18
  # Check if the API key is set
 
20
  st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
21
  st.stop()
22
 
23
+ # Load the LLaMA model and tokenizer
24
+ model_name = "huggingface/llama" # Replace with the actual LLaMA model name you want to use
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ llama_tokenizer = LlamaTokenizer.from_pretrained(model_name)
27
+ llama_model = LlamaForCausalLM.from_pretrained(model_name).to(device)
28
+
29
+ # Function to generate responses using LLaMA
30
+ def generate_llama_response(prompt):
31
+ inputs = llama_tokenizer(prompt, return_tensors="pt").to(device)
32
+ outputs = llama_model.generate(inputs.input_ids, max_length=200)
33
+ return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+
35
  # Step 1: Upload CSV data file (or use default)
36
  st.title("Natural Language to SQL Query App with Enhanced Insights")
37
  st.write("Upload a CSV file to get started, or use the default dataset.")
 
56
  valid_columns = list(data.columns)
57
  st.write(f"Valid columns: {valid_columns}")
58
 
59
+ # Step 3: Set up the LLM Chains (SQL generation with OpenAI, insights with LLaMA)
60
+ # SQL Generation Chain with OpenAI
61
  sql_template = """
62
  You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
63
 
 
79
  SQL Query:
80
  """
81
  sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
82
+ sql_llm = OpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens=180)
83
+ sql_generation_chain = LLMChain(llm=sql_llm, prompt=sql_prompt)
 
 
 
 
 
 
84
 
85
+ # General Insights and Recommendations Chain with LLaMA
86
+ def generate_insights_llama(question, data_summary):
87
+ insights_template = f"""
88
+ You are an expert data scientist. Based on the user's question and the dataset summary provided below, generate concise data insights and actionable recommendations.
89
 
90
+ User's Question: {question}
 
 
 
 
 
 
 
91
 
92
+ Dataset Summary:
93
+ {data_summary}
94
 
95
+ Concise Insights and Recommendations:
96
+ """
97
+ return generate_llama_response(insights_template)
 
98
 
99
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
100
  def clean_sql_query(query):
 
131
  Category (SQL/INSIGHTS):
132
  """
133
  classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
134
+ classification_chain = LLMChain(llm=sql_llm, prompt=classification_prompt)
135
  category = classification_chain.run({'question': question}).strip().upper()
136
  if category.startswith('SQL'):
137
  return 'SQL'
 
141
  # Function to generate dataset summary
142
  def generate_dataset_summary(data):
143
  """Generate a summary of the dataset for general insights."""
144
+ summary = f"Number of records: {len(data)}, Number of columns: {len(data.columns)}, Columns: {list(data.columns)}"
 
 
 
 
 
 
 
 
 
 
145
  return summary
146
 
147
  # Define the callback function
 
169
  }).strip()
170
 
171
  if generated_sql.upper() == "NO_SQL":
172
+ assistant_response = "No SQL query could be generated."
173
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
 
 
 
 
 
 
 
 
 
 
174
  else:
 
175
  cleaned_sql = clean_sql_query(generated_sql)
176
  logging.info(f"Generated SQL Query: {cleaned_sql}")
177
 
 
183
  assistant_response = "The query returned no results. Please try a different question."
184
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
185
  else:
186
+ # Display query results
 
 
 
 
 
 
 
 
 
 
 
187
  st.session_state.history.append({"role": "assistant", "content": result})
188
+
189
  except Exception as e:
190
  logging.error(f"An error occurred during SQL execution: {e}")
191
  assistant_response = f"Error executing SQL query: {e}"
192
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
193
  else: # INSIGHTS category
 
194
  dataset_summary = generate_dataset_summary(data)
195
+ insights = generate_insights_llama(user_prompt, dataset_summary)
196
+ st.session_state.history.append({"role": "assistant", "content": insights})
197
 
 
 
 
 
 
 
 
 
198
  except Exception as e:
199
  logging.error(f"An error occurred: {e}")
200
  assistant_response = f"Error: {e}"