|
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
from datetime import datetime |
|
import logging |
|
from config.config import settings |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class ConversationManager: |
|
"""Manages conversation history and context""" |
|
def __init__(self): |
|
self.conversations: Dict[str, List[Dict[str, Any]]] = {} |
|
self.max_history = 10 |
|
|
|
def add_interaction( |
|
self, |
|
session_id: str, |
|
user_input: str, |
|
response: str, |
|
context: Optional[Dict[str, Any]] = None |
|
) -> None: |
|
if session_id not in self.conversations: |
|
self.conversations[session_id] = [] |
|
|
|
self.conversations[session_id].append({ |
|
'timestamp': datetime.now().isoformat(), |
|
'user_input': user_input, |
|
'response': response, |
|
'context': context |
|
}) |
|
|
|
|
|
if len(self.conversations[session_id]) > self.max_history: |
|
self.conversations[session_id] = self.conversations[session_id][-self.max_history:] |
|
|
|
def get_history(self, session_id: str) -> List[Dict[str, Any]]: |
|
return self.conversations.get(session_id, []) |
|
|
|
def clear_history(self, session_id: str) -> None: |
|
if session_id in self.conversations: |
|
del self.conversations[session_id] |
|
|
|
class ChatService: |
|
"""Main chat service that coordinates responses""" |
|
def __init__( |
|
self, |
|
model_service, |
|
data_service, |
|
pdf_service, |
|
faq_service |
|
): |
|
self.model = model_service.model |
|
self.tokenizer = model_service.tokenizer |
|
self.data_service = data_service |
|
self.pdf_service = pdf_service |
|
self.faq_service = faq_service |
|
self.conversation_manager = ConversationManager() |
|
|
|
async def search_all_sources( |
|
self, |
|
query: str, |
|
top_k: int = 3 |
|
) -> Dict[str, List[Dict[str, Any]]]: |
|
"""Search across all available data sources""" |
|
try: |
|
|
|
product_task = asyncio.create_task( |
|
self.data_service.search(query, top_k) |
|
) |
|
pdf_task = asyncio.create_task( |
|
self.pdf_service.search(query, top_k) |
|
) |
|
faq_task = asyncio.create_task( |
|
self.faq_service.search_faqs(query, top_k) |
|
) |
|
|
|
|
|
products, pdfs, faqs = await asyncio.gather( |
|
product_task, pdf_task, faq_task |
|
) |
|
|
|
return { |
|
'products': products, |
|
'documents': pdfs, |
|
'faqs': faqs |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error searching sources: {e}") |
|
return {'products': [], 'documents': [], 'faqs': []} |
|
|
|
def build_context( |
|
self, |
|
search_results: Dict[str, List[Dict[str, Any]]], |
|
chat_history: List[Dict[str, Any]] |
|
) -> str: |
|
"""Build context for the model from search results and chat history""" |
|
context_parts = [] |
|
|
|
|
|
if search_results.get('products'): |
|
products = search_results['products'][:2] |
|
for product in products: |
|
context_parts.append( |
|
f"Produkt: {product['Name']}\n" |
|
f"Beschreibung: {product['Description']}\n" |
|
f"Preis: {product['Price']}€\n" |
|
f"Kategorie: {product['ProductCategory']}" |
|
) |
|
|
|
|
|
if search_results.get('documents'): |
|
docs = search_results['documents'][:2] |
|
for doc in docs: |
|
context_parts.append( |
|
f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n" |
|
f"{doc['text']}" |
|
) |
|
|
|
|
|
if search_results.get('faqs'): |
|
faqs = search_results['faqs'][:2] |
|
for faq in faqs: |
|
context_parts.append( |
|
f"FAQ:\n" |
|
f"Frage: {faq['question']}\n" |
|
f"Antwort: {faq['answer']}" |
|
) |
|
|
|
|
|
if chat_history: |
|
recent_history = chat_history[-3:] |
|
history_text = "\n".join( |
|
f"User: {h['user_input']}\nAssistant: {h['response']}" |
|
for h in recent_history |
|
) |
|
context_parts.append(f"Letzte Interaktionen:\n{history_text}") |
|
|
|
return "\n\n".join(context_parts) |
|
|
|
async def generate_response( |
|
self, |
|
prompt: str, |
|
max_length: int = 1000 |
|
) -> str: |
|
"""Generate response using the language model""" |
|
try: |
|
inputs = self.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=4096 |
|
).to(settings.DEVICE) |
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
no_repeat_ngram_size=3, |
|
early_stopping=True |
|
) |
|
|
|
response = self.tokenizer.decode( |
|
outputs[0], |
|
skip_special_tokens=True |
|
) |
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {e}") |
|
raise |
|
|
|
async def chat( |
|
self, |
|
user_input: str, |
|
session_id: str, |
|
max_length: int = 1000 |
|
) -> Tuple[str, List[Dict[str, Any]]]: |
|
"""Main chat method that coordinates the entire conversation flow""" |
|
try: |
|
|
|
chat_history = self.conversation_manager.get_history(session_id) |
|
|
|
|
|
search_results = await self.search_all_sources(user_input) |
|
|
|
|
|
context = self.build_context(search_results, chat_history) |
|
|
|
|
|
prompt = ( |
|
f"Context:\n{context}\n\n" |
|
f"User: {user_input}\n" |
|
"Assistant:" |
|
) |
|
|
|
|
|
response = await self.generate_response(prompt, max_length) |
|
|
|
|
|
self.conversation_manager.add_interaction( |
|
session_id, |
|
user_input, |
|
response, |
|
{'search_results': search_results} |
|
) |
|
|
|
return response, search_results |
|
|
|
except Exception as e: |
|
logger.error(f"Error in chat: {e}") |
|
raise |