Update app.py
Browse files
app.py
CHANGED
@@ -10,17 +10,32 @@ from langchain_core.runnables import RunnablePassthrough
|
|
10 |
from langchain_chroma import Chroma
|
11 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
12 |
|
|
|
13 |
page = st.title("Chat with AskUSTH")
|
14 |
|
|
|
15 |
if "gemini_api" not in st.session_state:
|
16 |
st.session_state.gemini_api = None
|
17 |
|
18 |
if "rag" not in st.session_state:
|
19 |
st.session_state.rag = None
|
20 |
-
|
21 |
if "llm" not in st.session_state:
|
22 |
st.session_state.llm = None
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
@st.cache_resource
|
25 |
def get_chat_google_model(api_key):
|
26 |
os.environ["GOOGLE_API_KEY"] = api_key
|
@@ -42,23 +57,56 @@ def get_embedding_model():
|
|
42 |
model_name=model_name,
|
43 |
model_kwargs=model_kwargs,
|
44 |
encode_kwargs=encode_kwargs
|
45 |
-
)
|
46 |
return model
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
@st.dialog("Setup Gemini")
|
61 |
-
def
|
62 |
st.markdown(
|
63 |
"""
|
64 |
Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
|
@@ -67,115 +115,44 @@ def vote():
|
|
67 |
key = st.text_input("Key:", "")
|
68 |
if st.button("Save") and key != "":
|
69 |
st.session_state.gemini_api = key
|
70 |
-
st.rerun()
|
71 |
|
72 |
if st.session_state.gemini_api is None:
|
73 |
-
|
74 |
|
75 |
if st.session_state.gemini_api and st.session_state.model is None:
|
76 |
st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
|
77 |
|
|
|
|
|
|
|
78 |
if st.session_state.save_dir is None:
|
79 |
save_dir = "./Documents"
|
80 |
if not os.path.exists(save_dir):
|
81 |
os.makedirs(save_dir)
|
82 |
st.session_state.save_dir = save_dir
|
83 |
-
|
84 |
-
def load_txt(file_path):
|
85 |
-
loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
|
86 |
-
doc = loader_sv.load()
|
87 |
-
return doc
|
88 |
|
|
|
89 |
with st.sidebar:
|
90 |
-
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
|
91 |
-
if
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
with open(file_path, mode='wb') as w:
|
101 |
-
w.write(uploaded_file.getvalue())
|
102 |
-
else:
|
103 |
-
continue
|
104 |
-
|
105 |
-
new_docs = True
|
106 |
-
|
107 |
doc = load_txt(file_path)
|
108 |
-
|
109 |
documents.extend([*doc])
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
st.session_state.uploaded_files = set()
|
116 |
-
st.session_state.rag = None
|
117 |
-
|
118 |
-
def format_docs(docs):
|
119 |
-
return "\n\n".join(doc.page_content for doc in docs)
|
120 |
-
|
121 |
-
@st.cache_resource
|
122 |
-
def compute_rag_chain(_model, _embd, docs_texts):
|
123 |
-
# Combine all texts into one large string
|
124 |
-
combined_text = "\n\n".join(docs_texts) # Join all document texts into one string
|
125 |
-
|
126 |
-
# Use RecursiveCharacterTextSplitter to split text into chunks
|
127 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0)
|
128 |
-
texts = text_splitter.split_text(combined_text) # Now this will work as 'combined_text' is a string
|
129 |
-
|
130 |
-
# Create vector store for similarity search
|
131 |
-
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
|
132 |
-
retriever = vectorstore.as_retriever()
|
133 |
-
|
134 |
-
# Prepare the prompt for context and question
|
135 |
-
template = """
|
136 |
-
Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
|
137 |
-
Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
|
138 |
-
Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.\n
|
139 |
-
Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:\n
|
140 |
-
{context}\n
|
141 |
-
hãy trả lời:\n
|
142 |
-
{question}
|
143 |
-
"""
|
144 |
-
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
145 |
-
|
146 |
-
# Chain for RAG
|
147 |
-
rag_chain = (
|
148 |
-
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
149 |
-
| prompt
|
150 |
-
| _model
|
151 |
-
| StrOutputParser()
|
152 |
-
)
|
153 |
-
return rag_chain
|
154 |
-
|
155 |
-
@st.dialog("Setup RAG")
|
156 |
-
def load_rag():
|
157 |
-
docs_texts = [d.page_content for d in documents]
|
158 |
-
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
|
159 |
-
st.rerun()
|
160 |
-
|
161 |
-
if st.session_state.uploaded_files and st.session_state.model is not None:
|
162 |
-
if st.session_state.rag is None:
|
163 |
-
load_rag()
|
164 |
-
|
165 |
-
if st.session_state.model is not None:
|
166 |
-
if st.session_state.llm is None:
|
167 |
-
mess = ChatPromptTemplate.from_messages(
|
168 |
-
[
|
169 |
-
(
|
170 |
-
"system",
|
171 |
-
"Bản là một trợ lí AI hỗ trợ tuyển sinh và sinh viên",
|
172 |
-
),
|
173 |
-
("human", "{input}"),
|
174 |
-
]
|
175 |
-
)
|
176 |
-
chain = mess | st.session_state.model
|
177 |
-
st.session_state.llm = chain
|
178 |
|
|
|
179 |
if "chat_history" not in st.session_state:
|
180 |
st.session_state.chat_history = []
|
181 |
|
@@ -184,20 +161,14 @@ for message in st.session_state.chat_history:
|
|
184 |
st.write(message["content"])
|
185 |
|
186 |
prompt = st.chat_input("Bạn muốn hỏi gì?")
|
187 |
-
if st.session_state.model
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
else:
|
199 |
-
ans = st.session_state.llm.invoke(prompt)
|
200 |
-
respone = ans.content
|
201 |
-
st.write(respone)
|
202 |
-
|
203 |
-
st.session_state.chat_history.append({"role": "assistant", "content": respone})
|
|
|
10 |
from langchain_chroma import Chroma
|
11 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
12 |
|
13 |
+
# App Title
|
14 |
page = st.title("Chat with AskUSTH")
|
15 |
|
16 |
+
# Initialize session states
|
17 |
if "gemini_api" not in st.session_state:
|
18 |
st.session_state.gemini_api = None
|
19 |
|
20 |
if "rag" not in st.session_state:
|
21 |
st.session_state.rag = None
|
22 |
+
|
23 |
if "llm" not in st.session_state:
|
24 |
st.session_state.llm = None
|
25 |
|
26 |
+
if "embd" not in st.session_state:
|
27 |
+
st.session_state.embd = None
|
28 |
+
|
29 |
+
if "model" not in st.session_state:
|
30 |
+
st.session_state.model = None
|
31 |
+
|
32 |
+
if "save_dir" not in st.session_state:
|
33 |
+
st.session_state.save_dir = None
|
34 |
+
|
35 |
+
if "uploaded_files" not in st.session_state:
|
36 |
+
st.session_state.uploaded_files = set()
|
37 |
+
|
38 |
+
# Caching functions
|
39 |
@st.cache_resource
|
40 |
def get_chat_google_model(api_key):
|
41 |
os.environ["GOOGLE_API_KEY"] = api_key
|
|
|
57 |
model_name=model_name,
|
58 |
model_kwargs=model_kwargs,
|
59 |
encode_kwargs=encode_kwargs
|
60 |
+
)
|
61 |
return model
|
62 |
|
63 |
+
# Load and process text files
|
64 |
+
def load_txt(file_path):
|
65 |
+
loader = TextLoader(file_path=file_path, encoding="utf-8")
|
66 |
+
doc = loader.load()
|
67 |
+
return doc
|
68 |
|
69 |
+
def format_docs(docs):
|
70 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
71 |
|
72 |
+
# Compute RAG Chain
|
73 |
+
@st.cache_resource
|
74 |
+
def compute_rag_chain(_model, _embd, docs_texts):
|
75 |
+
if not docs_texts:
|
76 |
+
raise ValueError("No documents to process. Please upload valid text files.")
|
77 |
+
|
78 |
+
combined_text = "\n\n".join(docs_texts)
|
79 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
|
80 |
+
texts = text_splitter.split_text(combined_text)
|
81 |
+
|
82 |
+
if not texts:
|
83 |
+
raise ValueError("Text splitter did not generate any text chunks. Check your input.")
|
84 |
+
|
85 |
+
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
|
86 |
+
retriever = vectorstore.as_retriever()
|
87 |
|
88 |
+
template = """
|
89 |
+
Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên.
|
90 |
+
Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi.
|
91 |
+
Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.
|
92 |
+
Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:
|
93 |
+
{context}
|
94 |
+
hãy trả lời:
|
95 |
+
{question}
|
96 |
+
"""
|
97 |
+
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
98 |
|
99 |
+
rag_chain = (
|
100 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
101 |
+
| prompt
|
102 |
+
| _model
|
103 |
+
| StrOutputParser()
|
104 |
+
)
|
105 |
+
return rag_chain
|
106 |
+
|
107 |
+
# Dialog to setup Gemini
|
108 |
@st.dialog("Setup Gemini")
|
109 |
+
def setup_gemini():
|
110 |
st.markdown(
|
111 |
"""
|
112 |
Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
|
|
|
115 |
key = st.text_input("Key:", "")
|
116 |
if st.button("Save") and key != "":
|
117 |
st.session_state.gemini_api = key
|
118 |
+
st.rerun()
|
119 |
|
120 |
if st.session_state.gemini_api is None:
|
121 |
+
setup_gemini()
|
122 |
|
123 |
if st.session_state.gemini_api and st.session_state.model is None:
|
124 |
st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
|
125 |
|
126 |
+
if st.session_state.embd is None:
|
127 |
+
st.session_state.embd = get_embedding_model()
|
128 |
+
|
129 |
if st.session_state.save_dir is None:
|
130 |
save_dir = "./Documents"
|
131 |
if not os.path.exists(save_dir):
|
132 |
os.makedirs(save_dir)
|
133 |
st.session_state.save_dir = save_dir
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
# Sidebar to upload files
|
136 |
with st.sidebar:
|
137 |
+
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
|
138 |
+
if uploaded_files:
|
139 |
+
documents = []
|
140 |
+
uploaded_file_names = set()
|
141 |
+
for uploaded_file in uploaded_files:
|
142 |
+
uploaded_file_names.add(uploaded_file.name)
|
143 |
+
if uploaded_file.name not in st.session_state.uploaded_files:
|
144 |
+
file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
|
145 |
+
with open(file_path, mode='wb') as w:
|
146 |
+
w.write(uploaded_file.getvalue())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
doc = load_txt(file_path)
|
|
|
148 |
documents.extend([*doc])
|
149 |
+
|
150 |
+
if documents:
|
151 |
+
docs_texts = [d.page_content for d in documents]
|
152 |
+
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
|
153 |
+
st.session_state.uploaded_files = uploaded_file_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
# Chat Interface
|
156 |
if "chat_history" not in st.session_state:
|
157 |
st.session_state.chat_history = []
|
158 |
|
|
|
161 |
st.write(message["content"])
|
162 |
|
163 |
prompt = st.chat_input("Bạn muốn hỏi gì?")
|
164 |
+
if prompt and st.session_state.model:
|
165 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
166 |
+
with st.chat_message("user"):
|
167 |
+
st.write(prompt)
|
168 |
+
with st.chat_message("assistant"):
|
169 |
+
if st.session_state.rag:
|
170 |
+
response = st.session_state.rag.invoke(prompt)
|
171 |
+
else:
|
172 |
+
response = st.session_state.model.invoke(prompt).content
|
173 |
+
st.write(response)
|
174 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
|
|
|
|
|
|
|
|
|
|
|