nkcong206 commited on
Commit
8e20df2
·
1 Parent(s): 46c348b
Files changed (2) hide show
  1. app.py +379 -2
  2. requirements.txt +11 -0
app.py CHANGED
@@ -1,4 +1,381 @@
1
  import streamlit as st
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_community.document_loaders import TextLoader
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain.prompts import PromptTemplate
8
 
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import umap
14
+ from langchain.prompts import ChatPromptTemplate
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from sklearn.mixture import GaussianMixture
17
+
18
+ from langchain_core.runnables import RunnablePassthrough
19
+ from langchain_chroma import Chroma
20
+
21
+
22
+
23
+
24
+ def global_cluster_embeddings(
25
+ embeddings: np.ndarray,
26
+ dim: int,
27
+ n_neighbors: Optional[int] = None,
28
+ metric: str = "cosine",
29
+ ) -> np.ndarray:
30
+ if n_neighbors is None:
31
+ n_neighbors = int((len(embeddings) - 1) ** 0.5)
32
+ return umap.UMAP(
33
+ n_neighbors=n_neighbors, n_components=dim, metric=metric
34
+ ).fit_transform(embeddings)
35
+
36
+
37
+ def local_cluster_embeddings(
38
+ embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
39
+ ) -> np.ndarray:
40
+ return umap.UMAP(
41
+ n_neighbors=num_neighbors, n_components=dim, metric=metric
42
+ ).fit_transform(embeddings)
43
+
44
+
45
+ def get_optimal_clusters(
46
+ embeddings: np.ndarray, max_clusters: int = 50, random_state: int = 200
47
+ ) -> int:
48
+ max_clusters = min(max_clusters, len(embeddings))
49
+ n_clusters = np.arange(1, max_clusters)
50
+ bics = []
51
+ for n in n_clusters:
52
+ gm = GaussianMixture(n_components=n, random_state=random_state)
53
+ gm.fit(embeddings)
54
+ bics.append(gm.bic(embeddings))
55
+ return n_clusters[np.argmin(bics)]
56
+
57
+
58
+ def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
59
+ n_clusters = get_optimal_clusters(embeddings, random_state = 200)
60
+ gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
61
+ gm.fit(embeddings)
62
+ probs = gm.predict_proba(embeddings)
63
+ labels = [np.where(prob > threshold)[0] for prob in probs]
64
+ return labels, n_clusters
65
+
66
+
67
+ def perform_clustering(
68
+ embeddings: np.ndarray,
69
+ dim: int,
70
+ threshold: float,
71
+ ) -> List[np.ndarray]:
72
+ if len(embeddings) <= dim + 1:
73
+ # Avoid clustering when there's insufficient data
74
+ return [np.array([0]) for _ in range(len(embeddings))]
75
+
76
+ # Global dimensionality reduction
77
+ reduced_embeddings_global = global_cluster_embeddings(embeddings, dim)
78
+ # Global clustering
79
+ global_clusters, n_global_clusters = GMM_cluster(
80
+ reduced_embeddings_global, threshold
81
+ )
82
+
83
+ all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
84
+ total_clusters = 0
85
+
86
+ # Iterate through each global cluster to perform local clustering
87
+ for i in range(n_global_clusters):
88
+ # Extract embeddings belonging to the current global cluster
89
+ global_cluster_embeddings_ = embeddings[
90
+ np.array([i in gc for gc in global_clusters])
91
+ ]
92
+
93
+ if len(global_cluster_embeddings_) == 0:
94
+ continue
95
+ if len(global_cluster_embeddings_) <= dim + 1:
96
+ # Handle small clusters with direct assignment
97
+ local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
98
+ n_local_clusters = 1
99
+ else:
100
+ # Local dimensionality reduction and clustering
101
+ reduced_embeddings_local = local_cluster_embeddings(
102
+ global_cluster_embeddings_, dim
103
+ )
104
+ local_clusters, n_local_clusters = GMM_cluster(
105
+ reduced_embeddings_local, threshold
106
+ )
107
+
108
+ # Assign local cluster IDs, adjusting for total clusters already processed
109
+ for j in range(n_local_clusters):
110
+ local_cluster_embeddings_ = global_cluster_embeddings_[
111
+ np.array([j in lc for lc in local_clusters])
112
+ ]
113
+ indices = np.where(
114
+ (embeddings == local_cluster_embeddings_[:, None]).all(-1)
115
+ )[1]
116
+ for idx in indices:
117
+ all_local_clusters[idx] = np.append(
118
+ all_local_clusters[idx], j + total_clusters
119
+ )
120
+
121
+ total_clusters += n_local_clusters
122
+
123
+ return all_local_clusters
124
+
125
+ def embed(embd,texts):
126
+ text_embeddings = embd.embed_documents(texts)
127
+ text_embeddings_np = np.array(text_embeddings)
128
+ return text_embeddings_np
129
+
130
+ def embed_cluster_texts(embd,texts):
131
+ text_embeddings_np = embed(embd,texts) # Generate embeddings
132
+ cluster_labels = perform_clustering(
133
+ text_embeddings_np, 10, 0.1
134
+ ) # Perform clustering on the embeddings
135
+ df = pd.DataFrame() # Initialize a DataFrame to store the results
136
+ df["text"] = texts # Store original texts
137
+ df["embd"] = list(text_embeddings_np) # Store embeddings as a list in the DataFrame
138
+ df["cluster"] = cluster_labels # Store cluster labels
139
+ return df
140
+
141
+ def fmt_txt(df: pd.DataFrame) -> str:
142
+ unique_txt = df["text"].tolist()
143
+ return "--- --- \n --- --- ".join(unique_txt)
144
+
145
+
146
+ def embed_cluster_summarize_texts(model,embd,
147
+ texts: List[str], level: int
148
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
149
+ df_clusters = embed_cluster_texts(embd,texts)
150
+
151
+ # Prepare to expand the DataFrame for easier manipulation of clusters
152
+ expanded_list = []
153
+
154
+ # Expand DataFrame entries to document-cluster pairings for straightforward processing
155
+ for index, row in df_clusters.iterrows():
156
+ for cluster in row["cluster"]:
157
+ expanded_list.append(
158
+ {"text": row["text"], "embd": row["embd"], "cluster": cluster}
159
+ )
160
+
161
+ # Create a new DataFrame from the expanded list
162
+ expanded_df = pd.DataFrame(expanded_list)
163
+
164
+ # Retrieve unique cluster identifiers for processing
165
+ all_clusters = expanded_df["cluster"].unique()
166
+ # Summarization
167
+ template = """Bạn là một chatbot hỗ trợ tuyển sinh và sinh viên đại học, hãy tóm tắt chi tiết tài liệu quy chế dưới đây.
168
+ Đảm bảo rằng nội dung tóm tắt giúp người dùng hiểu rõ các quy định và quy trình liên quan đến tuyển sinh hoặc đào tạo tại đại học.
169
+ Tài liệu:
170
+ {context}
171
+ """
172
+ prompt = ChatPromptTemplate.from_template(template)
173
+ chain = prompt | model | StrOutputParser()
174
+
175
+ summaries = []
176
+ for i in all_clusters:
177
+ df_cluster = expanded_df[expanded_df["cluster"] == i]
178
+ formatted_txt = fmt_txt(df_cluster)
179
+ summaries.append(chain.invoke({"context": formatted_txt}))
180
+ df_summary = pd.DataFrame(
181
+ {
182
+ "summaries": summaries,
183
+ "level": [level] * len(summaries),
184
+ "cluster": list(all_clusters),
185
+ }
186
+ )
187
+ return df_clusters, df_summary
188
+
189
+ def recursive_embed_cluster_summarize(model,embd,
190
+ texts: List[str], level: int = 1, n_levels: int = 3
191
+ ) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
192
+ results = {}
193
+ df_clusters, df_summary = embed_cluster_summarize_texts(model,embd,texts, level)
194
+
195
+ results[level] = (df_clusters, df_summary)
196
+
197
+ unique_clusters = df_summary["cluster"].nunique()
198
+ if level < n_levels and unique_clusters > 1:
199
+ new_texts = df_summary["summaries"].tolist()
200
+ next_level_results = recursive_embed_cluster_summarize(model,embd,
201
+ new_texts, level + 1, n_levels
202
+ )
203
+ results.update(next_level_results)
204
+
205
+ return results
206
+
207
+
208
+
209
+ page = st.title("Chat with Gemini")
210
+
211
+ if "gemini_api" not in st.session_state:
212
+ st.session_state.gemini_api = None
213
+
214
+ if "rag" not in st.session_state:
215
+ st.session_state.rag = None
216
+
217
+ if "llm" not in st.session_state:
218
+ st.session_state.llm = None
219
+
220
+ if "model" not in st.session_state:
221
+ st.session_state.model = None
222
+
223
+ if "embd" not in st.session_state:
224
+ st.session_state.embd = None
225
+
226
+ if "save_dir" not in st.session_state:
227
+ st.session_state.save_dir = None
228
+
229
+ if "uploaded_files" not in st.session_state:
230
+ st.session_state.uploaded_files = set()
231
+
232
+ @st.dialog("Setup Gemini")
233
+ def vote():
234
+ st.markdown(
235
+ """
236
+ Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
237
+ """
238
+ )
239
+ key = st.text_input("Key:", "")
240
+ if st.button("Save"):
241
+ st.session_state.gemini_api = key
242
+ st.rerun()
243
+
244
+ if st.session_state.gemini_api is None:
245
+ vote()
246
+ else:
247
+ os.environ["GOOGLE_API_KEY"] = st.session_state.gemini_api
248
+
249
+ st.session_state.model = ChatGoogleGenerativeAI(
250
+ model="gemini-1.5-flash",
251
+ temperature=0,
252
+ max_tokens=None,
253
+ timeout=None,
254
+ max_retries=2,
255
+ )
256
+
257
+ st.write(f"Key is set to: {st.session_state.gemini_api}")
258
+ model_name="bkai-foundation-models/vietnamese-bi-encoder"
259
+ model_kwargs = {'device': 'cpu'}
260
+ encode_kwargs = {'normalize_embeddings': False}
261
+
262
+ st.session_state.embd = HuggingFaceEmbeddings(
263
+ model_name=model_name,
264
+ model_kwargs=model_kwargs,
265
+ encode_kwargs=encode_kwargs
266
+ )
267
+
268
+ st.write(f"loaded vietnamese-bi-encoder")
269
+
270
+ if st.session_state.save_dir is None:
271
+ save_dir = "./Documents"
272
+ if not os.path.exists(save_dir):
273
+ os.makedirs(save_dir)
274
+ st.session_state.save_dir = save_dir
275
+
276
+ def load_txt(file_path):
277
+ loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
278
+ doc = loader_sv.load()
279
+ return doc
280
+
281
+ with st.sidebar:
282
+ uploaded_files = st.file_uploader("Chọn file CSV", accept_multiple_files=True, type=["txt"])
283
+ if st.session_state.gemini_api:
284
+ if uploaded_files:
285
+ documents = []
286
+ uploaded_file_names = set()
287
+ new_docs = False
288
+ for uploaded_file in uploaded_files:
289
+ uploaded_file_names.add(uploaded_file.name)
290
+ if uploaded_file.name not in st.session_state.uploaded_files:
291
+ file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
292
+ with open(file_path, mode='wb') as w:
293
+ w.write(uploaded_file.getvalue())
294
+ else:
295
+ continue
296
+
297
+ new_docs = True
298
+
299
+ doc = load_txt(file_path)
300
+
301
+ documents.extend([*doc])
302
+
303
+ if new_docs:
304
+ st.session_state.uploaded_files = uploaded_file_names
305
+ st.session_state.rag = None
306
+
307
+ if not uploaded_file_names:
308
+ st.session_state.uploaded_files = set()
309
+ st.session_state.rag = None
310
+
311
+ if st.session_state.uploaded_files:
312
+ if st.session_state.gemini_api is not None:
313
+ if st.session_state.rag is None:
314
+ docs_texts = [d.page_content for d in documents]
315
+
316
+ results = recursive_embed_cluster_summarize(st.session_state.model, st.session_state.embd, docs_texts, level=1, n_levels=3)
317
+
318
+ all_texts = docs_texts.copy()
319
+
320
+ for level in sorted(results.keys()):
321
+ summaries = results[level][1]["summaries"].tolist()
322
+ all_texts.extend(summaries)
323
+
324
+ vectorstore = Chroma.from_texts(texts=all_texts, embedding=st.session_state.embd)
325
+
326
+ retriever = vectorstore.as_retriever()
327
+
328
+ def format_docs(docs):
329
+ return "\n\n".join(doc.page_content for doc in docs)
330
+
331
+ template = """<|im_start|>system\nBản là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.
332
+ \n
333
+
334
+ Dưới đây là thông tin liên quan mà bạn có thể sử dụng:\n{context}<|im_end|>\n<|im_start|>hãy trả lời: \n{question}<|im_end|>\n<|im_start|>assistant"""
335
+ prompt = PromptTemplate(template = template, input_variables=["context", "question"])
336
+ rag_chain = (
337
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
338
+ | prompt
339
+ | st.session_state.model
340
+ | StrOutputParser()
341
+ )
342
+ st.session_state.rag = rag_chain
343
+
344
+ if st.session_state.gemini_api is not None:
345
+ if st.session_state.llm is None:
346
+ mess = ChatPromptTemplate.from_messages(
347
+ [
348
+ (
349
+ "system",
350
+ "Bạn là 1 chatbot thông minh",
351
+ ),
352
+ ("human", "{input}"),
353
+ ]
354
+ )
355
+ chain = mess | st.session_state.model
356
+ st.session_state.llm = chain
357
+
358
+ if "chat_history" not in st.session_state:
359
+ st.session_state.chat_history = []
360
+
361
+ for message in st.session_state.chat_history:
362
+ with st.chat_message(message["role"]):
363
+ st.write(message["content"])
364
+
365
+ prompt = st.chat_input("Bạn muốn hỏi gì?")
366
+ if st.session_state.gemini_api:
367
+ if prompt:
368
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
369
+
370
+ with st.chat_message("user"):
371
+ st.write(prompt)
372
+
373
+ with st.chat_message("assistant"):
374
+ if st.session_state.rag is not None:
375
+ respone = st.session_state.rag.invoke(prompt)
376
+ st.write(respone)
377
+ else:
378
+ respone = st.session_state.llm.invoke(prompt)
379
+ st.write(respone)
380
+
381
+ st.session_state.chat_history.append({"role": "assistant", "content": respone})
requirements.txt CHANGED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain-google-genai
3
+ langchain-core
4
+ langchain-community
5
+ langchain-huggingface
6
+ numpy
7
+ pandas
8
+ umap-learn
9
+ scikit-learn
10
+ langchain-chroma
11
+