sango07 commited on
Commit
00d3899
·
verified ·
1 Parent(s): d30e36b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -16,38 +16,35 @@ from langchain_core.output_parsers import StrOutputParser
16
  from langchain_groq import ChatGroq
17
  from langchain_community.callbacks import get_openai_callback
18
  import os
19
-
20
- os.environ["GROQ_API_KEY"] = "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"
21
- llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
22
 
23
 
24
  def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase:
 
25
  try:
26
  db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
27
  if sslmode:
28
  db_uri += f"?sslmode={sslmode}"
29
 
30
- # Attempt to create a database connection
31
  db = SQLDatabase.from_uri(db_uri)
32
  return db
33
-
34
  except Exception as e:
35
  st.error("Unable to connect to the database. Please check your credentials and try again.")
36
  st.stop() # Stop further execution if an error occurs
37
 
38
-
39
- def answer_sql(question: str, db: SQLDatabase, chat_history: list):
40
  try:
41
- # setup llm
42
- llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
43
-
44
- prompt = PromptTemplate(input_variables=['input', 'table_info', 'top_k'],
45
- template="""You are a PostgreSQL expert. Given an input question,
46
  first create a syntactically correct PostgreSQL query to run,
47
  then look at the results of the query and return the answer to the input question.
48
  Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
49
  Wrap each column name in double quotes (") to denote them as delimited identifiers.
50
- Only use the following tables:\n{table_info}\n\nQuestion: {input}')""")
 
51
 
52
  QUERY = f"""
53
  Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult.
@@ -55,7 +52,13 @@ def answer_sql(question: str, db: SQLDatabase, chat_history: list):
55
  """
56
 
57
  db_chain = SQLDatabaseChain(
58
- llm=llm, database=db, top_k=100, verbose=True, use_query_checker=True, prompt=prompt, return_intermediate_steps=True
 
 
 
 
 
 
59
  )
60
 
61
  with get_openai_callback() as cb:
@@ -72,12 +75,10 @@ def answer_sql(question: str, db: SQLDatabase, chat_history: list):
72
  print("*" * 55)
73
 
74
  return response
75
-
76
  except Exception as e:
77
  st.error("A technical error occurred. Please try again later.")
78
  st.stop()
79
 
80
-
81
  if "chat_history" not in st.session_state:
82
  st.session_state.chat_history = [
83
  AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."),
@@ -87,11 +88,15 @@ st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:"
87
  st.title("Chat with Postgres DB")
88
  st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
89
 
90
- # Step 1: Prompt user to select database type (local or cloud)
91
  with st.sidebar:
92
- st.subheader("Database Setup")
 
 
 
 
 
93
  db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud"))
94
-
95
  if db_type == "Local":
96
  st.write("Enter your local database credentials.")
97
  host = st.text_input("Host", value="localhost")
@@ -99,29 +104,28 @@ with st.sidebar:
99
  user = st.text_input("User", value="postgres")
100
  password = st.text_input("Password", type="password")
101
  database = st.text_input("Database", value="testing_3")
102
-
103
- # Connect Button
104
- if st.button("Connect"):
105
- with st.spinner("Connecting to the local database..."):
106
- db = init_database(user, password, host, port, database)
107
- st.session_state.db = db
108
- st.success("Connected to local database!")
109
-
110
  elif db_type == "Cloud":
111
  st.write("Enter your cloud database credentials.")
112
  host = st.text_input("Host (e.g., your-db-host.aws.com)")
113
- port = st.text_input("Port (default: 5432)", value="5432")
114
  user = st.text_input("User")
115
  password = st.text_input("Password", type="password")
116
  database = st.text_input("Database")
117
  sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"])
118
 
119
- # Connect Button
120
- if st.button("Connect"):
121
- with st.spinner("Connecting to the cloud database..."):
122
- db = init_database(user, password, host, port, database, sslmode)
 
 
123
  st.session_state.db = db
124
- st.success("Connected to cloud database!")
 
 
 
 
 
125
 
126
  # Main chat interface
127
  for message in st.session_state.chat_history:
@@ -139,7 +143,7 @@ if user_query:
139
  st.markdown(user_query)
140
 
141
  with st.chat_message("AI"):
142
- response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history)
143
  st.markdown(response)
144
 
145
- st.session_state.chat_history.append(AIMessage(content=response))
 
16
  from langchain_groq import ChatGroq
17
  from langchain_community.callbacks import get_openai_callback
18
  import os
19
+ from langchain_openai import ChatOpenAI
20
+ # llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo")
 
21
 
22
 
23
  def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase:
24
+ """Initialize a connection to the PostgreSQL database."""
25
  try:
26
  db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
27
  if sslmode:
28
  db_uri += f"?sslmode={sslmode}"
29
 
 
30
  db = SQLDatabase.from_uri(db_uri)
31
  return db
 
32
  except Exception as e:
33
  st.error("Unable to connect to the database. Please check your credentials and try again.")
34
  st.stop() # Stop further execution if an error occurs
35
 
36
+ def answer_sql(question: str, db: SQLDatabase, chat_history: list, llm) -> str:
37
+ """Generate SQL answer based on the user's question and database content."""
38
  try:
39
+ prompt = PromptTemplate(
40
+ input_variables=['input', 'table_info', 'top_k'],
41
+ template="""You are a PostgreSQL expert. Given an input question,
 
 
42
  first create a syntactically correct PostgreSQL query to run,
43
  then look at the results of the query and return the answer to the input question.
44
  Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
45
  Wrap each column name in double quotes (") to denote them as delimited identifiers.
46
+ Only use the following tables:\n{table_info}\n\nQuestion: {input}')"""
47
+ )
48
 
49
  QUERY = f"""
50
  Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult.
 
52
  """
53
 
54
  db_chain = SQLDatabaseChain(
55
+ llm=llm,
56
+ database=db,
57
+ top_k=100,
58
+ verbose=True,
59
+ use_query_checker=True,
60
+ prompt=prompt,
61
+ return_intermediate_steps=True
62
  )
63
 
64
  with get_openai_callback() as cb:
 
75
  print("*" * 55)
76
 
77
  return response
 
78
  except Exception as e:
79
  st.error("A technical error occurred. Please try again later.")
80
  st.stop()
81
 
 
82
  if "chat_history" not in st.session_state:
83
  st.session_state.chat_history = [
84
  AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."),
 
88
  st.title("Chat with Postgres DB")
89
  st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
90
 
91
+ # Get API key from user
92
  with st.sidebar:
93
+ st.subheader("API Key and Database Credentials")
94
+
95
+ # Take OpenAI API key from the user
96
+ openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password")
97
+
98
+ # Database connection fields
99
  db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud"))
 
100
  if db_type == "Local":
101
  st.write("Enter your local database credentials.")
102
  host = st.text_input("Host", value="localhost")
 
104
  user = st.text_input("User", value="postgres")
105
  password = st.text_input("Password", type="password")
106
  database = st.text_input("Database", value="testing_3")
 
 
 
 
 
 
 
 
107
  elif db_type == "Cloud":
108
  st.write("Enter your cloud database credentials.")
109
  host = st.text_input("Host (e.g., your-db-host.aws.com)")
110
+ port = st.text_input("Port", value="5432")
111
  user = st.text_input("User")
112
  password = st.text_input("Password", type="password")
113
  database = st.text_input("Database")
114
  sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"])
115
 
116
+ if st.button("Connect"):
117
+ if openai_api_key:
118
+ os.environ["OPENAI_API_KEY"] = openai_api_key # Set the OpenAI API key in the environment
119
+ llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo") # Initialize model with user's API key
120
+ try:
121
+ db = init_database(user, password, host, port, database, sslmode if db_type == "Cloud" else None)
122
  st.session_state.db = db
123
+ st.session_state.llm = llm
124
+ st.success("Connected to the database!")
125
+ except Exception as e:
126
+ st.error("Failed to connect to the database. Please check your details and try again.")
127
+ else:
128
+ st.error("Please enter your OpenAI API key.")
129
 
130
  # Main chat interface
131
  for message in st.session_state.chat_history:
 
143
  st.markdown(user_query)
144
 
145
  with st.chat_message("AI"):
146
+ response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history, st.session_state.llm)
147
  st.markdown(response)
148
 
149
+ st.session_state.chat_history.append(AIMessage(content=response))