Spaces:
Running
Running
#chat_service.py | |
from typing import List, Dict, Any, Optional, Tuple | |
from datetime import datetime | |
import logging | |
from config.config import settings | |
import asyncio | |
from io import StringIO | |
import pandas as pd | |
logger = logging.getLogger(__name__) | |
class ConversationManager: | |
"""Manages conversation history and context""" | |
def __init__(self): | |
self.conversations: Dict[str, List[Dict[str, Any]]] = {} | |
self.max_history = 1 | |
def add_interaction( | |
self, | |
session_id: str, | |
user_input: str, | |
response: str, | |
context: Optional[Dict[str, Any]] = None | |
) -> None: | |
if session_id not in self.conversations: | |
self.conversations[session_id] = [] | |
self.conversations[session_id].append({ | |
'timestamp': datetime.now().isoformat(), | |
'user_input': user_input, | |
'response': response, | |
'context': context | |
}) | |
if len(self.conversations[session_id]) > self.max_history: | |
self.conversations[session_id] = self.conversations[session_id][-self.max_history:] | |
def get_history(self, session_id: str) -> List[Dict[str, Any]]: | |
return self.conversations.get(session_id, []) | |
def clear_history(self, session_id: str) -> None: | |
if session_id in self.conversations: | |
del self.conversations[session_id] | |
class ChatService: | |
def __init__( | |
self, | |
model_service, | |
data_service, | |
pdf_service, | |
faq_service | |
): | |
self.model = model_service.model | |
self.tokenizer = model_service.tokenizer | |
self.data_service = data_service | |
self.pdf_service = pdf_service | |
self.faq_service = faq_service | |
self.conversation_manager = ConversationManager() | |
async def search_all_sources( | |
self, | |
query: str, | |
top_k: int = 3 | |
) -> Dict[str, List[Dict[str, Any]]]: | |
"""Search across all available data sources""" | |
try: | |
print("-----------------------------") | |
print("starting searches .... ") | |
# Await the search calls since they're coroutines | |
products = await self.data_service.search(query, top_k) | |
pdfs = await self.pdf_service.search(query, top_k) | |
faqs = await self.faq_service.search_faqs(query, top_k) | |
results = { | |
'products': products or [], | |
'documents': pdfs or [], | |
'faqs': faqs or [] | |
} | |
print("Search results:", results) | |
return results | |
except Exception as e: | |
logger.error(f"Error searching sources: {e}") | |
return {'products': [], 'documents': [], 'faqs': []} | |
def construct_system_prompt(self, context: str) -> str: | |
"""Constructs the system message.""" | |
return ( | |
"You are a friendly bot named: Oma Erna, specializing in Bofrost products and content. Use only the context from this prompt. " | |
"Return comprehensive German answers. If possible add product IDs from context. Do not make up information. The context is is truth. " | |
"Use the following context (product descriptions and information) for answers:\n\n" | |
f"{context}\n\n" | |
) | |
def construct_prompt( | |
self, | |
user_input: str, | |
context: str, | |
chat_history: List[Tuple[str, str]], | |
max_history_turns: int = 1 | |
) -> str: | |
"""Constructs the full prompt.""" | |
system_message = self.construct_system_prompt(context) | |
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" | |
for user_msg, assistant_msg in chat_history[-max_history_turns:]: | |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" | |
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" | |
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" | |
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
return prompt | |
def build_context( | |
self, | |
search_results: Dict[str, List[Dict[str, Any]]], | |
chat_history: List[Dict[str, Any]] | |
) -> str: | |
"""Build context for the model from search results and chat history""" | |
context_parts = [] | |
# Add relevant products | |
if search_results.get('products'): | |
products = search_results['products'][:2] # Limit to top 2 products | |
for product in products: | |
context_parts.append( | |
f"Produkt: {product['Name']}\n" | |
f"Beschreibung: {product['Description']}\n" | |
f"Preis: {product['Price']}€\n" | |
f"Kategorie: {product['ProductCategory']}" | |
) | |
# Add relevant PDF content | |
if search_results.get('documents'): | |
docs = search_results['documents'][:2] | |
for doc in docs: | |
context_parts.append( | |
f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n" | |
f"{doc['text']}" | |
) | |
# Add relevant FAQs | |
if search_results.get('faqs'): | |
faqs = search_results['faqs'][:2] | |
for faq in faqs: | |
context_parts.append( | |
f"FAQ:\n" | |
f"Frage: {faq['question']}\n" | |
f"Antwort: {faq['answer']}" | |
) | |
# Add recent chat history | |
if chat_history: | |
print("--- historiy--- ") | |
#recent_history = chat_history[-3:] # Last 3 interactions | |
#history_text = "\n".join( | |
# f"User: {h['user_input']}\nAssistant: {h['response']}" | |
# for h in recent_history | |
#) | |
#context_parts.append(f"Letzte Interaktionen:\n{history_text}") | |
print("\n\n".join(context_parts)) | |
return "\n\n".join(context_parts) | |
async def chat( | |
self, | |
user_input: str, | |
session_id: Any, | |
max_length: int = 8000 | |
) -> Tuple[str, List[Tuple[str, str]], Dict[str, List[Dict[str, Any]]]]: | |
"""Main chat method that coordinates the entire conversation flow.""" | |
try: | |
if not isinstance(session_id, str): | |
session_id = str(session_id) | |
chat_history_raw = self.conversation_manager.get_history(session_id) | |
chat_history = [ | |
(entry['user_input'], entry['response']) for entry in chat_history_raw | |
] | |
search_results = await self.search_all_sources(user_input) | |
print(search_results) | |
context = self.build_context(search_results, chat_history_raw) | |
prompt = self.construct_prompt(user_input, context, chat_history) | |
response = self.generate_response(prompt, max_length) | |
self.conversation_manager.add_interaction( | |
session_id, | |
user_input, | |
response, | |
{'search_results': search_results} | |
) | |
formatted_history = [ | |
(entry['user_input'], entry['response']) | |
for entry in self.conversation_manager.get_history(session_id) | |
] | |
return response, formatted_history, search_results | |
except Exception as e: | |
logger.error(f"Error in chat: {e}") | |
raise | |
def generate_response( | |
self, | |
prompt: str, | |
max_length: int = 1000 | |
) -> str: | |
"""Generate response using the language model""" | |
try: | |
print(prompt) | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=4096 | |
).to(settings.DEVICE) | |
outputs = self.model.generate( | |
**inputs, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
no_repeat_ngram_size=3, | |
early_stopping=False | |
) | |
input_ids = self.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=4096).to("cpu") | |
response = self.tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True) | |
return response.strip() | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
raise |