|
import streamlit as st |
|
import os |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.prompts import PromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain.vectorstores import Chroma |
|
from chromadb.config import Settings |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
|
|
|
page = st.title("Chat with AskUSTH") |
|
|
|
|
|
if "gemini_api" not in st.session_state: |
|
st.session_state.gemini_api = None |
|
|
|
if "rag" not in st.session_state: |
|
st.session_state.rag = None |
|
|
|
if "llm" not in st.session_state: |
|
st.session_state.llm = None |
|
|
|
if "embd" not in st.session_state: |
|
st.session_state.embd = None |
|
|
|
if "model" not in st.session_state: |
|
st.session_state.model = None |
|
|
|
if "save_dir" not in st.session_state: |
|
st.session_state.save_dir = None |
|
|
|
if "uploaded_files" not in st.session_state: |
|
st.session_state.uploaded_files = set() |
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
|
|
def load_txt(file_path): |
|
loader = TextLoader(file_path=file_path, encoding="utf-8") |
|
doc = loader.load() |
|
return doc |
|
|
|
|
|
def format_docs(docs): |
|
"""Định dạng các tài liệu thành chuỗi văn bản.""" |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
@st.cache_resource |
|
def get_chat_google_model(api_key): |
|
os.environ["GOOGLE_API_KEY"] = api_key |
|
return ChatGoogleGenerativeAI( |
|
model="gemini-1.5-pro", |
|
temperature=0, |
|
max_tokens=None, |
|
timeout=None, |
|
max_retries=2, |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def get_embedding_model(): |
|
model_name = "bkai-foundation-models/vietnamese-bi-encoder" |
|
model_kwargs = {'device': 'cpu'} |
|
encode_kwargs = {'normalize_embeddings': False} |
|
|
|
model = HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
return model |
|
|
|
|
|
@st.cache_resource |
|
def compute_rag_chain(_model, _embd, docs_texts): |
|
if not docs_texts: |
|
raise ValueError("Không có tài liệu nào để xử lý. Vui lòng tải lên các tệp hợp lệ.") |
|
|
|
combined_text = "\n\n".join(docs_texts) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
texts = text_splitter.split_text(combined_text) |
|
|
|
if len(texts) > 5000: |
|
raise ValueError("Tài liệu tạo ra quá nhiều đoạn. Vui lòng sử dụng tài liệu nhỏ hơn.") |
|
|
|
|
|
persist_dir = "./chromadb_store" |
|
if not os.path.exists(persist_dir): |
|
os.makedirs(persist_dir) |
|
|
|
|
|
settings = Settings(persist_directory=persist_dir) |
|
|
|
|
|
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd, client_settings=settings) |
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
template = """ |
|
Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. |
|
Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. |
|
Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời. |
|
Dưới đây là thông tin liên quan mà bạn cần sử dụng tới: |
|
{context} |
|
hãy trả lời: |
|
{question} |
|
""" |
|
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) |
|
|
|
rag_chain = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| _model |
|
| StrOutputParser() |
|
) |
|
return rag_chain |
|
|
|
|
|
@st.dialog("Setup Gemini") |
|
def setup_gemini(): |
|
st.markdown( |
|
""" |
|
Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới. |
|
""" |
|
) |
|
key = st.text_input("Key:", "") |
|
if st.button("Save") and key != "": |
|
st.session_state.gemini_api = key |
|
st.rerun() |
|
|
|
if st.session_state.gemini_api is None: |
|
setup_gemini() |
|
|
|
if st.session_state.gemini_api and st.session_state.model is None: |
|
st.session_state.model = get_chat_google_model(st.session_state.gemini_api) |
|
|
|
if st.session_state.embd is None: |
|
st.session_state.embd = get_embedding_model() |
|
|
|
if st.session_state.save_dir is None: |
|
save_dir = "./Documents" |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
st.session_state.save_dir = save_dir |
|
|
|
|
|
with st.sidebar: |
|
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"]) |
|
max_file_size_mb = 5 |
|
if uploaded_files: |
|
documents = [] |
|
for uploaded_file in uploaded_files: |
|
if uploaded_file.size > max_file_size_mb * 1024 * 1024: |
|
st.warning(f"Tệp {uploaded_file.name} vượt quá giới hạn {max_file_size_mb}MB.") |
|
continue |
|
|
|
file_path = os.path.join(st.session_state.save_dir, uploaded_file.name) |
|
with open(file_path, mode='wb') as w: |
|
w.write(uploaded_file.getvalue()) |
|
|
|
doc = load_txt(file_path) |
|
documents.extend([*doc]) |
|
|
|
if documents: |
|
docs_texts = [d.page_content for d in documents] |
|
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts) |
|
|
|
|
|
for message in st.session_state.chat_history: |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
|
|
prompt = st.chat_input("Bạn muốn hỏi gì?") |
|
if st.session_state.model is not None: |
|
if prompt: |
|
st.session_state.chat_history.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
with st.chat_message("assistant"): |
|
if st.session_state.rag is not None: |
|
response = st.session_state.rag.invoke(prompt) |
|
st.write(response) |
|
else: |
|
ans = st.session_state.llm.invoke(prompt) |
|
response = ans.content |
|
st.write(response) |
|
|
|
st.session_state.chat_history.append({"role": "assistant", "content": response}) |