Ari commited on
Commit
1c810c3
·
verified ·
1 Parent(s): 2e33de0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -25
app.py CHANGED
@@ -9,7 +9,9 @@ import logging
9
  from sklearn.linear_model import LinearRegression
10
  from sklearn.model_selection import train_test_split
11
  from sklearn.metrics import mean_squared_error, r2_score
12
- import statsmodels.api as sm
 
 
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -64,33 +66,11 @@ Instructions:
64
 
65
  - If the question involves data retrieval or simple aggregations, generate a SQL query.
66
  - If the question requires statistical analysis or time series analysis, generate a Python code snippet using pandas, numpy, and statsmodels.
67
- - If the question involves predictions or modeling, generate a Python code snippet using scikit-learn.
68
  - Ensure that you only use the columns provided.
69
  - Do not include any import statements in the code.
70
  - Provide the code between <CODE> and </CODE> tags.
71
 
72
- Examples:
73
-
74
- User Query: Calculate the average sales.
75
- Response:
76
- <CODE>
77
- result = data['Sales'].mean()
78
- </CODE>
79
-
80
- User Query: Total sales for Product A
81
- Response:
82
- <CODE>
83
- SELECT SUM(Sales) AS "Total sales" FROM {table_name} WHERE Product = 'Product A' COLLATE NOCASE
84
- </CODE>
85
-
86
- User Query: Show me the sales trend over time.
87
- Response:
88
- <CODE>
89
- data['Date'] = pd.to_datetime(data['Date'])
90
- sales_by_date = data.groupby('Date')['Sales'].sum().reset_index()
91
- result = sales_by_date
92
- </CODE>
93
-
94
  Question: {question}
95
 
96
  Table name: {table_name}
@@ -100,6 +80,9 @@ Valid columns: {columns}
100
  Response:
101
  """
102
 
 
 
 
103
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
104
 
105
  # Set up the LLM Chain
@@ -140,8 +123,10 @@ def execute_code(code):
140
  'train_test_split': train_test_split,
141
  'mean_squared_error': mean_squared_error,
142
  'r2_score': r2_score,
143
- 'sm': sm
 
144
  }
 
145
  exec(code, {}, local_vars)
146
  result = local_vars.get('result')
147
  return result
@@ -214,3 +199,7 @@ for message in st.session_state.history:
214
  st.markdown(f"**Assistant:** {content}")
215
  else:
216
  st.markdown(f"**Assistant:** {content}")
 
 
 
 
 
9
  from sklearn.linear_model import LinearRegression
10
  from sklearn.model_selection import train_test_split
11
  from sklearn.metrics import mean_squared_error, r2_score
12
+ import statsmodels.api as sm # For time series analysis
13
+ from sklearn.metrics.pairwise import cosine_similarity # For recommendations
14
+
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
 
66
 
67
  - If the question involves data retrieval or simple aggregations, generate a SQL query.
68
  - If the question requires statistical analysis or time series analysis, generate a Python code snippet using pandas, numpy, and statsmodels.
69
+ - If the question involves predictions, modeling, or recommendations, generate a Python code snippet using scikit-learn or pandas.
70
  - Ensure that you only use the columns provided.
71
  - Do not include any import statements in the code.
72
  - Provide the code between <CODE> and </CODE> tags.
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  Question: {question}
75
 
76
  Table name: {table_name}
 
80
  Response:
81
  """
82
 
83
+
84
+
85
+
86
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
87
 
88
  # Set up the LLM Chain
 
123
  'train_test_split': train_test_split,
124
  'mean_squared_error': mean_squared_error,
125
  'r2_score': r2_score,
126
+ 'sm': sm, # Added statsmodels
127
+ 'cosine_similarity': cosine_similarity # Added cosine_similarity
128
  }
129
+
130
  exec(code, {}, local_vars)
131
  result = local_vars.get('result')
132
  return result
 
199
  st.markdown(f"**Assistant:** {content}")
200
  else:
201
  st.markdown(f"**Assistant:** {content}")
202
+
203
+ # Place the text input after displaying the conversation
204
+ st.text_input("Enter your question:", key='user_input', on_change=process_input)
205
+