nkcong206 commited on
Commit
af5ec80
·
verified ·
1 Parent(s): 70997f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -121
app.py CHANGED
@@ -10,17 +10,32 @@ from langchain_core.runnables import RunnablePassthrough
10
  from langchain_chroma import Chroma
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
 
 
13
  page = st.title("Chat with AskUSTH")
14
 
 
15
  if "gemini_api" not in st.session_state:
16
  st.session_state.gemini_api = None
17
 
18
  if "rag" not in st.session_state:
19
  st.session_state.rag = None
20
-
21
  if "llm" not in st.session_state:
22
  st.session_state.llm = None
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @st.cache_resource
25
  def get_chat_google_model(api_key):
26
  os.environ["GOOGLE_API_KEY"] = api_key
@@ -42,23 +57,56 @@ def get_embedding_model():
42
  model_name=model_name,
43
  model_kwargs=model_kwargs,
44
  encode_kwargs=encode_kwargs
45
- )
46
  return model
47
 
48
- if "embd" not in st.session_state:
49
- st.session_state.embd = get_embedding_model()
 
 
 
50
 
51
- if "model" not in st.session_state:
52
- st.session_state.model = None
53
 
54
- if "save_dir" not in st.session_state:
55
- st.session_state.save_dir = None
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- if "uploaded_files" not in st.session_state:
58
- st.session_state.uploaded_files = set()
 
 
 
 
 
 
 
 
59
 
 
 
 
 
 
 
 
 
 
60
  @st.dialog("Setup Gemini")
61
- def vote():
62
  st.markdown(
63
  """
64
  Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
@@ -67,115 +115,44 @@ def vote():
67
  key = st.text_input("Key:", "")
68
  if st.button("Save") and key != "":
69
  st.session_state.gemini_api = key
70
- st.rerun()
71
 
72
  if st.session_state.gemini_api is None:
73
- vote()
74
 
75
  if st.session_state.gemini_api and st.session_state.model is None:
76
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
77
 
 
 
 
78
  if st.session_state.save_dir is None:
79
  save_dir = "./Documents"
80
  if not os.path.exists(save_dir):
81
  os.makedirs(save_dir)
82
  st.session_state.save_dir = save_dir
83
-
84
- def load_txt(file_path):
85
- loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
86
- doc = loader_sv.load()
87
- return doc
88
 
 
89
  with st.sidebar:
90
- uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
91
- if st.session_state.gemini_api:
92
- if uploaded_files:
93
- documents = []
94
- uploaded_file_names = set()
95
- new_docs = False
96
- for uploaded_file in uploaded_files:
97
- uploaded_file_names.add(uploaded_file.name)
98
- if uploaded_file.name not in st.session_state.uploaded_files:
99
- file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
100
- with open(file_path, mode='wb') as w:
101
- w.write(uploaded_file.getvalue())
102
- else:
103
- continue
104
-
105
- new_docs = True
106
-
107
  doc = load_txt(file_path)
108
-
109
  documents.extend([*doc])
110
-
111
- if new_docs:
112
- st.session_state.uploaded_files = uploaded_file_names
113
- st.session_state.rag = None
114
- else:
115
- st.session_state.uploaded_files = set()
116
- st.session_state.rag = None
117
-
118
- def format_docs(docs):
119
- return "\n\n".join(doc.page_content for doc in docs)
120
-
121
- @st.cache_resource
122
- def compute_rag_chain(_model, _embd, docs_texts):
123
- # Combine all texts into one large string
124
- combined_text = "\n\n".join(docs_texts) # Join all document texts into one string
125
-
126
- # Use RecursiveCharacterTextSplitter to split text into chunks
127
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0)
128
- texts = text_splitter.split_text(combined_text) # Now this will work as 'combined_text' is a string
129
-
130
- # Create vector store for similarity search
131
- vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
132
- retriever = vectorstore.as_retriever()
133
-
134
- # Prepare the prompt for context and question
135
- template = """
136
- Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
137
- Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
138
- Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.\n
139
- Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:\n
140
- {context}\n
141
- hãy trả lời:\n
142
- {question}
143
- """
144
- prompt = PromptTemplate(template=template, input_variables=["context", "question"])
145
-
146
- # Chain for RAG
147
- rag_chain = (
148
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
149
- | prompt
150
- | _model
151
- | StrOutputParser()
152
- )
153
- return rag_chain
154
-
155
- @st.dialog("Setup RAG")
156
- def load_rag():
157
- docs_texts = [d.page_content for d in documents]
158
- st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
159
- st.rerun()
160
-
161
- if st.session_state.uploaded_files and st.session_state.model is not None:
162
- if st.session_state.rag is None:
163
- load_rag()
164
-
165
- if st.session_state.model is not None:
166
- if st.session_state.llm is None:
167
- mess = ChatPromptTemplate.from_messages(
168
- [
169
- (
170
- "system",
171
- "Bản là một trợ lí AI hỗ trợ tuyển sinh và sinh viên",
172
- ),
173
- ("human", "{input}"),
174
- ]
175
- )
176
- chain = mess | st.session_state.model
177
- st.session_state.llm = chain
178
 
 
179
  if "chat_history" not in st.session_state:
180
  st.session_state.chat_history = []
181
 
@@ -184,20 +161,14 @@ for message in st.session_state.chat_history:
184
  st.write(message["content"])
185
 
186
  prompt = st.chat_input("Bạn muốn hỏi gì?")
187
- if st.session_state.model is not None:
188
- if prompt:
189
- st.session_state.chat_history.append({"role": "user", "content": prompt})
190
-
191
- with st.chat_message("user"):
192
- st.write(prompt)
193
-
194
- with st.chat_message("assistant"):
195
- if st.session_state.rag is not None:
196
- respone = st.session_state.rag.invoke(prompt)
197
- st.write(respone)
198
- else:
199
- ans = st.session_state.llm.invoke(prompt)
200
- respone = ans.content
201
- st.write(respone)
202
-
203
- st.session_state.chat_history.append({"role": "assistant", "content": respone})
 
10
  from langchain_chroma import Chroma
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
 
13
+ # App Title
14
  page = st.title("Chat with AskUSTH")
15
 
16
+ # Initialize session states
17
  if "gemini_api" not in st.session_state:
18
  st.session_state.gemini_api = None
19
 
20
  if "rag" not in st.session_state:
21
  st.session_state.rag = None
22
+
23
  if "llm" not in st.session_state:
24
  st.session_state.llm = None
25
 
26
+ if "embd" not in st.session_state:
27
+ st.session_state.embd = None
28
+
29
+ if "model" not in st.session_state:
30
+ st.session_state.model = None
31
+
32
+ if "save_dir" not in st.session_state:
33
+ st.session_state.save_dir = None
34
+
35
+ if "uploaded_files" not in st.session_state:
36
+ st.session_state.uploaded_files = set()
37
+
38
+ # Caching functions
39
  @st.cache_resource
40
  def get_chat_google_model(api_key):
41
  os.environ["GOOGLE_API_KEY"] = api_key
 
57
  model_name=model_name,
58
  model_kwargs=model_kwargs,
59
  encode_kwargs=encode_kwargs
60
+ )
61
  return model
62
 
63
+ # Load and process text files
64
+ def load_txt(file_path):
65
+ loader = TextLoader(file_path=file_path, encoding="utf-8")
66
+ doc = loader.load()
67
+ return doc
68
 
69
+ def format_docs(docs):
70
+ return "\n\n".join(doc.page_content for doc in docs)
71
 
72
+ # Compute RAG Chain
73
+ @st.cache_resource
74
+ def compute_rag_chain(_model, _embd, docs_texts):
75
+ if not docs_texts:
76
+ raise ValueError("No documents to process. Please upload valid text files.")
77
+
78
+ combined_text = "\n\n".join(docs_texts)
79
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
80
+ texts = text_splitter.split_text(combined_text)
81
+
82
+ if not texts:
83
+ raise ValueError("Text splitter did not generate any text chunks. Check your input.")
84
+
85
+ vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
86
+ retriever = vectorstore.as_retriever()
87
 
88
+ template = """
89
+ Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên.
90
+ Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi.
91
+ Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.
92
+ Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:
93
+ {context}
94
+ hãy trả lời:
95
+ {question}
96
+ """
97
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
98
 
99
+ rag_chain = (
100
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
101
+ | prompt
102
+ | _model
103
+ | StrOutputParser()
104
+ )
105
+ return rag_chain
106
+
107
+ # Dialog to setup Gemini
108
  @st.dialog("Setup Gemini")
109
+ def setup_gemini():
110
  st.markdown(
111
  """
112
  Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
 
115
  key = st.text_input("Key:", "")
116
  if st.button("Save") and key != "":
117
  st.session_state.gemini_api = key
118
+ st.rerun()
119
 
120
  if st.session_state.gemini_api is None:
121
+ setup_gemini()
122
 
123
  if st.session_state.gemini_api and st.session_state.model is None:
124
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
125
 
126
+ if st.session_state.embd is None:
127
+ st.session_state.embd = get_embedding_model()
128
+
129
  if st.session_state.save_dir is None:
130
  save_dir = "./Documents"
131
  if not os.path.exists(save_dir):
132
  os.makedirs(save_dir)
133
  st.session_state.save_dir = save_dir
 
 
 
 
 
134
 
135
+ # Sidebar to upload files
136
  with st.sidebar:
137
+ uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
138
+ if uploaded_files:
139
+ documents = []
140
+ uploaded_file_names = set()
141
+ for uploaded_file in uploaded_files:
142
+ uploaded_file_names.add(uploaded_file.name)
143
+ if uploaded_file.name not in st.session_state.uploaded_files:
144
+ file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
145
+ with open(file_path, mode='wb') as w:
146
+ w.write(uploaded_file.getvalue())
 
 
 
 
 
 
 
147
  doc = load_txt(file_path)
 
148
  documents.extend([*doc])
149
+
150
+ if documents:
151
+ docs_texts = [d.page_content for d in documents]
152
+ st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
153
+ st.session_state.uploaded_files = uploaded_file_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Chat Interface
156
  if "chat_history" not in st.session_state:
157
  st.session_state.chat_history = []
158
 
 
161
  st.write(message["content"])
162
 
163
  prompt = st.chat_input("Bạn muốn hỏi gì?")
164
+ if prompt and st.session_state.model:
165
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
166
+ with st.chat_message("user"):
167
+ st.write(prompt)
168
+ with st.chat_message("assistant"):
169
+ if st.session_state.rag:
170
+ response = st.session_state.rag.invoke(prompt)
171
+ else:
172
+ response = st.session_state.model.invoke(prompt).content
173
+ st.write(response)
174
+ st.session_state.chat_history.append({"role": "assistant", "content": response})