Update app.py
Browse files
app.py
CHANGED
@@ -10,10 +10,10 @@ from langchain_core.runnables import RunnablePassthrough
|
|
10 |
from langchain_chroma import Chroma
|
11 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
12 |
|
13 |
-
#
|
14 |
page = st.title("Chat with AskUSTH")
|
15 |
|
16 |
-
#
|
17 |
if "gemini_api" not in st.session_state:
|
18 |
st.session_state.gemini_api = None
|
19 |
|
@@ -35,6 +35,21 @@ if "save_dir" not in st.session_state:
|
|
35 |
if "uploaded_files" not in st.session_state:
|
36 |
st.session_state.uploaded_files = set()
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
@st.cache_resource
|
39 |
def get_chat_google_model(api_key):
|
40 |
os.environ["GOOGLE_API_KEY"] = api_key
|
@@ -46,6 +61,7 @@ def get_chat_google_model(api_key):
|
|
46 |
max_retries=2,
|
47 |
)
|
48 |
|
|
|
49 |
@st.cache_resource
|
50 |
def get_embedding_model():
|
51 |
model_name = "bkai-foundation-models/vietnamese-bi-encoder"
|
@@ -59,26 +75,18 @@ def get_embedding_model():
|
|
59 |
)
|
60 |
return model
|
61 |
|
62 |
-
|
63 |
-
loader = TextLoader(file_path=file_path, encoding="utf-8")
|
64 |
-
doc = loader.load()
|
65 |
-
return doc
|
66 |
-
|
67 |
-
def format_docs(docs):
|
68 |
-
"""Format documents into a single string for prompt input."""
|
69 |
-
return "\n\n".join(doc.page_content for doc in docs)
|
70 |
-
|
71 |
@st.cache_resource
|
72 |
def compute_rag_chain(_model, _embd, docs_texts):
|
73 |
if not docs_texts:
|
74 |
-
raise ValueError("
|
75 |
|
76 |
combined_text = "\n\n".join(docs_texts)
|
77 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
78 |
texts = text_splitter.split_text(combined_text)
|
79 |
|
80 |
if len(texts) > 5000:
|
81 |
-
raise ValueError("
|
82 |
|
83 |
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
|
84 |
retriever = vectorstore.as_retriever()
|
@@ -102,7 +110,7 @@ def compute_rag_chain(_model, _embd, docs_texts):
|
|
102 |
)
|
103 |
return rag_chain
|
104 |
|
105 |
-
# Dialog
|
106 |
@st.dialog("Setup Gemini")
|
107 |
def setup_gemini():
|
108 |
st.markdown(
|
@@ -130,6 +138,7 @@ if st.session_state.save_dir is None:
|
|
130 |
os.makedirs(save_dir)
|
131 |
st.session_state.save_dir = save_dir
|
132 |
|
|
|
133 |
with st.sidebar:
|
134 |
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
|
135 |
max_file_size_mb = 5
|
@@ -150,3 +159,26 @@ with st.sidebar:
|
|
150 |
if documents:
|
151 |
docs_texts = [d.page_content for d in documents]
|
152 |
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from langchain_chroma import Chroma
|
11 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
12 |
|
13 |
+
# Tiêu đề ứng dụng
|
14 |
page = st.title("Chat with AskUSTH")
|
15 |
|
16 |
+
# Khởi tạo trạng thái phiên
|
17 |
if "gemini_api" not in st.session_state:
|
18 |
st.session_state.gemini_api = None
|
19 |
|
|
|
35 |
if "uploaded_files" not in st.session_state:
|
36 |
st.session_state.uploaded_files = set()
|
37 |
|
38 |
+
if "chat_history" not in st.session_state:
|
39 |
+
st.session_state.chat_history = []
|
40 |
+
|
41 |
+
# Hàm tải và xử lý file văn bản
|
42 |
+
def load_txt(file_path):
|
43 |
+
loader = TextLoader(file_path=file_path, encoding="utf-8")
|
44 |
+
doc = loader.load()
|
45 |
+
return doc
|
46 |
+
|
47 |
+
# Hàm định dạng văn bản
|
48 |
+
def format_docs(docs):
|
49 |
+
"""Định dạng các tài liệu thành chuỗi văn bản."""
|
50 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
51 |
+
|
52 |
+
# Hàm thiết lập mô hình Google Gemini
|
53 |
@st.cache_resource
|
54 |
def get_chat_google_model(api_key):
|
55 |
os.environ["GOOGLE_API_KEY"] = api_key
|
|
|
61 |
max_retries=2,
|
62 |
)
|
63 |
|
64 |
+
# Hàm thiết lập mô hình embedding
|
65 |
@st.cache_resource
|
66 |
def get_embedding_model():
|
67 |
model_name = "bkai-foundation-models/vietnamese-bi-encoder"
|
|
|
75 |
)
|
76 |
return model
|
77 |
|
78 |
+
# Hàm tạo RAG Chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
@st.cache_resource
|
80 |
def compute_rag_chain(_model, _embd, docs_texts):
|
81 |
if not docs_texts:
|
82 |
+
raise ValueError("Không có tài liệu nào để xử lý. Vui lòng tải lên các tệp hợp lệ.")
|
83 |
|
84 |
combined_text = "\n\n".join(docs_texts)
|
85 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
86 |
texts = text_splitter.split_text(combined_text)
|
87 |
|
88 |
if len(texts) > 5000:
|
89 |
+
raise ValueError("Tài liệu tạo ra quá nhiều đoạn. Vui lòng sử dụng tài liệu nhỏ hơn.")
|
90 |
|
91 |
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd)
|
92 |
retriever = vectorstore.as_retriever()
|
|
|
110 |
)
|
111 |
return rag_chain
|
112 |
|
113 |
+
# Dialog cài đặt Google Gemini
|
114 |
@st.dialog("Setup Gemini")
|
115 |
def setup_gemini():
|
116 |
st.markdown(
|
|
|
138 |
os.makedirs(save_dir)
|
139 |
st.session_state.save_dir = save_dir
|
140 |
|
141 |
+
# Sidebar: Upload file và xử lý
|
142 |
with st.sidebar:
|
143 |
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
|
144 |
max_file_size_mb = 5
|
|
|
159 |
if documents:
|
160 |
docs_texts = [d.page_content for d in documents]
|
161 |
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
|
162 |
+
|
163 |
+
# Giao diện chat
|
164 |
+
for message in st.session_state.chat_history:
|
165 |
+
with st.chat_message(message["role"]):
|
166 |
+
st.write(message["content"])
|
167 |
+
|
168 |
+
prompt = st.chat_input("Bạn muốn hỏi gì?")
|
169 |
+
if st.session_state.model is not None:
|
170 |
+
if prompt:
|
171 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
172 |
+
with st.chat_message("user"):
|
173 |
+
st.write(prompt)
|
174 |
+
|
175 |
+
with st.chat_message("assistant"):
|
176 |
+
if st.session_state.rag is not None:
|
177 |
+
response = st.session_state.rag.invoke(prompt)
|
178 |
+
st.write(response)
|
179 |
+
else:
|
180 |
+
ans = st.session_state.llm.invoke(prompt)
|
181 |
+
response = ans.content
|
182 |
+
st.write(response)
|
183 |
+
|
184 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|