# llm_handling.py
import logging
import os
from langchain_community.vectorstores import FAISS
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
import json
from collections import defaultdict

from app.config import (
    BASE_DB_PATH,
    LLM_CONFIGS, 
    LLMType,
    EMBEDDING_CONFIG,
    LLM_CONFIGS_EXTENDED
)
from app.configs.prompts import SYSTEM_PROMPTS
from app.utils.embedding_utils import get_embeddings


logging.basicConfig(level=logging.INFO)

# =====================================
# Functions related to LLM
# =====================================

def get_llm_client(llm_type: LLMType):
    """Obtains the appropriate client for the selected model"""
    config = LLM_CONFIGS.get(llm_type)
    if not config:
        raise ValueError(f"Model {llm_type} not supported")
    client_class = config["client"]
    model = config["model"]
    client = client_class()  # Ensure no arguments are needed
    return client, model

def get_system_prompt(prompt_type="tutor"):
    """Selects the appropriate system prompt"""
    return SYSTEM_PROMPTS.get(prompt_type, SYSTEM_PROMPTS["tutor"])

def test_local_connection():
    """Checks connection to the local LLM server"""
    try:
        response = requests.get(f"http://192.168.43.199:1234/v1/health", timeout=5)
        return response.status_code == 200
    except:
        return False

def read_metadata(db_path):
    metadata_file = os.path.join(db_path, "metadata.json")
    if os.path.exists(metadata_file):
        with open(metadata_file, 'r') as f:
            return json.load(f)
    return []

def get_relevant_documents(vectorstore, question):
    """Retrieves relevant documents from the vectorstore based on similarity threshold"""
    try:
        enhanced_query = enhance_query(question)
        # Get all documents with their similarity scores
        docs_and_scores = vectorstore.similarity_search_with_score(enhanced_query)
        # Filter documents based on similarity threshold
        filtered_docs = [
            doc for doc, score in docs_and_scores 
            if score >= EMBEDDING_CONFIG['min_similarity']
        ]
        logging.info(f"Query: {question}")
        logging.info(f"Documents found: {len(filtered_docs)}")
        return filtered_docs if filtered_docs else []
    except Exception as e:
        logging.error(f"Error retrieving documents: {e}")
        return []

def enhance_query(question):
    stop_words = set(['il', 'lo', 'la', 'i', 'gli', 'le', 'un', 'uno', 'una'])
    words = [w for w in question.lower().split() if w not in stop_words]
    return " ".join(words)

def log_search_results(question, docs_and_scores):
    logging.info(f"Query: {question}")
    for idx, (doc, score) in enumerate(docs_and_scores, 1):
        logging.info(f"Doc {idx} - Score: {score:.4f}")
        logging.info(f"Content: {doc.page_content[:100]}...")

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def summarize_context(messages):
    """Crea un riassunto del contesto mantenendo le informazioni chiave"""
    summary = []
    key_info = set()
    
    for msg in messages:
        if msg["role"] == "system":
            continue
            
        # Estrai informazioni chiave
        content = msg["content"]
        if "fonte" in content.lower() or "fonti" in content.lower():
            key_info.add(content)
        elif "importante" in content.lower() or "nota" in content.lower():
            key_info.add(content)
            
    if key_info:
        summary.append({
            "role": "system",
            "content": "Contesto riassunto:\n" + "\n".join(f"- {info}" for info in key_info)
        })
        
    return summary

def answer_question(question, db_name, prompt_type="tutor", chat_history=None, llm_type=LLMType.OPENAI_GPT_4O_MINI):
    if chat_history is None:
        chat_history = []
        
    # Configurazione dinamica della cronologia
    MAX_HISTORY_TOKENS = int(LLM_CONFIGS_EXTENDED["max_tokens"] * 0.4)  # 40% dei token totali
    MIN_HISTORY_ITEMS = 2  # Mantieni almeno l'ultimo scambio
    
    # Calcola la lunghezza della cronologia attuale
    current_tokens = sum(len(m["content"].split()) for m in chat_history)
    
    # Se superiamo il limite, creiamo un riassunto
    if current_tokens > MAX_HISTORY_TOKENS:
        summary = summarize_context(chat_history)
        # Manteniamo l'ultimo scambio completo
        last_exchange = chat_history[-MIN_HISTORY_ITEMS:]
        chat_history = summary + last_exchange
    
    try:
        # Setup e recupero documenti
        db_path = os.path.join(BASE_DB_PATH, f"faiss_index_{db_name}")
        embeddings = get_embeddings()
        vectorstore = FAISS.load_local(db_path, embeddings, allow_dangerous_deserialization=True)
        relevant_docs = get_relevant_documents(vectorstore, question)
        
        if not relevant_docs:
            return [
                {"role": "user", "content": question},
                {"role": "assistant", "content": "Mi dispiace, non ho trovato informazioni rilevanti."}
            ]

        # Leggi metadata.json per il totale dei chunks
        metadata_path = os.path.join("db", f"faiss_index_{db_name}", "metadata.json")
        with open(metadata_path, 'r') as f:
            metadata_list = json.load(f)
        
        # Crea dizionario titolo -> chunks
        total_chunks = {doc['title']: doc['chunks'] for doc in metadata_list}
        
        # Prepara le fonti
        sources = []
        for doc in relevant_docs:
            meta = doc.metadata
            title = meta.get('title', 'Unknown')
            author = meta.get('author', 'Unknown')
            filename = meta.get('filename', 'Unknown')
            chunk_id = meta.get('chunk_id', 0)  # Usa l'ID univoco del chunk
            total_doc_chunks = total_chunks.get(title, 0)
            
            # Usa lo stesso formato di chunks_viewer_tab.py
            chunk_info = f"📚 Chunk {chunk_id} - {title} ({filename})"
            if author != 'Unknown':
                chunk_info += f" - Author: {author}"
            
            sources.append(chunk_info)

        # Prepara contesto e prompt
        context = "\n".join([doc.page_content for doc in relevant_docs])
        sources_text = "\n\nFonti consultate:\n" + "\n".join(set(sources))
        prompt = SYSTEM_PROMPTS[prompt_type].format(context=context)
        prompt += "\nCita sempre le fonti utilizzate nella risposta, inclusi titolo e autore."

        # Crea messaggio e ottieni risposta
        messages = [
            {"role": "system", "content": prompt},
            *[{"role": m["role"], "content": m["content"]} for m in chat_history],
            {"role": "user", "content": question}
        ]
        
        client, model = get_llm_client(llm_type)
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature= LLM_CONFIGS_EXTENDED["temperature"],
            max_tokens=LLM_CONFIGS_EXTENDED["max_tokens"]
        )
        
        answer = response.choices[0].message.content + sources_text
        return [
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer}
        ]
    except Exception as e:
        logging.error(f"Error in answer_question: {e}")
        error_msg = "LLM locale non disponibile." if "local" in str(llm_type) else str(e)
        return [
            {"role": "user", "content": question},
            {"role": "assistant", "content": f"⚠️ {error_msg}"}
        ]

class DocumentRetriever:
    def __init__(self, db_path):
        self.embeddings = get_embeddings()
        self.vectorstore = FAISS.load_local(
            db_path,
            self.embeddings,
            allow_dangerous_deserialization=True
        )
        
    def get_relevant_chunks(self, question):
        enhanced_query = enhance_query(question)
        docs_and_scores = self.vectorstore.similarity_search_with_score(enhanced_query)
        log_search_results(question, docs_and_scores)
        return [
            doc for doc, score in docs_and_scores
            if score >= EMBEDDING_CONFIG['min_similarity']
        ]

if __name__ == "__main__":
    pass