# services/chat_service.py
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import logging
from config.config import settings

logger = logging.getLogger(__name__)

class ConversationManager:
    """Manages conversation history and context"""
    def __init__(self):
        self.conversations: Dict[str, List[Dict[str, Any]]] = {}
        self.max_history = 10
        
    def add_interaction(
        self,
        session_id: str,
        user_input: str,
        response: str,
        context: Optional[Dict[str, Any]] = None
    ) -> None:
        if session_id not in self.conversations:
            self.conversations[session_id] = []
            
        self.conversations[session_id].append({
            'timestamp': datetime.now().isoformat(),
            'user_input': user_input,
            'response': response,
            'context': context
        })
        
        # Trim history if needed
        if len(self.conversations[session_id]) > self.max_history:
            self.conversations[session_id] = self.conversations[session_id][-self.max_history:]
            
    def get_history(self, session_id: str) -> List[Dict[str, Any]]:
        return self.conversations.get(session_id, [])
        
    def clear_history(self, session_id: str) -> None:
        if session_id in self.conversations:
            del self.conversations[session_id]

class ChatService:
    """Main chat service that coordinates responses"""
    def __init__(
        self,
        model_service,
        data_service,
        pdf_service,
        faq_service
    ):
        self.model = model_service.model
        self.tokenizer = model_service.tokenizer
        self.data_service = data_service
        self.pdf_service = pdf_service
        self.faq_service = faq_service
        self.conversation_manager = ConversationManager()
        
    async def search_all_sources(
        self,
        query: str,
        top_k: int = 3
    ) -> Dict[str, List[Dict[str, Any]]]:
        """Search across all available data sources"""
        try:
            # Run searches in parallel
            product_task = asyncio.create_task(
                self.data_service.search(query, top_k)
            )
            pdf_task = asyncio.create_task(
                self.pdf_service.search(query, top_k)
            )
            faq_task = asyncio.create_task(
                self.faq_service.search_faqs(query, top_k)
            )
            
            # Gather results
            products, pdfs, faqs = await asyncio.gather(
                product_task, pdf_task, faq_task
            )
            
            return {
                'products': products,
                'documents': pdfs,
                'faqs': faqs
            }
            
        except Exception as e:
            logger.error(f"Error searching sources: {e}")
            return {'products': [], 'documents': [], 'faqs': []}

    def build_context(
        self,
        search_results: Dict[str, List[Dict[str, Any]]],
        chat_history: List[Dict[str, Any]]
    ) -> str:
        """Build context for the model from search results and chat history"""
        context_parts = []
        
        # Add relevant products
        if search_results.get('products'):
            products = search_results['products'][:2]  # Limit to top 2 products
            for product in products:
                context_parts.append(
                    f"Produkt: {product['Name']}\n"
                    f"Beschreibung: {product['Description']}\n"
                    f"Preis: {product['Price']}€\n"
                    f"Kategorie: {product['ProductCategory']}"
                )
        
        # Add relevant PDF content
        if search_results.get('documents'):
            docs = search_results['documents'][:2]
            for doc in docs:
                context_parts.append(
                    f"Aus Dokument '{doc['source']}' (Seite {doc['page']}):\n"
                    f"{doc['text']}"
                )
        
        # Add relevant FAQs
        if search_results.get('faqs'):
            faqs = search_results['faqs'][:2]
            for faq in faqs:
                context_parts.append(
                    f"FAQ:\n"
                    f"Frage: {faq['question']}\n"
                    f"Antwort: {faq['answer']}"
                )
        
        # Add recent chat history
        if chat_history:
            recent_history = chat_history[-3:]  # Last 3 interactions
            history_text = "\n".join(
                f"User: {h['user_input']}\nAssistant: {h['response']}"
                for h in recent_history
            )
            context_parts.append(f"Letzte Interaktionen:\n{history_text}")
        
        return "\n\n".join(context_parts)

    async def generate_response(
        self,
        prompt: str,
        max_length: int = 1000
    ) -> str:
        """Generate response using the language model"""
        try:
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=4096
            ).to(settings.DEVICE)
            
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                no_repeat_ngram_size=3,
                early_stopping=True
            )
            
            response = self.tokenizer.decode(
                outputs[0],
                skip_special_tokens=True
            )
            
            return response.strip()
            
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            raise

    async def chat(
        self,
        user_input: str,
        session_id: str,
        max_length: int = 1000
    ) -> Tuple[str, List[Dict[str, Any]]]:
        """Main chat method that coordinates the entire conversation flow"""
        try:
            # Get chat history
            chat_history = self.conversation_manager.get_history(session_id)
            
            # Search all sources
            search_results = await self.search_all_sources(user_input)
            
            # Build context
            context = self.build_context(search_results, chat_history)
            
            # Create prompt
            prompt = (
                f"Context:\n{context}\n\n"
                f"User: {user_input}\n"
                "Assistant:"
            )
            
            # Generate response
            response = await self.generate_response(prompt, max_length)
            
            # Store interaction
            self.conversation_manager.add_interaction(
                session_id,
                user_input,
                response,
                {'search_results': search_results}
            )
            
            return response, search_results
            
        except Exception as e:
            logger.error(f"Error in chat: {e}")
            raise