# services/data_service.py
from typing import List, Dict, Any, Optional
import pandas as pd
import faiss
import numpy as np
import aiohttp
from datetime import datetime
import logging
from config.config import settings
from functools import lru_cache

logger = logging.getLogger(__name__)

class DataService:
    def __init__(self, model_service):
        self.embedder = model_service.embedder
        self.cache = {}
        self.last_update = None
        self.faiss_index = None
        self.data_cleaned = None

    async def fetch_csv_data(self) -> pd.DataFrame:
        async with aiohttp.ClientSession() as session:
            for attempt in range(settings.MAX_RETRIES):
                try:
                    async with session.get(settings.CSV_URL) as response:
                        if response.status == 200:
                            content = await response.text()
                            return pd.read_csv(pd.StringIO(content), sep='|')
                except Exception as e:
                    logger.error(f"Attempt {attempt + 1} failed: {e}")
                    if attempt == settings.MAX_RETRIES - 1:
                        raise

    async def prepare_data_and_index(self) -> tuple:
        current_time = datetime.now()
        
        # Check cache validity
        if (self.last_update and 
            (current_time - self.last_update).seconds < settings.CACHE_DURATION and
            self.cache):
            return self.cache['data'], self.cache['index']

        data = await self.fetch_csv_data()
        
        # Data cleaning and preparation
        columns_to_keep = [
            'ID', 'Name', 'Description', 'Price', 
            'ProductCategory', 'Grammage', 
            'BasePriceText', 'Rating', 'RatingCount',
            'Ingredients', 'CreationDate', 'Keywords', 'Brand'
        ]
        
        self.data_cleaned = data[columns_to_keep].copy()
        self.data_cleaned['Description'] = self.data_cleaned['Description'].str.replace(
            r'[^\w\s.,;:\'/?!€$%&()\[\]{}<>|=+\\-]', ' ', regex=True
        )
        
        # Improved text combination with weights
        self.data_cleaned['combined_text'] = self.data_cleaned.apply(
            lambda row: (
                f"{row['Name']} {row['Name']} "  # Double weight for name
                f"{row['Description']} "
                f"{row['Keywords'] if pd.notnull(row['Keywords']) else ''} "
                f"{row['ProductCategory'] if pd.notnull(row['ProductCategory']) else ''}"
            ).strip(),
            axis=1
        )

        # Create FAISS index
        embeddings = self.embedder.encode(
            self.data_cleaned['combined_text'].tolist(),
            convert_to_tensor=True,
            show_progress_bar=True
        ).cpu().detach().numpy()

        d = embeddings.shape[1]
        self.faiss_index = faiss.IndexFlatL2(d)
        self.faiss_index.add(embeddings)

        # Update cache
        self.cache = {
            'data': self.data_cleaned,
            'index': self.faiss_index
        }
        self.last_update = current_time

        return self.data_cleaned, self.faiss_index

    async def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        if not self.faiss_index:
            await self.prepare_data_and_index()

        query_embedding = self.embedder.encode([query], convert_to_tensor=True).cpu().detach().numpy()
        distances, indices = self.faiss_index.search(query_embedding, top_k)
        
        results = []
        for i, idx in enumerate(indices[0]):
            product = self.data_cleaned.iloc[idx].to_dict()
            product['score'] = float(distances[0][i])
            results.append(product)
            
        return results