llm / services /chat_service.py
Chris4K's picture
Create services/chat_service.py
257879f verified
raw
history blame
7.14 kB
# services/chat_service.py
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import logging
from config.config import settings
logger = logging.getLogger(__name__)
class ConversationManager:
"""Manages conversation history and context"""
def __init__(self):
self.conversations: Dict[str, List[Dict[str, Any]]] = {}
self.max_history = 10
def add_interaction(
self,
session_id: str,
user_input: str,
response: str,
context: Optional[Dict[str, Any]] = None
) -> None:
if session_id not in self.conversations:
self.conversations[session_id] = []
self.conversations[session_id].append({
'timestamp': datetime.now().isoformat(),
'user_input': user_input,
'response': response,
'context': context
})
# Trim history if needed
if len(self.conversations[session_id]) > self.max_history:
self.conversations[session_id] = self.conversations[session_id][-self.max_history:]
def get_history(self, session_id: str) -> List[Dict[str, Any]]:
return self.conversations.get(session_id, [])
def clear_history(self, session_id: str) -> None:
if session_id in self.conversations:
del self.conversations[session_id]
class ChatService:
"""Main chat service that coordinates responses"""
def __init__(
self,
model_service,
data_service,
pdf_service,
faq_service
):
self.model = model_service.model
self.tokenizer = model_service.tokenizer
self.data_service = data_service
self.pdf_service = pdf_service
self.faq_service = faq_service
self.conversation_manager = ConversationManager()
async def search_all_sources(
self,
query: str,
top_k: int = 3
) -> Dict[str, List[Dict[str, Any]]]:
"""Search across all available data sources"""
try:
# Run searches in parallel
product_task = asyncio.create_task(
self.data_service.search(query, top_k)
)
pdf_task = asyncio.create_task(
self.pdf_service.search(query, top_k)
)
faq_task = asyncio.create_task(
self.faq_service.search_faqs(query, top_k)
)
# Gather results
products, pdfs, faqs = await asyncio.gather(
product_task, pdf_task, faq_task
)
return {
'products': products,
'documents': pdfs,
'faqs': faqs
}
except Exception as e:
logger.error(f"Error searching sources: {e}")
return {'products': [], 'documents': [], 'faqs': []}
def build_context(
self,
search_results: Dict[str, List[Dict[str, Any]]],
chat_history: List[Dict[str, Any]]
) -> str:
"""Build context for the model from search results and chat history"""
context_parts = []
# Add relevant products
if search_results.get('products'):
products = search_results['products'][:2] # Limit to top 2 products
for product in products:
context_parts.append(
f"Produkt: {product['Name']}\n"
f"Beschreibung: {product['Description']}\n"
f"Preis: {product['Price']}€\n"
f"Kategorie: {product['ProductCategory']}"
)
# Add relevant PDF content
if search_results.get('documents'):
docs = search_results['documents'][:2]
for doc in docs:
context_parts.append(
f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n"
f"{doc['text']}"
)
# Add relevant FAQs
if search_results.get('faqs'):
faqs = search_results['faqs'][:2]
for faq in faqs:
context_parts.append(
f"FAQ:\n"
f"Frage: {faq['question']}\n"
f"Antwort: {faq['answer']}"
)
# Add recent chat history
if chat_history:
recent_history = chat_history[-3:] # Last 3 interactions
history_text = "\n".join(
f"User: {h['user_input']}\nAssistant: {h['response']}"
for h in recent_history
)
context_parts.append(f"Letzte Interaktionen:\n{history_text}")
return "\n\n".join(context_parts)
async def generate_response(
self,
prompt: str,
max_length: int = 1000
) -> str:
"""Generate response using the language model"""
try:
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096
).to(settings.DEVICE)
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
do_sample=True,
no_repeat_ngram_size=3,
early_stopping=True
)
response = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
logger.error(f"Error generating response: {e}")
raise
async def chat(
self,
user_input: str,
session_id: str,
max_length: int = 1000
) -> Tuple[str, List[Dict[str, Any]]]:
"""Main chat method that coordinates the entire conversation flow"""
try:
# Get chat history
chat_history = self.conversation_manager.get_history(session_id)
# Search all sources
search_results = await self.search_all_sources(user_input)
# Build context
context = self.build_context(search_results, chat_history)
# Create prompt
prompt = (
f"Context:\n{context}\n\n"
f"User: {user_input}\n"
"Assistant:"
)
# Generate response
response = await self.generate_response(prompt, max_length)
# Store interaction
self.conversation_manager.add_interaction(
session_id,
user_input,
response,
{'search_results': search_results}
)
return response, search_results
except Exception as e:
logger.error(f"Error in chat: {e}")
raise