Lalit Mahale commited on
Commit
db92763
·
unverified ·
0 Parent(s):

Add files via upload

Browse files
Files changed (5) hide show
  1. app.py +31 -0
  2. config.py +8 -0
  3. prompt.py +14 -0
  4. requirements.txt +5 -0
  5. utils.py +109 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import Chain
3
+ from utils import DB
4
+ st.set_page_config(page_title="💬 Chat_to_DB")
5
+
6
+ # st.title(":red[Chat] to :red[Database]")
7
+ st.markdown("<h1 style='text-align: center;'>Chat to Database</h1>", unsafe_allow_html=True)
8
+
9
+ st.sidebar.subheader("See table")
10
+ row = st.sidebar.number_input("Enter Number of rows", min_value=5,step=1)
11
+ st.sidebar.write(DB().see_table(rows = row))
12
+
13
+ if "messages" not in st.session_state.keys():
14
+ st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
15
+
16
+ for message in st.session_state.messages:
17
+ with st.chat_message(message["role"]):
18
+ st.write(message["content"])
19
+
20
+ if prompt := st.chat_input():
21
+ st.session_state.messages.append({"role": "user", "content": prompt})
22
+ with st.chat_message("user"):
23
+ st.write(prompt)
24
+
25
+ if st.session_state.messages[-1]["role"] != "assistant":
26
+ with st.chat_message("assistant"):
27
+ with st.spinner("Thinking..."):
28
+ response = Chain().final_sql(prompt)
29
+ st.write(response)
30
+ message = {"role": "assistant", "content": response}
31
+ st.session_state.messages.append(message)
config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ db_configuration = {"USER" : '',
2
+ "PASSWORD" : "",
3
+ "PORT" : "",
4
+ "DB" : "",
5
+ "HOST" :""
6
+ }
7
+
8
+ API_KEY = ''
prompt.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def get_response():
4
+ return """
5
+ You are a nice chatbot who have nice converstion with human.
6
+ You have to understand user question and database response and give the proper, easy to understand.\n\n
7
+ user_query : {question}\n\n
8
+ database_response : {db_res}
9
+
10
+ Last converstion :
11
+ {last_conversion}
12
+
13
+ Response:
14
+ """
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ langchain-google-genai
2
+ streamlit
3
+ python-dotenv
4
+ pandas
5
+ sqlalchemy
utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_google_genai import GoogleGenerativeAI
3
+ from langchain_community.utilities import SQLDatabase
4
+ from dotenv import load_dotenv
5
+ from config import db_configuration, API_KEY
6
+ from prompt import get_response
7
+ from langchain_experimental.sql.base import SQLDatabaseSequentialChain
8
+ from langchain.chains import create_sql_query_chain
9
+ from langchain_core.prompts import PromptTemplate
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.chains import LLMChain
12
+ import pandas as pd
13
+ from sqlalchemy import create_engine
14
+
15
+ load_dotenv()
16
+
17
+ class DB:
18
+ def __init__(self):
19
+ self.host = db_configuration["HOST"]
20
+ self.password = db_configuration["PASSWORD"]
21
+ self.database = db_configuration["DB"]
22
+ self.port = db_configuration["PORT"]
23
+ self.user = db_configuration["USER"]
24
+
25
+ def db_conn(self):
26
+ url = f"""mysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}?"""
27
+ return SQLDatabase.from_uri(url)
28
+
29
+ def see_table(self,rows):
30
+ url = f"""mysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}?"""
31
+ conn = create_engine(url)
32
+ df = pd.read_sql_query(f"select * from cars_details limit {rows};",con=conn,index_col="id")
33
+ return df
34
+
35
+
36
+ class LLM_conn:
37
+ def __init__(self) -> None:
38
+ self.temparature = 0
39
+ self.model = "gemini-pro"
40
+
41
+ def llm(self):
42
+ return GoogleGenerativeAI(google_api_key=API_KEY, model=self.model,temperature=self.temparature)
43
+
44
+
45
+ class Chain:
46
+ def __init__(self) -> None:
47
+ self.description = DB().db_conn().run("DESC cars_details;")
48
+ self.db = DB().db_conn()
49
+ self.llm = LLM_conn().llm()
50
+ self.memory = ConversationBufferMemory(memory_key="chat_history")
51
+
52
+ def clean_sql_query(self,query):
53
+ return query.replace("sql","").replace("```","").replace("\n"," ").strip()
54
+
55
+ def sql_chain(self,query):
56
+ chain = create_sql_query_chain(self.llm, self.db)
57
+ res = chain.invoke({"question":query,"table_info":self.description})
58
+ res = self.clean_sql_query(res)
59
+ return res
60
+
61
+ def final_sql(self,query):
62
+ sql_q = self.sql_chain(query=query)
63
+ f_sql = self.db.run(sql_q)
64
+ llm_res = self.llm.invoke(get_response().format(question = query, db_res = f_sql))
65
+ return llm_res
66
+
67
+ def memory_base_chain(self, question):
68
+ # Assuming self.sql_chain and self.db.run are defined and work correctly
69
+ sql_q = self.sql_chain(query=question)
70
+ f_sql = self.db.run(sql_q)
71
+
72
+ template = f"""
73
+ You are a nice chatbot who has nice conversation with humans.
74
+ You have to understand user question and database response and give the proper, easy to understand.\n\n
75
+ user_query : {question}
76
+ database_response : {f_sql}
77
+ Last conversation :
78
+ {{chat_history}}
79
+
80
+ Response:
81
+ """
82
+
83
+ # You may need to fetch chat_history from self.memory or another source
84
+
85
+ # Format the prompt template with actual values
86
+ # prompt = template.format(Question=question, db_res=f_sql, chat_history="") # Provide chat_history if available
87
+
88
+ formatted_prompt = PromptTemplate.format_prompt(template)
89
+
90
+ conversation = LLMChain(llm=self.llm, prompt=formatted_prompt, memory=self.memory)
91
+
92
+ res = conversation({"Question": question, "db_res": f_sql})
93
+
94
+ print(res)
95
+ return res
96
+
97
+ if __name__ =="__main__":
98
+ # db = DB()
99
+ # db_conn = db.db_conn()
100
+ # print(db_conn.run("desc cars_details;"))
101
+ # llm_conn = LLM_conn()
102
+ # llm = llm_conn.llm()
103
+ # print(llm.invoke("hi"))
104
+ # res = chain.sql_chain(query="give me name and price of most selling 3 cars")
105
+ # print(res)
106
+ # print("\n\n\n")
107
+ query = input("Enter :")
108
+ final = Chain().memory_base_chain(question= query)
109
+ print(final)