AskUSTH / app.py
nkcong206's picture
Update app.py
38a5afa verified
import streamlit as st
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders import TextLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.vectorstores import Chroma
from chromadb.config import Settings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Tiêu đề ứng dụng
page = st.title("Chat with AskUSTH")
# Khởi tạo trạng thái phiên
if "gemini_api" not in st.session_state:
st.session_state.gemini_api = None
if "rag" not in st.session_state:
st.session_state.rag = None
if "llm" not in st.session_state:
st.session_state.llm = None
if "embd" not in st.session_state:
st.session_state.embd = None
if "model" not in st.session_state:
st.session_state.model = None
if "save_dir" not in st.session_state:
st.session_state.save_dir = None
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = set()
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Hàm tải và xử lý file văn bản
def load_txt(file_path):
loader = TextLoader(file_path=file_path, encoding="utf-8")
doc = loader.load()
return doc
# Hàm định dạng văn bản
def format_docs(docs):
"""Định dạng các tài liệu thành chuỗi văn bản."""
return "\n\n".join(doc.page_content for doc in docs)
# Hàm thiết lập mô hình Google Gemini
@st.cache_resource
def get_chat_google_model(api_key):
os.environ["GOOGLE_API_KEY"] = api_key
return ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
)
# Hàm thiết lập mô hình embedding
@st.cache_resource
def get_embedding_model():
model_name = "bkai-foundation-models/vietnamese-bi-encoder"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
return model
# Hàm tạo RAG Chain
@st.cache_resource
def compute_rag_chain(_model, _embd, docs_texts):
if not docs_texts:
raise ValueError("Không có tài liệu nào để xử lý. Vui lòng tải lên các tệp hợp lệ.")
combined_text = "\n\n".join(docs_texts)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
texts = text_splitter.split_text(combined_text)
if len(texts) > 5000:
raise ValueError("Tài liệu tạo ra quá nhiều đoạn. Vui lòng sử dụng tài liệu nhỏ hơn.")
# Tạo thư mục lưu trữ
persist_dir = "./chromadb_store"
if not os.path.exists(persist_dir):
os.makedirs(persist_dir)
# Khởi tạo Chroma với cấu hình lưu trữ
settings = Settings(persist_directory=persist_dir)
# Khởi tạo Chroma và lưu dữ liệu
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd, client_settings=settings)
retriever = vectorstore.as_retriever()
# Template cho prompt
template = """
Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên.
Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi.
Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.
Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:
{context}
hãy trả lời:
{question}
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| _model
| StrOutputParser()
)
return rag_chain
# Dialog cài đặt Google Gemini
@st.dialog("Setup Gemini")
def setup_gemini():
st.markdown(
"""
Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
"""
)
key = st.text_input("Key:", "")
if st.button("Save") and key != "":
st.session_state.gemini_api = key
st.rerun()
if st.session_state.gemini_api is None:
setup_gemini()
if st.session_state.gemini_api and st.session_state.model is None:
st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
if st.session_state.embd is None:
st.session_state.embd = get_embedding_model()
if st.session_state.save_dir is None:
save_dir = "./Documents"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
st.session_state.save_dir = save_dir
# Cập nhật xử lý Sidebar
with st.sidebar:
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
max_file_size_mb = 5
if uploaded_files:
documents = []
for uploaded_file in uploaded_files:
if uploaded_file.size > max_file_size_mb * 1024 * 1024:
st.warning(f"Tệp {uploaded_file.name} vượt quá giới hạn {max_file_size_mb}MB.")
continue
file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
with open(file_path, mode='wb') as w:
w.write(uploaded_file.getvalue())
doc = load_txt(file_path)
documents.extend([*doc])
if documents:
docs_texts = [d.page_content for d in documents]
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
# Giao diện chat
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
prompt = st.chat_input("Bạn muốn hỏi gì?")
if st.session_state.model is not None:
if prompt:
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
if st.session_state.rag is not None:
response = st.session_state.rag.invoke(prompt)
st.write(response)
else:
ans = st.session_state.llm.invoke(prompt)
response = ans.content
st.write(response)
st.session_state.chat_history.append({"role": "assistant", "content": response})