Spaces:
Sleeping
Sleeping
update
Browse files- app.py +379 -2
- requirements.txt +11 -0
app.py
CHANGED
@@ -1,4 +1,381 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import os
|
3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from langchain_community.document_loaders import TextLoader
|
6 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
|
9 |
+
from typing import Dict, List, Optional, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import umap
|
14 |
+
from langchain.prompts import ChatPromptTemplate
|
15 |
+
from langchain_core.output_parsers import StrOutputParser
|
16 |
+
from sklearn.mixture import GaussianMixture
|
17 |
+
|
18 |
+
from langchain_core.runnables import RunnablePassthrough
|
19 |
+
from langchain_chroma import Chroma
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def global_cluster_embeddings(
|
25 |
+
embeddings: np.ndarray,
|
26 |
+
dim: int,
|
27 |
+
n_neighbors: Optional[int] = None,
|
28 |
+
metric: str = "cosine",
|
29 |
+
) -> np.ndarray:
|
30 |
+
if n_neighbors is None:
|
31 |
+
n_neighbors = int((len(embeddings) - 1) ** 0.5)
|
32 |
+
return umap.UMAP(
|
33 |
+
n_neighbors=n_neighbors, n_components=dim, metric=metric
|
34 |
+
).fit_transform(embeddings)
|
35 |
+
|
36 |
+
|
37 |
+
def local_cluster_embeddings(
|
38 |
+
embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
|
39 |
+
) -> np.ndarray:
|
40 |
+
return umap.UMAP(
|
41 |
+
n_neighbors=num_neighbors, n_components=dim, metric=metric
|
42 |
+
).fit_transform(embeddings)
|
43 |
+
|
44 |
+
|
45 |
+
def get_optimal_clusters(
|
46 |
+
embeddings: np.ndarray, max_clusters: int = 50, random_state: int = 200
|
47 |
+
) -> int:
|
48 |
+
max_clusters = min(max_clusters, len(embeddings))
|
49 |
+
n_clusters = np.arange(1, max_clusters)
|
50 |
+
bics = []
|
51 |
+
for n in n_clusters:
|
52 |
+
gm = GaussianMixture(n_components=n, random_state=random_state)
|
53 |
+
gm.fit(embeddings)
|
54 |
+
bics.append(gm.bic(embeddings))
|
55 |
+
return n_clusters[np.argmin(bics)]
|
56 |
+
|
57 |
+
|
58 |
+
def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
|
59 |
+
n_clusters = get_optimal_clusters(embeddings, random_state = 200)
|
60 |
+
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
61 |
+
gm.fit(embeddings)
|
62 |
+
probs = gm.predict_proba(embeddings)
|
63 |
+
labels = [np.where(prob > threshold)[0] for prob in probs]
|
64 |
+
return labels, n_clusters
|
65 |
+
|
66 |
+
|
67 |
+
def perform_clustering(
|
68 |
+
embeddings: np.ndarray,
|
69 |
+
dim: int,
|
70 |
+
threshold: float,
|
71 |
+
) -> List[np.ndarray]:
|
72 |
+
if len(embeddings) <= dim + 1:
|
73 |
+
# Avoid clustering when there's insufficient data
|
74 |
+
return [np.array([0]) for _ in range(len(embeddings))]
|
75 |
+
|
76 |
+
# Global dimensionality reduction
|
77 |
+
reduced_embeddings_global = global_cluster_embeddings(embeddings, dim)
|
78 |
+
# Global clustering
|
79 |
+
global_clusters, n_global_clusters = GMM_cluster(
|
80 |
+
reduced_embeddings_global, threshold
|
81 |
+
)
|
82 |
+
|
83 |
+
all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
|
84 |
+
total_clusters = 0
|
85 |
+
|
86 |
+
# Iterate through each global cluster to perform local clustering
|
87 |
+
for i in range(n_global_clusters):
|
88 |
+
# Extract embeddings belonging to the current global cluster
|
89 |
+
global_cluster_embeddings_ = embeddings[
|
90 |
+
np.array([i in gc for gc in global_clusters])
|
91 |
+
]
|
92 |
+
|
93 |
+
if len(global_cluster_embeddings_) == 0:
|
94 |
+
continue
|
95 |
+
if len(global_cluster_embeddings_) <= dim + 1:
|
96 |
+
# Handle small clusters with direct assignment
|
97 |
+
local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
|
98 |
+
n_local_clusters = 1
|
99 |
+
else:
|
100 |
+
# Local dimensionality reduction and clustering
|
101 |
+
reduced_embeddings_local = local_cluster_embeddings(
|
102 |
+
global_cluster_embeddings_, dim
|
103 |
+
)
|
104 |
+
local_clusters, n_local_clusters = GMM_cluster(
|
105 |
+
reduced_embeddings_local, threshold
|
106 |
+
)
|
107 |
+
|
108 |
+
# Assign local cluster IDs, adjusting for total clusters already processed
|
109 |
+
for j in range(n_local_clusters):
|
110 |
+
local_cluster_embeddings_ = global_cluster_embeddings_[
|
111 |
+
np.array([j in lc for lc in local_clusters])
|
112 |
+
]
|
113 |
+
indices = np.where(
|
114 |
+
(embeddings == local_cluster_embeddings_[:, None]).all(-1)
|
115 |
+
)[1]
|
116 |
+
for idx in indices:
|
117 |
+
all_local_clusters[idx] = np.append(
|
118 |
+
all_local_clusters[idx], j + total_clusters
|
119 |
+
)
|
120 |
+
|
121 |
+
total_clusters += n_local_clusters
|
122 |
+
|
123 |
+
return all_local_clusters
|
124 |
+
|
125 |
+
def embed(embd,texts):
|
126 |
+
text_embeddings = embd.embed_documents(texts)
|
127 |
+
text_embeddings_np = np.array(text_embeddings)
|
128 |
+
return text_embeddings_np
|
129 |
+
|
130 |
+
def embed_cluster_texts(embd,texts):
|
131 |
+
text_embeddings_np = embed(embd,texts) # Generate embeddings
|
132 |
+
cluster_labels = perform_clustering(
|
133 |
+
text_embeddings_np, 10, 0.1
|
134 |
+
) # Perform clustering on the embeddings
|
135 |
+
df = pd.DataFrame() # Initialize a DataFrame to store the results
|
136 |
+
df["text"] = texts # Store original texts
|
137 |
+
df["embd"] = list(text_embeddings_np) # Store embeddings as a list in the DataFrame
|
138 |
+
df["cluster"] = cluster_labels # Store cluster labels
|
139 |
+
return df
|
140 |
+
|
141 |
+
def fmt_txt(df: pd.DataFrame) -> str:
|
142 |
+
unique_txt = df["text"].tolist()
|
143 |
+
return "--- --- \n --- --- ".join(unique_txt)
|
144 |
+
|
145 |
+
|
146 |
+
def embed_cluster_summarize_texts(model,embd,
|
147 |
+
texts: List[str], level: int
|
148 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
149 |
+
df_clusters = embed_cluster_texts(embd,texts)
|
150 |
+
|
151 |
+
# Prepare to expand the DataFrame for easier manipulation of clusters
|
152 |
+
expanded_list = []
|
153 |
+
|
154 |
+
# Expand DataFrame entries to document-cluster pairings for straightforward processing
|
155 |
+
for index, row in df_clusters.iterrows():
|
156 |
+
for cluster in row["cluster"]:
|
157 |
+
expanded_list.append(
|
158 |
+
{"text": row["text"], "embd": row["embd"], "cluster": cluster}
|
159 |
+
)
|
160 |
+
|
161 |
+
# Create a new DataFrame from the expanded list
|
162 |
+
expanded_df = pd.DataFrame(expanded_list)
|
163 |
+
|
164 |
+
# Retrieve unique cluster identifiers for processing
|
165 |
+
all_clusters = expanded_df["cluster"].unique()
|
166 |
+
# Summarization
|
167 |
+
template = """Bạn là một chatbot hỗ trợ tuyển sinh và sinh viên đại học, hãy tóm tắt chi tiết tài liệu quy chế dưới đây.
|
168 |
+
Đảm bảo rằng nội dung tóm tắt giúp người dùng hiểu rõ các quy định và quy trình liên quan đến tuyển sinh hoặc đào tạo tại đại học.
|
169 |
+
Tài liệu:
|
170 |
+
{context}
|
171 |
+
"""
|
172 |
+
prompt = ChatPromptTemplate.from_template(template)
|
173 |
+
chain = prompt | model | StrOutputParser()
|
174 |
+
|
175 |
+
summaries = []
|
176 |
+
for i in all_clusters:
|
177 |
+
df_cluster = expanded_df[expanded_df["cluster"] == i]
|
178 |
+
formatted_txt = fmt_txt(df_cluster)
|
179 |
+
summaries.append(chain.invoke({"context": formatted_txt}))
|
180 |
+
df_summary = pd.DataFrame(
|
181 |
+
{
|
182 |
+
"summaries": summaries,
|
183 |
+
"level": [level] * len(summaries),
|
184 |
+
"cluster": list(all_clusters),
|
185 |
+
}
|
186 |
+
)
|
187 |
+
return df_clusters, df_summary
|
188 |
+
|
189 |
+
def recursive_embed_cluster_summarize(model,embd,
|
190 |
+
texts: List[str], level: int = 1, n_levels: int = 3
|
191 |
+
) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
|
192 |
+
results = {}
|
193 |
+
df_clusters, df_summary = embed_cluster_summarize_texts(model,embd,texts, level)
|
194 |
+
|
195 |
+
results[level] = (df_clusters, df_summary)
|
196 |
+
|
197 |
+
unique_clusters = df_summary["cluster"].nunique()
|
198 |
+
if level < n_levels and unique_clusters > 1:
|
199 |
+
new_texts = df_summary["summaries"].tolist()
|
200 |
+
next_level_results = recursive_embed_cluster_summarize(model,embd,
|
201 |
+
new_texts, level + 1, n_levels
|
202 |
+
)
|
203 |
+
results.update(next_level_results)
|
204 |
+
|
205 |
+
return results
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
page = st.title("Chat with Gemini")
|
210 |
+
|
211 |
+
if "gemini_api" not in st.session_state:
|
212 |
+
st.session_state.gemini_api = None
|
213 |
+
|
214 |
+
if "rag" not in st.session_state:
|
215 |
+
st.session_state.rag = None
|
216 |
+
|
217 |
+
if "llm" not in st.session_state:
|
218 |
+
st.session_state.llm = None
|
219 |
+
|
220 |
+
if "model" not in st.session_state:
|
221 |
+
st.session_state.model = None
|
222 |
+
|
223 |
+
if "embd" not in st.session_state:
|
224 |
+
st.session_state.embd = None
|
225 |
+
|
226 |
+
if "save_dir" not in st.session_state:
|
227 |
+
st.session_state.save_dir = None
|
228 |
+
|
229 |
+
if "uploaded_files" not in st.session_state:
|
230 |
+
st.session_state.uploaded_files = set()
|
231 |
+
|
232 |
+
@st.dialog("Setup Gemini")
|
233 |
+
def vote():
|
234 |
+
st.markdown(
|
235 |
+
"""
|
236 |
+
Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
|
237 |
+
"""
|
238 |
+
)
|
239 |
+
key = st.text_input("Key:", "")
|
240 |
+
if st.button("Save"):
|
241 |
+
st.session_state.gemini_api = key
|
242 |
+
st.rerun()
|
243 |
+
|
244 |
+
if st.session_state.gemini_api is None:
|
245 |
+
vote()
|
246 |
+
else:
|
247 |
+
os.environ["GOOGLE_API_KEY"] = st.session_state.gemini_api
|
248 |
+
|
249 |
+
st.session_state.model = ChatGoogleGenerativeAI(
|
250 |
+
model="gemini-1.5-flash",
|
251 |
+
temperature=0,
|
252 |
+
max_tokens=None,
|
253 |
+
timeout=None,
|
254 |
+
max_retries=2,
|
255 |
+
)
|
256 |
+
|
257 |
+
st.write(f"Key is set to: {st.session_state.gemini_api}")
|
258 |
+
model_name="bkai-foundation-models/vietnamese-bi-encoder"
|
259 |
+
model_kwargs = {'device': 'cpu'}
|
260 |
+
encode_kwargs = {'normalize_embeddings': False}
|
261 |
+
|
262 |
+
st.session_state.embd = HuggingFaceEmbeddings(
|
263 |
+
model_name=model_name,
|
264 |
+
model_kwargs=model_kwargs,
|
265 |
+
encode_kwargs=encode_kwargs
|
266 |
+
)
|
267 |
+
|
268 |
+
st.write(f"loaded vietnamese-bi-encoder")
|
269 |
+
|
270 |
+
if st.session_state.save_dir is None:
|
271 |
+
save_dir = "./Documents"
|
272 |
+
if not os.path.exists(save_dir):
|
273 |
+
os.makedirs(save_dir)
|
274 |
+
st.session_state.save_dir = save_dir
|
275 |
+
|
276 |
+
def load_txt(file_path):
|
277 |
+
loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
|
278 |
+
doc = loader_sv.load()
|
279 |
+
return doc
|
280 |
+
|
281 |
+
with st.sidebar:
|
282 |
+
uploaded_files = st.file_uploader("Chọn file CSV", accept_multiple_files=True, type=["txt"])
|
283 |
+
if st.session_state.gemini_api:
|
284 |
+
if uploaded_files:
|
285 |
+
documents = []
|
286 |
+
uploaded_file_names = set()
|
287 |
+
new_docs = False
|
288 |
+
for uploaded_file in uploaded_files:
|
289 |
+
uploaded_file_names.add(uploaded_file.name)
|
290 |
+
if uploaded_file.name not in st.session_state.uploaded_files:
|
291 |
+
file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
|
292 |
+
with open(file_path, mode='wb') as w:
|
293 |
+
w.write(uploaded_file.getvalue())
|
294 |
+
else:
|
295 |
+
continue
|
296 |
+
|
297 |
+
new_docs = True
|
298 |
+
|
299 |
+
doc = load_txt(file_path)
|
300 |
+
|
301 |
+
documents.extend([*doc])
|
302 |
+
|
303 |
+
if new_docs:
|
304 |
+
st.session_state.uploaded_files = uploaded_file_names
|
305 |
+
st.session_state.rag = None
|
306 |
+
|
307 |
+
if not uploaded_file_names:
|
308 |
+
st.session_state.uploaded_files = set()
|
309 |
+
st.session_state.rag = None
|
310 |
+
|
311 |
+
if st.session_state.uploaded_files:
|
312 |
+
if st.session_state.gemini_api is not None:
|
313 |
+
if st.session_state.rag is None:
|
314 |
+
docs_texts = [d.page_content for d in documents]
|
315 |
+
|
316 |
+
results = recursive_embed_cluster_summarize(st.session_state.model, st.session_state.embd, docs_texts, level=1, n_levels=3)
|
317 |
+
|
318 |
+
all_texts = docs_texts.copy()
|
319 |
+
|
320 |
+
for level in sorted(results.keys()):
|
321 |
+
summaries = results[level][1]["summaries"].tolist()
|
322 |
+
all_texts.extend(summaries)
|
323 |
+
|
324 |
+
vectorstore = Chroma.from_texts(texts=all_texts, embedding=st.session_state.embd)
|
325 |
+
|
326 |
+
retriever = vectorstore.as_retriever()
|
327 |
+
|
328 |
+
def format_docs(docs):
|
329 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
330 |
+
|
331 |
+
template = """<|im_start|>system\nBản là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.
|
332 |
+
\n
|
333 |
+
|
334 |
+
Dưới đây là thông tin liên quan mà bạn có thể sử dụng:\n{context}<|im_end|>\n<|im_start|>hãy trả lời: \n{question}<|im_end|>\n<|im_start|>assistant"""
|
335 |
+
prompt = PromptTemplate(template = template, input_variables=["context", "question"])
|
336 |
+
rag_chain = (
|
337 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
338 |
+
| prompt
|
339 |
+
| st.session_state.model
|
340 |
+
| StrOutputParser()
|
341 |
+
)
|
342 |
+
st.session_state.rag = rag_chain
|
343 |
+
|
344 |
+
if st.session_state.gemini_api is not None:
|
345 |
+
if st.session_state.llm is None:
|
346 |
+
mess = ChatPromptTemplate.from_messages(
|
347 |
+
[
|
348 |
+
(
|
349 |
+
"system",
|
350 |
+
"Bạn là 1 chatbot thông minh",
|
351 |
+
),
|
352 |
+
("human", "{input}"),
|
353 |
+
]
|
354 |
+
)
|
355 |
+
chain = mess | st.session_state.model
|
356 |
+
st.session_state.llm = chain
|
357 |
+
|
358 |
+
if "chat_history" not in st.session_state:
|
359 |
+
st.session_state.chat_history = []
|
360 |
+
|
361 |
+
for message in st.session_state.chat_history:
|
362 |
+
with st.chat_message(message["role"]):
|
363 |
+
st.write(message["content"])
|
364 |
+
|
365 |
+
prompt = st.chat_input("Bạn muốn hỏi gì?")
|
366 |
+
if st.session_state.gemini_api:
|
367 |
+
if prompt:
|
368 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
369 |
+
|
370 |
+
with st.chat_message("user"):
|
371 |
+
st.write(prompt)
|
372 |
+
|
373 |
+
with st.chat_message("assistant"):
|
374 |
+
if st.session_state.rag is not None:
|
375 |
+
respone = st.session_state.rag.invoke(prompt)
|
376 |
+
st.write(respone)
|
377 |
+
else:
|
378 |
+
respone = st.session_state.llm.invoke(prompt)
|
379 |
+
st.write(respone)
|
380 |
+
|
381 |
+
st.session_state.chat_history.append({"role": "assistant", "content": respone})
|
requirements.txt
CHANGED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
langchain-google-genai
|
3 |
+
langchain-core
|
4 |
+
langchain-community
|
5 |
+
langchain-huggingface
|
6 |
+
numpy
|
7 |
+
pandas
|
8 |
+
umap-learn
|
9 |
+
scikit-learn
|
10 |
+
langchain-chroma
|
11 |
+
|