Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ from transformers import pipeline
|
|
11 |
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
12 |
LLM_MODEL_NAME = "google/flan-t5-small"
|
13 |
|
|
|
|
|
14 |
# **Dokumente laden und aufteilen**
|
15 |
def load_and_split_docs(list_file_path):
|
16 |
if not list_file_path:
|
@@ -63,7 +65,7 @@ def initialize_llm_chain(temperature, max_tokens, vector_db):
|
|
63 |
)
|
64 |
print(f"Modell {LLM_MODEL_NAME} erfolgreich geladen.")
|
65 |
llm = HuggingFacePipeline(pipeline=local_pipeline)
|
66 |
-
memory = ConversationBufferMemory(memory_key="chat_history")
|
67 |
retriever = vector_db.as_retriever()
|
68 |
return ConversationalRetrievalChain.from_llm(
|
69 |
llm,
|
@@ -73,6 +75,18 @@ def initialize_llm_chain(temperature, max_tokens, vector_db):
|
|
73 |
)
|
74 |
|
75 |
# **Konversation mit QA-Kette führen**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def conversation(qa_chain, message, history):
|
77 |
if qa_chain is None:
|
78 |
return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
|
@@ -80,7 +94,7 @@ def conversation(qa_chain, message, history):
|
|
80 |
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
|
81 |
try:
|
82 |
print(f"Frage: {message}")
|
83 |
-
history = history
|
84 |
response = qa_chain.invoke({"question": message, "chat_history": history})
|
85 |
response_text = response["answer"]
|
86 |
sources = [doc.metadata["source"] for doc in response["source_documents"]]
|
|
|
11 |
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
12 |
LLM_MODEL_NAME = "google/flan-t5-small"
|
13 |
|
14 |
+
MAX_INPUT_LENGTH = 512 # Maximale Länge der Eingabe für das Modell
|
15 |
+
|
16 |
# **Dokumente laden und aufteilen**
|
17 |
def load_and_split_docs(list_file_path):
|
18 |
if not list_file_path:
|
|
|
65 |
)
|
66 |
print(f"Modell {LLM_MODEL_NAME} erfolgreich geladen.")
|
67 |
llm = HuggingFacePipeline(pipeline=local_pipeline)
|
68 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer") # Speichere nur die Antwort
|
69 |
retriever = vector_db.as_retriever()
|
70 |
return ConversationalRetrievalChain.from_llm(
|
71 |
llm,
|
|
|
75 |
)
|
76 |
|
77 |
# **Konversation mit QA-Kette führen**
|
78 |
+
def truncate_history(history, max_length=MAX_INPUT_LENGTH):
|
79 |
+
total_length = 0
|
80 |
+
truncated_history = []
|
81 |
+
|
82 |
+
for message in reversed(history):
|
83 |
+
total_length += len(message[0]) + len(message[1])
|
84 |
+
if total_length > max_length:
|
85 |
+
break
|
86 |
+
truncated_history.insert(0, message)
|
87 |
+
|
88 |
+
return truncated_history
|
89 |
+
|
90 |
def conversation(qa_chain, message, history):
|
91 |
if qa_chain is None:
|
92 |
return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
|
|
|
94 |
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
|
95 |
try:
|
96 |
print(f"Frage: {message}")
|
97 |
+
history = truncate_history(history) # Beschränke den Verlauf auf unter 512 Tokens
|
98 |
response = qa_chain.invoke({"question": message, "chat_history": history})
|
99 |
response_text = response["answer"]
|
100 |
sources = [doc.metadata["source"] for doc in response["source_documents"]]
|