Spaces:
Sleeping
Sleeping
File size: 7,030 Bytes
080146c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# llm_handling.py
import logging
import os
from langchain_community.vectorstores import FAISS
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
import json
from app.config import *
from app.configs.prompts import SYSTEM_PROMPTS
from app.utils.embedding_utils import get_embeddings
from app.utils.voice_utils import generate_speech
logging.basicConfig(level=logging.INFO)
# =====================================
# Funzioni relative al LLM
# =====================================
def get_llm_client(llm_type: LLMType):
"""Ottiene il client appropriato per il modello selezionato"""
config = LLM_CONFIGS.get(llm_type)
if not config:
raise ValueError(f"Modello {llm_type} non supportato")
return config["client"](), config["model"]
def get_system_prompt(prompt_type="tutor"):
"""Seleziona il prompt di sistema appropriato"""
return SYSTEM_PROMPTS.get(prompt_type, SYSTEM_PROMPTS["tutor"])
def test_local_connection():
"""Verifica la connessione al server LLM locale"""
try:
response = requests.get(f"http://192.168.82.5:1234/v1/health", timeout=5)
return response.status_code == 200
except:
return False
def read_metadata(db_path):
metadata_file = os.path.join(db_path, "metadata.json")
if os.path.exists(metadata_file):
with open(metadata_file, 'r') as f:
return json.load(f)
return []
def get_relevant_documents(vectorstore, question, min_similarity=0.7):
"""Recupera i documenti rilevanti dal vectorstore"""
try:
# Migliora la query prima della ricerca
enhanced_query = enhance_query(question)
# Ottieni documenti con punteggi di similarità
docs_and_scores = vectorstore.similarity_search_with_score(
enhanced_query,
k=8 # Aumenta il numero di documenti recuperati
)
# Filtra i documenti per similarità
filtered_docs = [
doc for doc, score in docs_and_scores
if score >= min_similarity
]
# Log dei risultati per debug
logging.info(f"Query: {question}")
logging.info(f"Documenti trovati: {len(filtered_docs)}")
# Restituisci almeno un documento o una lista vuota
return filtered_docs[:5] if filtered_docs else []
except Exception as e:
logging.error(f"Errore nel recupero dei documenti: {e}")
return [] # Restituisce lista vuota invece di None
def enhance_query(question):
# Rimuovi parole non significative
stop_words = set(['il', 'lo', 'la', 'i', 'gli', 'le', 'un', 'uno', 'una'])
words = [w for w in question.lower().split() if w not in stop_words]
# Estrai keywords chiave
enhanced_query = " ".join(words)
return enhanced_query
def log_search_results(question, docs_and_scores):
logging.info(f"Query: {question}")
for idx, (doc, score) in enumerate(docs_and_scores, 1):
logging.info(f"Doc {idx} - Score: {score:.4f}")
logging.info(f"Content: {doc.page_content[:100]}...")
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def answer_question(question, db_name, prompt_type="tutor", chat_history=None, llm_type=LLMType.OPENAI_GPT_4O_MINI):
if chat_history is None:
chat_history = []
try:
embeddings = get_embeddings() # Usa la funzione comune
db_path = os.path.join(BASE_DB_PATH, f"faiss_index_{db_name}")
# Leggi i metadati
metadata_list = read_metadata(db_path)
metadata_dict = {m["filename"]: m for m in metadata_list}
# Recupera i documenti rilevanti
vectorstore = FAISS.load_local(db_path, embeddings, allow_dangerous_deserialization=True)
relevant_docs = get_relevant_documents(vectorstore, question)
if not relevant_docs:
return [
{"role": "user", "content": question},
{"role": "assistant", "content": "Mi dispiace, non ho trovato informazioni rilevanti per rispondere alla tua domanda. Prova a riformularla o a fare una domanda diversa."}
]
# Prepara le citazioni delle fonti con numerazione dei chunk
sources = []
for idx, doc in enumerate(relevant_docs, 1):
source_file = doc.metadata.get("source", "Unknown")
if source_file in metadata_dict:
meta = metadata_dict[source_file]
sources.append(f"📚 {meta['title']} (Autore: {meta['author']}) - Parte {idx} di {len(relevant_docs)}")
# Prepara il contesto con le fonti
context = "\n".join([
f"[Parte {idx+1} di {len(relevant_docs)}]\n{doc.page_content}"
for idx, doc in enumerate(relevant_docs)
])
sources_text = "\n\nFonti consultate:\n" + "\n".join(set(sources))
# Aggiorna il prompt per includere la richiesta di citare le fonti
prompt = SYSTEM_PROMPTS[prompt_type].format(context=context)
prompt += "\nCita sempre le fonti utilizzate per la tua risposta includendo il titolo del documento e l'autore."
# Costruisci il messaggio completo
messages = [
{"role": "system", "content": prompt},
*[{"role": m["role"], "content": m["content"]} for m in chat_history],
{"role": "user", "content": question}
]
# Ottieni la risposta dall'LLM
client, model = get_llm_client(llm_type)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0.7,
max_tokens=2048
)
answer = response.choices[0].message.content + sources_text
# return [
# {"role": "user", "content": question, "audio": user_audio},
# {"role": "assistant", "content": answer, "audio": assistant_audio}
# ]
except Exception as e:
logging.error(f"Errore durante la generazione della risposta: {e}")
error_msg = "LLM locale non disponibile. Riprova più tardi o usa OpenAI." if "local" in str(llm_type) else str(e)
return [
{"role": "user", "content": question},
{"role": "assistant", "content": f"⚠️ {error_msg}"}
]
class DocumentRetriever:
def __init__(self, db_path):
self.embeddings = get_embeddings()
self.vectorstore = FAISS.load_local(
db_path,
self.embeddings,
allow_dangerous_deserialization=True
)
def get_relevant_chunks(self, question):
enhanced_query = enhance_query(question)
docs_and_scores = self.vectorstore.similarity_search_with_score(
enhanced_query,
k=8
)
log_search_results(question, docs_and_scores)
return self._filter_relevant_docs(docs_and_scores)
if __name__ == "__main__":
pass
|