from typing import List, Dict, Any, Optional
import aiohttp
from bs4 import BeautifulSoup
import faiss
import logging
from config.config import settings
import asyncio
from urllib.parse import urljoin

logger = logging.getLogger(__name__)

class FAQService:
    def __init__(self, model_service):
        self.embedder = model_service.embedder
        self.faiss_index = None
        self.faq_data = []
        self.visited_urls = set()
        self.base_url = "https://www.bofrost.de/faq/"

    async def fetch_faq_pages(self) -> List[Dict[str, Any]]:
        async with aiohttp.ClientSession() as session:
            try:
                # Start with the main FAQ page
                pages = await self.crawl_faq_pages(self.base_url, session)
                return [page for page in pages if page]
            except Exception as e:
                logger.error(f"Error fetching FAQ pages: {e}")
                return []

    async def crawl_faq_pages(self, url: str, session: aiohttp.ClientSession) -> List[Dict[str, Any]]:
        if url in self.visited_urls or not url.startswith(self.base_url):
            return []

        self.visited_urls.add(url)
        pages = []

        try:
            async with session.get(url, timeout=settings.TIMEOUT) as response:
                if response.status == 200:
                    content = await response.text()
                    soup = BeautifulSoup(content, 'html.parser')
                    
                    # Add current page content
                    page_content = await self.parse_faq_content(soup, url)
                    if page_content:
                        pages.append(page_content)

                    # Find and follow FAQ links
                    tasks = []
                    for link in soup.find_all('a', href=True):
                        href = link['href']
                        full_url = urljoin(url, href)
                        
                        if (full_url.startswith(self.base_url) and 
                            full_url not in self.visited_urls):
                            tasks.append(self.crawl_faq_pages(full_url, session))
                    
                    if tasks:
                        results = await asyncio.gather(*tasks)
                        for result in results:
                            pages.extend(result)

        except Exception as e:
            logger.error(f"Error crawling FAQ page {url}: {e}")

        return pages

    async def parse_faq_content(self, soup: BeautifulSoup, url: str) -> Optional[Dict[str, Any]]:
        try:
            faqs = []
            faq_items = soup.find_all('div', class_='faq-item')
            
            for item in faq_items:
                # Extract question
                question_elem = item.find('a', class_='headline-collapse')
                if not question_elem:
                    continue
                    
                question = question_elem.find('span')
                if not question:
                    continue
                    
                question_text = question.text.strip()
                
                # Extract answer
                content_elem = item.find('div', class_='content-collapse')
                if not content_elem:
                    continue
                    
                wysiwyg = content_elem.find('div', class_='wysiwyg-content')
                if not wysiwyg:
                    continue
                
                # Extract all text while preserving structure
                answer_parts = []
                for elem in wysiwyg.find_all(['p', 'li']):
                    text = elem.get_text(strip=True)
                    if text:
                        answer_parts.append(text)
                
                answer_text = ' '.join(answer_parts)
                
                if question_text and answer_text:
                    faqs.append({
                        "question": question_text,
                        "answer": answer_text
                    })
            
            if faqs:
                return {
                    "url": url,
                    "faqs": faqs
                }
            
        except Exception as e:
            logger.error(f"Error parsing FAQ content from {url}: {e}")
        
        return None

    async def index_faqs(self):
        faq_pages = await self.fetch_faq_pages()
        
        self.faq_data = []
        all_texts = []
        
        for faq_page in faq_pages:
            for item in faq_page['faqs']:
                # Combine question and answer for better semantic search
                combined_text = f"{item['question']} {item['answer']}"
                all_texts.append(combined_text)
                self.faq_data.append({
                    "question": item['question'],
                    "answer": item['answer'],
                    "source": faq_page['url']
                })
        
        if not all_texts:
            logger.warning("No FAQ content found to index")
            return
            
        # Create embeddings and index them
        embeddings = self.embedder.encode(all_texts, convert_to_tensor=True).cpu().detach().numpy()
        dimension = embeddings.shape[1]
        self.faiss_index = faiss.IndexFlatL2(dimension)
        self.faiss_index.add(embeddings)

    async def search_faqs(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        if not self.faiss_index:
            await self.index_faqs()
            
        if not self.faq_data:
            logger.warning("No FAQ data available for search")
            return []
            
        query_embedding = self.embedder.encode([query], convert_to_tensor=True).cpu().detach().numpy()
        distances, indices = self.faiss_index.search(query_embedding, top_k)
        
        results = []
        for i, idx in enumerate(indices[0]):
            if idx < len(self.faq_data):
                result = self.faq_data[idx].copy()
                result["score"] = float(distances[0][i])
                results.append(result)
                
        return results