llm / services /chat_service.py
Chris4K's picture
Update services/chat_service.py
24976d3 verified
#chat_service.py
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import logging
from config.config import settings
import asyncio
from io import StringIO
import pandas as pd
logger = logging.getLogger(__name__)
class ConversationManager:
"""Manages conversation history and context"""
def __init__(self):
self.conversations: Dict[str, List[Dict[str, Any]]] = {}
self.max_history = 1
def add_interaction(
self,
session_id: str,
user_input: str,
response: str,
context: Optional[Dict[str, Any]] = None
) -> None:
if session_id not in self.conversations:
self.conversations[session_id] = []
self.conversations[session_id].append({
'timestamp': datetime.now().isoformat(),
'user_input': user_input,
'response': response,
'context': context
})
if len(self.conversations[session_id]) > self.max_history:
self.conversations[session_id] = self.conversations[session_id][-self.max_history:]
def get_history(self, session_id: str) -> List[Dict[str, Any]]:
return self.conversations.get(session_id, [])
def clear_history(self, session_id: str) -> None:
if session_id in self.conversations:
del self.conversations[session_id]
class ChatService:
def __init__(
self,
model_service,
data_service,
pdf_service,
faq_service
):
self.model = model_service.model
self.tokenizer = model_service.tokenizer
self.data_service = data_service
self.pdf_service = pdf_service
self.faq_service = faq_service
self.conversation_manager = ConversationManager()
async def search_all_sources(
self,
query: str,
top_k: int = 3
) -> Dict[str, List[Dict[str, Any]]]:
"""Search across all available data sources"""
try:
print("-----------------------------")
print("starting searches .... ")
# Await the search calls since they're coroutines
products = await self.data_service.search(query, top_k)
pdfs = await self.pdf_service.search(query, top_k)
faqs = await self.faq_service.search_faqs(query, top_k)
results = {
'products': products or [],
'documents': pdfs or [],
'faqs': faqs or []
}
print("Search results:", results)
return results
except Exception as e:
logger.error(f"Error searching sources: {e}")
return {'products': [], 'documents': [], 'faqs': []}
def construct_system_prompt(self, context: str) -> str:
"""Constructs the system message."""
return (
"You are a friendly bot named: Oma Erna, specializing in Bofrost products and content. Use only the context from this prompt. "
"Return comprehensive German answers. If possible add product IDs from context. Do not make up information. The context is is truth. "
"Use the following context (product descriptions and information) for answers:\n\n"
f"{context}\n\n"
)
def construct_prompt(
self,
user_input: str,
context: str,
chat_history: List[Tuple[str, str]],
max_history_turns: int = 1
) -> str:
"""Constructs the full prompt."""
system_message = self.construct_system_prompt(context)
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
for user_msg, assistant_msg in chat_history[-max_history_turns:]:
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
return prompt
def build_context(
self,
search_results: Dict[str, List[Dict[str, Any]]],
chat_history: List[Dict[str, Any]]
) -> str:
"""Build context for the model from search results and chat history"""
context_parts = []
# Add relevant products
if search_results.get('products'):
products = search_results['products'][:2] # Limit to top 2 products
for product in products:
context_parts.append(
f"Produkt: {product['Name']}\n"
f"Beschreibung: {product['Description']}\n"
f"Preis: {product['Price']}€\n"
f"Kategorie: {product['ProductCategory']}"
)
# Add relevant PDF content
if search_results.get('documents'):
docs = search_results['documents'][:2]
for doc in docs:
context_parts.append(
f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n"
f"{doc['text']}"
)
# Add relevant FAQs
if search_results.get('faqs'):
faqs = search_results['faqs'][:2]
for faq in faqs:
context_parts.append(
f"FAQ:\n"
f"Frage: {faq['question']}\n"
f"Antwort: {faq['answer']}"
)
# Add recent chat history
if chat_history:
print("--- historiy--- ")
#recent_history = chat_history[-3:] # Last 3 interactions
#history_text = "\n".join(
# f"User: {h['user_input']}\nAssistant: {h['response']}"
# for h in recent_history
#)
#context_parts.append(f"Letzte Interaktionen:\n{history_text}")
print("\n\n".join(context_parts))
return "\n\n".join(context_parts)
async def chat(
self,
user_input: str,
session_id: Any,
max_length: int = 8000
) -> Tuple[str, List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]]]:
"""Main chat method that coordinates the entire conversation flow."""
try:
if not isinstance(session_id, str):
session_id = str(session_id)
chat_history_raw = self.conversation_manager.get_history(session_id)
chat_history = [
(entry['user_input'], entry['response']) for entry in chat_history_raw
]
search_results = await self.search_all_sources(user_input)
print(search_results)
context = self.build_context(search_results, chat_history_raw)
prompt = self.construct_prompt(user_input, context, chat_history)
response = self.generate_response(prompt, max_length)
self.conversation_manager.add_interaction(
session_id,
user_input,
response,
{'search_results': search_results}
)
formatted_history = [
(entry['user_input'], entry['response'])
for entry in self.conversation_manager.get_history(session_id)
]
return response, formatted_history, search_results
except Exception as e:
logger.error(f"Error in chat: {e}")
raise
def generate_response(
self,
prompt: str,
max_length: int = 1000
) -> str:
"""Generate response using the language model"""
try:
print(prompt)
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096
).to(settings.DEVICE)
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
do_sample=True,
no_repeat_ngram_size=3,
early_stopping=False
)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=4096).to("cpu")
response = self.tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
# Remove potential repeated assistant text
response = response.replace("<|assistant|>", "").strip()
return response.strip()
except Exception as e:
logger.error(f"Error generating response: {e}")
raise