nkcong206 commited on
Commit
caa9f7a
·
verified ·
1 Parent(s): 9125dab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_community.document_loaders import TextLoader
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.runnables import RunnablePassthrough
10
+ from langchain_chroma import Chroma
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+
13
+ page = st.title("Chat with AskUSTH")
14
+
15
+ if "gemini_api" not in st.session_state:
16
+ st.session_state.gemini_api = None
17
+
18
+ if "rag" not in st.session_state:
19
+ st.session_state.rag = None
20
+
21
+ if "llm" not in st.session_state:
22
+ st.session_state.llm = None
23
+
24
+ @st.cache_resource
25
+ def get_chat_google_model(api_key):
26
+ os.environ["GOOGLE_API_KEY"] = api_key
27
+ return ChatGoogleGenerativeAI(
28
+ model="gemini-1.5-flash",
29
+ temperature=0,
30
+ max_tokens=None,
31
+ timeout=None,
32
+ max_retries=2,
33
+ )
34
+
35
+ @st.cache_resource
36
+ def get_embedding_model():
37
+ model_name = "bkai-foundation-models/vietnamese-bi-encoder"
38
+ model_kwargs = {'device': 'cpu'}
39
+ encode_kwargs = {'normalize_embeddings': False}
40
+
41
+ model = HuggingFaceEmbeddings(
42
+ model_name=model_name,
43
+ model_kwargs=model_kwargs,
44
+ encode_kwargs=encode_kwargs
45
+ )
46
+ return model
47
+
48
+ if "embd" not in st.session_state:
49
+ st.session_state.embd = get_embedding_model()
50
+
51
+ if "model" not in st.session_state:
52
+ st.session_state.model = None
53
+
54
+ if "save_dir" not in st.session_state:
55
+ st.session_state.save_dir = None
56
+
57
+ if "uploaded_files" not in st.session_state:
58
+ st.session_state.uploaded_files = set()
59
+
60
+ @st.dialog("Setup Gemini")
61
+ def vote():
62
+ st.markdown(
63
+ """
64
+ Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
65
+ """
66
+ )
67
+ key = st.text_input("Key:", "")
68
+ if st.button("Save") and key != "":
69
+ st.session_state.gemini_api = key
70
+ st.rerun()
71
+
72
+ if st.session_state.gemini_api is None:
73
+ vote()
74
+
75
+ if st.session_state.gemini_api and st.session_state.model is None:
76
+ st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
77
+
78
+ if st.session_state.save_dir is None:
79
+ save_dir = "./Documents"
80
+ if not os.path.exists(save_dir):
81
+ os.makedirs(save_dir)
82
+ st.session_state.save_dir = save_dir
83
+
84
+ def load_txt(file_path):
85
+ loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
86
+ doc = loader_sv.load()
87
+ return doc
88
+
89
+ with st.sidebar:
90
+ uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
91
+ if st.session_state.gemini_api:
92
+ if uploaded_files:
93
+ documents = []
94
+ uploaded_file_names = set()
95
+ new_docs = False
96
+ for uploaded_file in uploaded_files:
97
+ uploaded_file_names.add(uploaded_file.name)
98
+ if uploaded_file.name not in st.session_state.uploaded_files:
99
+ file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
100
+ with open(file_path, mode='wb') as w:
101
+ w.write(uploaded_file.getvalue())
102
+ else:
103
+ continue
104
+
105
+ new_docs = True
106
+
107
+ doc = load_txt(file_path)
108
+
109
+ documents.extend([*doc])
110
+
111
+ if new_docs:
112
+ st.session_state.uploaded_files = uploaded_file_names
113
+ st.session_state.rag = None
114
+ else:
115
+ st.session_state.uploaded_files = set()
116
+ st.session_state.rag = None
117
+
118
+ def format_docs(docs):
119
+ return "\n\n".join(doc.page_content for doc in docs)
120
+
121
+ @st.cache_resource
122
+ def compute_rag_chain(_model, _embd, docs_texts):
123
+ # Use RecursiveCharacterTextSplitter to split text into chunks
124
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0)
125
+ texts = text_splitter.split_text(docs_texts)
126
+
127
+ # Create vector store for similarity search
128
+ vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
129
+ retriever = vectorstore.as_retriever()
130
+
131
+ # Prepare the prompt for context and question
132
+ template = """
133
+ Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
134
+ Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
135
+ Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.\n
136
+ Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:\n
137
+ {context}\n
138
+ hãy trả lời:\n
139
+ {question}
140
+ """
141
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
142
+
143
+ # Chain for RAG
144
+ rag_chain = (
145
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
146
+ | prompt
147
+ | _model
148
+ | StrOutputParser()
149
+ )
150
+ return rag_chain
151
+
152
+ @st.dialog("Setup RAG")
153
+ def load_rag():
154
+ docs_texts = [d.page_content for d in documents]
155
+ st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
156
+ st.rerun()
157
+
158
+ if st.session_state.uploaded_files and st.session_state.model is not None:
159
+ if st.session_state.rag is None:
160
+ load_rag()
161
+
162
+ if st.session_state.model is not None:
163
+ if st.session_state.llm is None:
164
+ mess = ChatPromptTemplate.from_messages(
165
+ [
166
+ (
167
+ "system",
168
+ "Bản là một trợ lí AI hỗ trợ tuyển sinh và sinh viên",
169
+ ),
170
+ ("human", "{input}"),
171
+ ]
172
+ )
173
+ chain = mess | st.session_state.model
174
+ st.session_state.llm = chain
175
+
176
+ if "chat_history" not in st.session_state:
177
+ st.session_state.chat_history = []
178
+
179
+ for message in st.session_state.chat_history:
180
+ with st.chat_message(message["role"]):
181
+ st.write(message["content"])
182
+
183
+ prompt = st.chat_input("Bạn muốn hỏi gì?")
184
+ if st.session_state.model is not None:
185
+ if prompt:
186
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
187
+
188
+ with st.chat_message("user"):
189
+ st.write(prompt)
190
+
191
+ with st.chat_message("assistant"):
192
+ if st.session_state.rag is not None:
193
+ respone = st.session_state.rag.invoke(prompt)
194
+ st.write(respone)
195
+ else:
196
+ ans = st.session_state.llm.invoke(prompt)
197
+ respone = ans.content
198
+ st.write(respone)
199
+
200
+ st.session_state.chat_history.append({"role": "assistant", "content": respone})