File size: 6,902 Bytes
caa9f7a 26305ff caa9f7a 8df61df caa9f7a 8df61df caa9f7a af5ec80 caa9f7a af5ec80 8df61df caa9f7a 26305ff caa9f7a 8df61df caa9f7a af5ec80 caa9f7a 8df61df af5ec80 8df61df af5ec80 dca18ab af5ec80 dca18ab 8df61df af5ec80 f321a77 26305ff f321a77 26305ff af5ec80 caa9f7a f321a77 af5ec80 caa9f7a af5ec80 8df61df caa9f7a af5ec80 caa9f7a af5ec80 caa9f7a af5ec80 caa9f7a af5ec80 caa9f7a f321a77 caa9f7a af5ec80 dca18ab af5ec80 dca18ab af5ec80 8df61df 38a5afa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import streamlit as st
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders import TextLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.vectorstores import Chroma
from chromadb.config import Settings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Tiêu đề ứng dụng
page = st.title("Chat with AskUSTH")
# Khởi tạo trạng thái phiên
if "gemini_api" not in st.session_state:
st.session_state.gemini_api = None
if "rag" not in st.session_state:
st.session_state.rag = None
if "llm" not in st.session_state:
st.session_state.llm = None
if "embd" not in st.session_state:
st.session_state.embd = None
if "model" not in st.session_state:
st.session_state.model = None
if "save_dir" not in st.session_state:
st.session_state.save_dir = None
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = set()
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Hàm tải và xử lý file văn bản
def load_txt(file_path):
loader = TextLoader(file_path=file_path, encoding="utf-8")
doc = loader.load()
return doc
# Hàm định dạng văn bản
def format_docs(docs):
"""Định dạng các tài liệu thành chuỗi văn bản."""
return "\n\n".join(doc.page_content for doc in docs)
# Hàm thiết lập mô hình Google Gemini
@st.cache_resource
def get_chat_google_model(api_key):
os.environ["GOOGLE_API_KEY"] = api_key
return ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
)
# Hàm thiết lập mô hình embedding
@st.cache_resource
def get_embedding_model():
model_name = "bkai-foundation-models/vietnamese-bi-encoder"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
return model
# Hàm tạo RAG Chain
@st.cache_resource
def compute_rag_chain(_model, _embd, docs_texts):
if not docs_texts:
raise ValueError("Không có tài liệu nào để xử lý. Vui lòng tải lên các tệp hợp lệ.")
combined_text = "\n\n".join(docs_texts)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
texts = text_splitter.split_text(combined_text)
if len(texts) > 5000:
raise ValueError("Tài liệu tạo ra quá nhiều đoạn. Vui lòng sử dụng tài liệu nhỏ hơn.")
# Tạo thư mục lưu trữ
persist_dir = "./chromadb_store"
if not os.path.exists(persist_dir):
os.makedirs(persist_dir)
# Khởi tạo Chroma với cấu hình lưu trữ
settings = Settings(persist_directory=persist_dir)
# Khởi tạo Chroma và lưu dữ liệu
vectorstore = Chroma.from_texts(texts=texts, embedding=_embd, client_settings=settings)
retriever = vectorstore.as_retriever()
# Template cho prompt
template = """
Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên.
Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi.
Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.
Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:
{context}
hãy trả lời:
{question}
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| _model
| StrOutputParser()
)
return rag_chain
# Dialog cài đặt Google Gemini
@st.dialog("Setup Gemini")
def setup_gemini():
st.markdown(
"""
Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
"""
)
key = st.text_input("Key:", "")
if st.button("Save") and key != "":
st.session_state.gemini_api = key
st.rerun()
if st.session_state.gemini_api is None:
setup_gemini()
if st.session_state.gemini_api and st.session_state.model is None:
st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
if st.session_state.embd is None:
st.session_state.embd = get_embedding_model()
if st.session_state.save_dir is None:
save_dir = "./Documents"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
st.session_state.save_dir = save_dir
# Cập nhật xử lý Sidebar
with st.sidebar:
uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
max_file_size_mb = 5
if uploaded_files:
documents = []
for uploaded_file in uploaded_files:
if uploaded_file.size > max_file_size_mb * 1024 * 1024:
st.warning(f"Tệp {uploaded_file.name} vượt quá giới hạn {max_file_size_mb}MB.")
continue
file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
with open(file_path, mode='wb') as w:
w.write(uploaded_file.getvalue())
doc = load_txt(file_path)
documents.extend([*doc])
if documents:
docs_texts = [d.page_content for d in documents]
st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
# Giao diện chat
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
prompt = st.chat_input("Bạn muốn hỏi gì?")
if st.session_state.model is not None:
if prompt:
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
if st.session_state.rag is not None:
response = st.session_state.rag.invoke(prompt)
st.write(response)
else:
ans = st.session_state.llm.invoke(prompt)
response = ans.content
st.write(response)
st.session_state.chat_history.append({"role": "assistant", "content": response}) |