Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from langchain_community.document_loaders import TextLoader
|
6 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain_core.output_parsers import StrOutputParser
|
9 |
+
from langchain_core.runnables import RunnablePassthrough
|
10 |
+
from langchain_chroma import Chroma
|
11 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
12 |
+
|
13 |
+
page = st.title("Chat with AskUSTH")
|
14 |
+
|
15 |
+
if "gemini_api" not in st.session_state:
|
16 |
+
st.session_state.gemini_api = None
|
17 |
+
|
18 |
+
if "rag" not in st.session_state:
|
19 |
+
st.session_state.rag = None
|
20 |
+
|
21 |
+
if "llm" not in st.session_state:
|
22 |
+
st.session_state.llm = None
|
23 |
+
|
24 |
+
@st.cache_resource
|
25 |
+
def get_chat_google_model(api_key):
|
26 |
+
os.environ["GOOGLE_API_KEY"] = api_key
|
27 |
+
return ChatGoogleGenerativeAI(
|
28 |
+
model="gemini-1.5-flash",
|
29 |
+
temperature=0,
|
30 |
+
max_tokens=None,
|
31 |
+
timeout=None,
|
32 |
+
max_retries=2,
|
33 |
+
)
|
34 |
+
|
35 |
+
@st.cache_resource
|
36 |
+
def get_embedding_model():
|
37 |
+
model_name = "bkai-foundation-models/vietnamese-bi-encoder"
|
38 |
+
model_kwargs = {'device': 'cpu'}
|
39 |
+
encode_kwargs = {'normalize_embeddings': False}
|
40 |
+
|
41 |
+
model = HuggingFaceEmbeddings(
|
42 |
+
model_name=model_name,
|
43 |
+
model_kwargs=model_kwargs,
|
44 |
+
encode_kwargs=encode_kwargs
|
45 |
+
)
|
46 |
+
return model
|
47 |
+
|
48 |
+
if "embd" not in st.session_state:
|
49 |
+
st.session_state.embd = get_embedding_model()
|
50 |
+
|
51 |
+
if "model" not in st.session_state:
|
52 |
+
st.session_state.model = None
|
53 |
+
|
54 |
+
if "save_dir" not in st.session_state:
|
55 |
+
st.session_state.save_dir = None
|
56 |
+
|
57 |
+
if "uploaded_files" not in st.session_state:
|
58 |
+
st.session_state.uploaded_files = set()
|
59 |
+
|
60 |
+
@st.dialog("Setup Gemini")
|
61 |
+
def vote():
|
62 |
+
st.markdown(
|
63 |
+
"""
|
64 |
+
Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
|
65 |
+
"""
|
66 |
+
)
|
67 |
+
key = st.text_input("Key:", "")
|
68 |
+
if st.button("Save") and key != "":
|
69 |
+
st.session_state.gemini_api = key
|
70 |
+
st.rerun()
|
71 |
+
|
72 |
+
if st.session_state.gemini_api is None:
|
73 |
+
vote()
|
74 |
+
|
75 |
+
if st.session_state.gemini_api and st.session_state.model is None:
|
76 |
+
st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
|
77 |
+
|
78 |
+
if st.session_state.save_dir is None:
|
79 |
+
save_dir = "./Documents"
|
80 |
+
if not os.path.exists(save_dir):
|
81 |
+
os.makedirs(save_dir)
|
82 |
+
st.session_state.save_dir = save_dir
|
83 |
+
|
84 |
+
def load_txt(file_path):
|
85 |
+
loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
|
86 |
+
doc = loader_sv.load()
|
87 |
+
return doc
|
88 |
+
|
89 |
+
with st.sidebar:
|
90 |
+
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
|
91 |
+
if st.session_state.gemini_api:
|
92 |
+
if uploaded_files:
|
93 |
+
documents = []
|
94 |
+
uploaded_file_names = set()
|
95 |
+
new_docs = False
|
96 |
+
for uploaded_file in uploaded_files:
|
97 |
+
uploaded_file_names.add(uploaded_file.name)
|
98 |
+
if uploaded_file.name not in st.session_state.uploaded_files:
|
99 |
+
file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
|
100 |
+
with open(file_path, mode='wb') as w:
|
101 |
+
w.write(uploaded_file.getvalue())
|
102 |
+
else:
|
103 |
+
continue
|
104 |
+
|
105 |
+
new_docs = True
|
106 |
+
|
107 |
+
doc = load_txt(file_path)
|
108 |
+
|
109 |
+
documents.extend([*doc])
|
110 |
+
|
111 |
+
if new_docs:
|
112 |
+
st.session_state.uploaded_files = uploaded_file_names
|
113 |
+
st.session_state.rag = None
|
114 |
+
else:
|
115 |
+
st.session_state.uploaded_files = set()
|
116 |
+
st.session_state.rag = None
|
117 |
+
|
118 |
+
def format_docs(docs):
|
119 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
120 |
+
|
121 |
+
@st.cache_resource
|
122 |
+
def compute_rag_chain(_model, _embd, docs_texts):
|
123 |
+
# Use RecursiveCharacterTextSplitter to split text into chunks
|
124 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0)
|
125 |
+
texts = text_splitter.split_text(docs_texts)
|
126 |
+
|
127 |
+
# Create vector store for similarity search
|
128 |
+
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
|
129 |
+
retriever = vectorstore.as_retriever()
|
130 |
+
|
131 |
+
# Prepare the prompt for context and question
|
132 |
+
template = """
|
133 |
+
Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
|
134 |
+
Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
|
135 |
+
Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.\n
|
136 |
+
Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:\n
|
137 |
+
{context}\n
|
138 |
+
hãy trả lời:\n
|
139 |
+
{question}
|
140 |
+
"""
|
141 |
+
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
142 |
+
|
143 |
+
# Chain for RAG
|
144 |
+
rag_chain = (
|
145 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
146 |
+
| prompt
|
147 |
+
| _model
|
148 |
+
| StrOutputParser()
|
149 |
+
)
|
150 |
+
return rag_chain
|
151 |
+
|
152 |
+
@st.dialog("Setup RAG")
|
153 |
+
def load_rag():
|
154 |
+
docs_texts = [d.page_content for d in documents]
|
155 |
+
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
|
156 |
+
st.rerun()
|
157 |
+
|
158 |
+
if st.session_state.uploaded_files and st.session_state.model is not None:
|
159 |
+
if st.session_state.rag is None:
|
160 |
+
load_rag()
|
161 |
+
|
162 |
+
if st.session_state.model is not None:
|
163 |
+
if st.session_state.llm is None:
|
164 |
+
mess = ChatPromptTemplate.from_messages(
|
165 |
+
[
|
166 |
+
(
|
167 |
+
"system",
|
168 |
+
"Bản là một trợ lí AI hỗ trợ tuyển sinh và sinh viên",
|
169 |
+
),
|
170 |
+
("human", "{input}"),
|
171 |
+
]
|
172 |
+
)
|
173 |
+
chain = mess | st.session_state.model
|
174 |
+
st.session_state.llm = chain
|
175 |
+
|
176 |
+
if "chat_history" not in st.session_state:
|
177 |
+
st.session_state.chat_history = []
|
178 |
+
|
179 |
+
for message in st.session_state.chat_history:
|
180 |
+
with st.chat_message(message["role"]):
|
181 |
+
st.write(message["content"])
|
182 |
+
|
183 |
+
prompt = st.chat_input("Bạn muốn hỏi gì?")
|
184 |
+
if st.session_state.model is not None:
|
185 |
+
if prompt:
|
186 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
187 |
+
|
188 |
+
with st.chat_message("user"):
|
189 |
+
st.write(prompt)
|
190 |
+
|
191 |
+
with st.chat_message("assistant"):
|
192 |
+
if st.session_state.rag is not None:
|
193 |
+
respone = st.session_state.rag.invoke(prompt)
|
194 |
+
st.write(respone)
|
195 |
+
else:
|
196 |
+
ans = st.session_state.llm.invoke(prompt)
|
197 |
+
respone = ans.content
|
198 |
+
st.write(respone)
|
199 |
+
|
200 |
+
st.session_state.chat_history.append({"role": "assistant", "content": respone})
|