llm / services /chat_service.py
Chris4K's picture
Update services/chat_service.py
4ae5769 verified
raw
history blame
9.46 kB
# services/chat_service.py
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import logging
from config.config import settings
logger = logging.getLogger(__name__)
class ConversationManager:
"""Manages conversation history and context"""
def __init__(self):
self.conversations: Dict[str, List[Dict[str, Any]]] = {}
self.max_history = 10
def add_interaction(
self,
session_id: str,
user_input: str,
response: str,
context: Optional[Dict[str, Any]] = None
) -> None:
if session_id not in self.conversations:
self.conversations[session_id] = []
self.conversations[session_id].append({
'timestamp': datetime.now().isoformat(),
'user_input': user_input,
'response': response,
'context': context
})
# Trim history if needed
if len(self.conversations[session_id]) > self.max_history:
self.conversations[session_id] = self.conversations[session_id][-self.max_history:]
def get_history(self, session_id: str) -> List[Dict[str, Any]]:
return self.conversations.get(session_id, [])
def clear_history(self, session_id: str) -> None:
if session_id in self.conversations:
del self.conversations[session_id]
class ChatService:
"""Main chat service that coordinates responses"""
"""Main chat service that coordinates responses"""
def __init__(
self,
model_service,
data_service,
pdf_service,
faq_service
):
self.model = model_service.model
self.tokenizer = model_service.tokenizer
self.data_service = data_service
self.pdf_service = pdf_service
self.faq_service = faq_service
self.conversation_manager = ConversationManager()
def construct_system_prompt(self, context: str) -> str:
"""Constructs the system message."""
return (
"You are a friendly bot specializing in Bofrost products. "
"Return comprehensive German answers. Always add product IDs. "
"Use the following product descriptions:\n\n"
f"{context}\n\n"
)
def construct_prompt(
self,
user_input: str,
context: str,
chat_history: List[Tuple[str, str]],
max_history_turns: int = 1
) -> str:
"""Constructs the full prompt."""
# System message
system_message = self.construct_system_prompt(context)
# Start with system message
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
# Add chat history (limit to last `max_history_turns` interactions)
for user_msg, assistant_msg in chat_history[-max_history_turns:]:
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
# Add the current user input
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
return prompt
def build_context(
self,
search_results: Dict[str, List[Dict[str, Any]]],
chat_history: List[Dict[str, Any]]
) -> str:
"""Build context for the model from search results and chat history"""
context_parts = []
# Add relevant products
if search_results.get('products'):
products = search_results['products'][:2] # Limit to top 2 products
for product in products:
context_parts.append(
f"Produkt: {product['Name']}\n"
f"Beschreibung: {product['Description']}\n"
f"Preis: {product['Price']}€\n"
f"Kategorie: {product['ProductCategory']}"
)
# Add relevant PDF content
if search_results.get('documents'):
docs = search_results['documents'][:2]
for doc in docs:
context_parts.append(
f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n"
f"{doc['text']}"
)
# Add relevant FAQs
if search_results.get('faqs'):
faqs = search_results['faqs'][:2]
for faq in faqs:
context_parts.append(
f"FAQ:\n"
f"Frage: {faq['question']}\n"
f"Antwort: {faq['answer']}"
)
# Add recent chat history
if chat_history:
recent_history = chat_history[-3:] # Last 3 interactions
history_text = "\n".join(
f"User: {h['user_input']}\nAssistant: {h['response']}"
for h in recent_history
)
context_parts.append(f"Letzte Interaktionen:\n{history_text}")
return "\n\n".join(context_parts)
async def chat(
self,
user_input: str,
session_id: Any,
max_length: int = 1000
) -> Tuple[str, List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]]]:
"""Main chat method that coordinates the entire conversation flow."""
try:
# Ensure session_id is a string
if not isinstance(session_id, str):
session_id = str(session_id)
# Get chat history
chat_history_raw = self.conversation_manager.get_history(session_id)
chat_history = [
(entry['user_input'], entry['response']) for entry in chat_history_raw
]
# Search all sources
search_results = await self.search_all_sources(user_input)
# Build context
context = self.build_context(search_results, chat_history_raw)
# Construct the prompt
prompt = self.construct_prompt(user_input, context, chat_history)
# Generate response
response = await self.generate_response(prompt, max_length)
# Store interaction
self.conversation_manager.add_interaction(
session_id,
user_input,
response,
{'search_results': search_results}
)
# Prepare the chat history for Gradio
formatted_history = [
(entry['user_input'], entry['response']) for entry in self.conversation_manager.get_history(session_id)
]
return response, formatted_history, search_results
except Exception as e:
logger.error(f"Error in chat: {e}")
raise
async def search_all_sources(
self,
query: str,
top_k: int = 3
) -> Dict[str, List[Dict[str, Any]]]:
"""Search across all available data sources"""
try:
# Run searches in parallel
product_task = asyncio.create_task(
self.data_service.search(query, top_k)
)
pdf_task = asyncio.create_task(
self.pdf_service.search(query, top_k)
)
faq_task = asyncio.create_task(
self.faq_service.search_faqs(query, top_k)
)
# Gather results
products, pdfs, faqs = await asyncio.gather(
product_task, pdf_task, faq_task
)
return {
'products': products,
'documents': pdfs,
'faqs': faqs
}
except Exception as e:
logger.error(f"Error searching sources: {e}")
return {'products': [], 'documents': [], 'faqs': []}
async def generate_response(
self,
prompt: str,
max_length: int = 1000
) -> str:
"""Generate response using the language model"""
try:
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096
).to(settings.DEVICE)
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
do_sample=True,
no_repeat_ngram_size=3,
early_stopping=False, # True num_beams=3, # Increase number of beams if beam search is needed
#pad_token_id=self.tokenizer.eos_token_id, #eos_token_id
)
response = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
logger.error(f"Error generating response: {e}")
raise