Spaces:
Sleeping
Sleeping
File size: 4,691 Bytes
7f96312 2c20468 0dda7f4 7b3bf1d 812f60c 7b3bf1d 2c20468 0dda7f4 2c20468 0dda7f4 2c20468 812f60c 7b3bf1d 812f60c 7b3bf1d 812f60c 0dda7f4 812f60c 2c20468 812f60c 2c20468 812f60c 2c20468 812f60c 2c20468 812f60c 2c20468 812f60c 2c20468 812f60c 2c20468 7b3bf1d 2c20468 812f60c 2c20468 812f60c 2c20468 812f60c 2c20468 812f60c 2c20468 812f60c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import os
from langchain.vectorstores.faiss import FAISS # Direktimport
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
list_llm = ["google/flan-t5-small", "distilbert-base-uncased"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(list_file_path):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
def create_db(splits):
embeddings = HuggingFaceEmbeddings()
vectordb = FAISS.from_documents(splits, embeddings)
return vectordb
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={
"temperature": temperature,
"max_length": max_tokens,
"top_k": top_k,
}
)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False
)
return qa_chain
def initialize_database(list_file_obj):
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Datenbank erfolgreich erstellt!"
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
return qa_chain, "LLM erfolgreich initialisiert! Chatbot ist bereit."
def format_chat_history(message, chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value=""), new_history
def demo():
with gr.Blocks() as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
with gr.Row():
with gr.Column():
gr.Markdown("### Schritt 1: Lade PDF-Dokument hoch")
document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True)
db_btn = gr.Button("Erstelle Vektordatenbank")
db_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
gr.Markdown("### Schritt 2: Wähle LLM und Einstellungen")
llm_btn = gr.Radio(list_llm_simple, label="Verfügbare Modelle", value=list_llm_simple[0], type="index")
slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Temperature")
slider_maxtokens = gr.Slider(64, 512, value=256, step=64, label="Max Tokens")
slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-k")
qachain_btn = gr.Button("Initialisiere QA-Chatbot")
llm_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
with gr.Column():
gr.Markdown("### Schritt 3: Stelle Fragen an dein Dokument")
chatbot = gr.Chatbot(height=400, type="messages")
msg = gr.Textbox(placeholder="Frage stellen...")
submit_btn = gr.Button("Absenden")
db_btn.click(initialize_database, [document], [vector_db, db_progress])
qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
demo.launch(debug=True)
if __name__ == "__main__":
demo()
|