File size: 6,902 Bytes
caa9f7a
 
 
 
 
 
 
 
 
26305ff
 
caa9f7a
 
8df61df
caa9f7a
 
8df61df
caa9f7a
 
 
 
 
af5ec80
caa9f7a
 
 
af5ec80
 
 
 
 
 
 
 
 
 
 
 
8df61df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caa9f7a
 
 
 
26305ff
caa9f7a
 
 
 
 
 
8df61df
caa9f7a
 
 
 
 
 
 
 
 
 
af5ec80
caa9f7a
 
8df61df
af5ec80
 
 
8df61df
af5ec80
 
dca18ab
af5ec80
 
dca18ab
8df61df
af5ec80
f321a77
 
 
 
 
26305ff
 
 
f321a77
26305ff
af5ec80
caa9f7a
f321a77
af5ec80
 
 
 
 
 
 
 
 
 
caa9f7a
af5ec80
 
 
 
 
 
 
 
8df61df
caa9f7a
af5ec80
caa9f7a
 
 
 
 
 
 
 
af5ec80
caa9f7a
 
af5ec80
caa9f7a
 
 
 
af5ec80
 
 
caa9f7a
 
 
 
 
 
f321a77
caa9f7a
af5ec80
dca18ab
af5ec80
 
 
dca18ab
 
 
 
 
 
 
 
 
 
af5ec80
 
 
 
8df61df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38a5afa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import streamlit as st
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders import TextLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.vectorstores import Chroma
from chromadb.config import Settings
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Tiêu đề ứng dụng
page = st.title("Chat with AskUSTH")

# Khởi tạo trạng thái phiên
if "gemini_api" not in st.session_state:
    st.session_state.gemini_api = None

if "rag" not in st.session_state:
    st.session_state.rag = None

if "llm" not in st.session_state:
    st.session_state.llm = None

if "embd" not in st.session_state:
    st.session_state.embd = None

if "model" not in st.session_state:
    st.session_state.model = None

if "save_dir" not in st.session_state:
    st.session_state.save_dir = None 

if "uploaded_files" not in st.session_state:
    st.session_state.uploaded_files = set()

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Hàm tải và xử lý file văn bản
def load_txt(file_path):
    loader = TextLoader(file_path=file_path, encoding="utf-8")
    doc = loader.load()
    return doc

# Hàm định dạng văn bản
def format_docs(docs):
    """Định dạng các tài liệu thành chuỗi văn bản."""
    return "\n\n".join(doc.page_content for doc in docs)

# Hàm thiết lập mô hình Google Gemini
@st.cache_resource
def get_chat_google_model(api_key):
    os.environ["GOOGLE_API_KEY"] = api_key
    return ChatGoogleGenerativeAI(
        model="gemini-1.5-pro",
        temperature=0,
        max_tokens=None,
        timeout=None,
        max_retries=2,
    )

# Hàm thiết lập mô hình embedding
@st.cache_resource
def get_embedding_model():
    model_name = "bkai-foundation-models/vietnamese-bi-encoder"
    model_kwargs = {'device': 'cpu'}
    encode_kwargs = {'normalize_embeddings': False}
    
    model = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs
    )
    return model

# Hàm tạo RAG Chain
@st.cache_resource
def compute_rag_chain(_model, _embd, docs_texts):
    if not docs_texts:
        raise ValueError("Không có tài liệu nào để xử lý. Vui lòng tải lên các tệp hợp lệ.")
    
    combined_text = "\n\n".join(docs_texts)
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    texts = text_splitter.split_text(combined_text)
    
    if len(texts) > 5000:
        raise ValueError("Tài liệu tạo ra quá nhiều đoạn. Vui lòng sử dụng tài liệu nhỏ hơn.")
    
    # Tạo thư mục lưu trữ
    persist_dir = "./chromadb_store"
    if not os.path.exists(persist_dir):
        os.makedirs(persist_dir)
    
    # Khởi tạo Chroma với cấu hình lưu trữ
    settings = Settings(persist_directory=persist_dir)

    # Khởi tạo Chroma và lưu dữ liệu
    vectorstore = Chroma.from_texts(texts=texts, embedding=_embd, client_settings=settings)
    retriever = vectorstore.as_retriever()

    # Template cho prompt
    template = """
        Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. 
        Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi.
        Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.
        Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:
        {context}
        hãy trả lời:
        {question}
    """
    prompt = PromptTemplate(template=template, input_variables=["context", "question"])
    
    rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | _model
        | StrOutputParser()
    )
    return rag_chain

# Dialog cài đặt Google Gemini
@st.dialog("Setup Gemini")
def setup_gemini():
    st.markdown(
        """
        Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới.
        """
    )
    key = st.text_input("Key:", "")
    if st.button("Save") and key != "":
        st.session_state.gemini_api = key
        st.rerun()

if st.session_state.gemini_api is None:
    setup_gemini()

if st.session_state.gemini_api and st.session_state.model is None:
    st.session_state.model = get_chat_google_model(st.session_state.gemini_api)

if st.session_state.embd is None:
    st.session_state.embd = get_embedding_model()

if st.session_state.save_dir is None:
    save_dir = "./Documents"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    st.session_state.save_dir = save_dir

# Cập nhật xử lý Sidebar
with st.sidebar:
    uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
    max_file_size_mb = 5
    if uploaded_files:
        documents = []
        for uploaded_file in uploaded_files:
            if uploaded_file.size > max_file_size_mb * 1024 * 1024:
                st.warning(f"Tệp {uploaded_file.name} vượt quá giới hạn {max_file_size_mb}MB.")
                continue
            
            file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
            with open(file_path, mode='wb') as w:
                w.write(uploaded_file.getvalue())
            
            doc = load_txt(file_path)
            documents.extend([*doc])
        
        if documents:
            docs_texts = [d.page_content for d in documents]
            st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)

# Giao diện chat
for message in st.session_state.chat_history:
    with st.chat_message(message["role"]):
        st.write(message["content"])

prompt = st.chat_input("Bạn muốn hỏi gì?")
if st.session_state.model is not None:
    if prompt:
        st.session_state.chat_history.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.write(prompt)
            
        with st.chat_message("assistant"):
            if st.session_state.rag is not None:
                response = st.session_state.rag.invoke(prompt)
                st.write(response)
            else: 
                ans = st.session_state.llm.invoke(prompt)
                response = ans.content
                st.write(response)
                
        st.session_state.chat_history.append({"role": "assistant", "content": response})