File size: 6,236 Bytes
080146c
47e4aa2
 
 
754b268
 
080146c
47e4aa2
080146c
 
facd13e
080146c
 
2a36d42
47e4aa2
 
080146c
 
 
767bdf9
 
080146c
767bdf9
 
080146c
 
 
 
 
403260d
facd13e
080146c
facd13e
 
754b268
080146c
754b268
 
 
 
 
 
2a36d42
 
 
 
 
 
 
080146c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47e4aa2
080146c
 
 
 
 
 
 
 
47e4aa2
 
 
080146c
 
 
 
 
 
 
47e4aa2
 
080146c
47e4aa2
080146c
 
754b268
 
080146c
 
 
 
 
 
 
24083bf
080146c
403260d
 
080146c
403260d
 
767bdf9
 
 
 
 
080146c
767bdf9
080146c
47e4aa2
080146c
 
47e4aa2
 
080146c
 
47e4aa2
 
754b268
47e4aa2
 
080146c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47e4aa2
 
080146c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# llm_handling.py
import logging
import os
from langchain_community.vectorstores import FAISS
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
import json

from app.config import BASE_DB_PATH  # Ensure correct import
from app.config import LLM_CONFIGS, LLMType  # Import LLMType and LLM_CONFIGS
from app.configs.prompts import SYSTEM_PROMPTS
from app.utils.embedding_utils import get_embeddings
from app.utils.voice_utils import generate_speech  # Retain import if needed

logging.basicConfig(level=logging.INFO)

# =====================================
# Functions related to LLM
# =====================================

def get_llm_client(llm_type: LLMType):
    """Obtains the appropriate client for the selected model"""
    config = LLM_CONFIGS.get(llm_type)
    if not config:
        raise ValueError(f"Model {llm_type} not supported")
    client_class = config["client"]
    model = config["model"]
    client = client_class()  # Ensure no arguments are needed
    return client, model

def get_system_prompt(prompt_type="tutor"):
    """Selects the appropriate system prompt"""
    return SYSTEM_PROMPTS.get(prompt_type, SYSTEM_PROMPTS["tutor"])

def test_local_connection():
    """Checks connection to the local LLM server"""
    try:
        response = requests.get(f"http://192.168.82.5:1234/v1/health", timeout=5)
        return response.status_code == 200
    except:
        return False

def read_metadata(db_path):
    metadata_file = os.path.join(db_path, "metadata.json")
    if os.path.exists(metadata_file):
        with open(metadata_file, 'r') as f:
            return json.load(f)
    return []

def get_relevant_documents(vectorstore, question, min_similarity=0.7):
    """Retrieves relevant documents from the vectorstore"""
    try:
        enhanced_query = enhance_query(question)
        docs_and_scores = vectorstore.similarity_search_with_score(
            enhanced_query,
            k=8
        )
        filtered_docs = [
            doc for doc, score in docs_and_scores if score >= min_similarity
        ]
        logging.info(f"Query: {question}")
        logging.info(f"Documents found: {len(filtered_docs)}")
        return filtered_docs[:5] if filtered_docs else []
    except Exception as e:
        logging.error(f"Error retrieving documents: {e}")
        return []

def enhance_query(question):
    stop_words = set(['il', 'lo', 'la', 'i', 'gli', 'le', 'un', 'uno', 'una'])
    words = [w for w in question.lower().split() if w not in stop_words]
    enhanced_query = " ".join(words)
    return enhanced_query

def log_search_results(question, docs_and_scores):
    logging.info(f"Query: {question}")
    for idx, (doc, score) in enumerate(docs_and_scores, 1):
        logging.info(f"Doc {idx} - Score: {score:.4f}")
        logging.info(f"Content: {doc.page_content[:100]}...")

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def answer_question(question, db_name, prompt_type="tutor", chat_history=None, llm_type=None):
    if chat_history is None:
        chat_history = []
    try:
        embeddings = get_embeddings()
        db_path = os.path.join(BASE_DB_PATH, f"faiss_index_{db_name}")
        metadata_list = read_metadata(db_path)
        metadata_dict = {m["filename"]: m for m in metadata_list}
        vectorstore = FAISS.load_local(db_path, embeddings, allow_dangerous_deserialization=True)
        relevant_docs = get_relevant_documents(vectorstore, question)
        if not relevant_docs:
            return [
                {"role": "user", "content": question},
                {"role": "assistant", "content": "Sorry, no relevant information found to answer your question. Try rephrasing or asking a different question."}
            ]
        sources = []
        for idx, doc in enumerate(relevant_docs, 1):
            source_file = doc.metadata.get("source", "Unknown")
            if source_file in metadata_dict:
                meta = metadata_dict[source_file]
                sources.append(f"📚 {meta['title']} (Author: {meta['author']}) - Part {idx} of {len(relevant_docs)}")
        context = "\n".join([
            f"[Part {idx+1} of {len(relevant_docs)}]\n{doc.page_content}"
            for idx, doc in enumerate(relevant_docs)
        ])
        sources_text = "\n\nSources consulted:\n" + "\n".join(set(sources))
        prompt = SYSTEM_PROMPTS[prompt_type].format(context=context)
        prompt += "\nAlways cite the sources used for your response, including the document title and author."
        messages = [
            {"role": "system", "content": prompt},
            *[{"role": m["role"], "content": m["content"]} for m in chat_history],
            {"role": "user", "content": question}
        ]
        client, model = get_llm_client(llm_type)
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=0.7,
            max_tokens=2048
        )
        answer = response.choices[0].message.content + sources_text
        return [
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer}
        ]
    except Exception as e:
        logging.error(f"Error generating response: {e}")
        error_msg = "Local LLM not available. Try again later or use OpenAI." if "local" in str(llm_type) else str(e)
        return [
            {"role": "user", "content": question},
            {"role": "assistant", "content": f"⚠️ {error_msg}"}
        ]

class DocumentRetriever:
    def __init__(self, db_path):
        self.embeddings = get_embeddings()
        self.vectorstore = FAISS.load_local(
            db_path,
            self.embeddings,
            allow_dangerous_deserialization=True
        )
        
    def get_relevant_chunks(self, question):
        enhanced_query = enhance_query(question)
        docs_and_scores = self.vectorstore.similarity_search_with_score(
            enhanced_query,
            k=8
        )
        log_search_results(question, docs_and_scores)
        # Implement _filter_relevant_docs or remove the call
        # return self._filter_relevant_docs(docs_and_scores)

if __name__ == "__main__":
    pass