from llama_index.core import VectorStoreIndex
from llama_index.core import StorageContext
from pinecone import Pinecone, ServerlessSpec
from llama_index.vector_stores.pinecone import PineconeVectorStore
from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from config import PINECONE_CONFIG
from math import ceil
import numpy as np
import logging


class IndexManager:
    def __init__(self, index_name: str = "summarizer-semantic-index"):
        self.vector_index = None
        self.index_name = index_name
        self.client = self._get_pinecone_client()
        self.pinecone_index = self._create_pinecone_index()

    def _get_pinecone_client(self):
        """Initialize and return the Pinecone client."""
        # api_key = os.getenv("PINECONE_API_KEY")
        api_key = PINECONE_CONFIG.PINECONE_API_KEY
        if not api_key:
            raise ValueError(
                "Pinecone API key is missing. Please set it in environment variables."
            )
        return Pinecone(api_key=api_key)

    def _create_pinecone_index(self):
        """Create Pinecone index if it doesn't already exist."""
        if self.index_name not in self.client.list_indexes().names():
            self.client.create_index(
                name=self.index_name,
                dimension=1536,
                metric="cosine",
                spec=ServerlessSpec(cloud="aws", region="us-east-1"),
            )
        return self.client.Index(self.index_name)

    def _initialize_vector_store(self) -> StorageContext:
        """Initialize and return the vector store with the Pinecone index."""
        vector_store = PineconeVectorStore(pinecone_index=self.pinecone_index)
        return StorageContext.from_defaults(vector_store=vector_store)


    def build_indexes(self, nodes):
        """Build vector and tree indexes from nodes."""
        try:
            storage_context = self._initialize_vector_store()
            self.vector_index = VectorStoreIndex(nodes, storage_context=storage_context)
            self.vector_index.set_index_id("vector")
            
        except HTTPException as http_exc:
            raise http_exc  # Re-return JSONResponses to ensure FastAPI handles them
        
        except Exception as e:
            return JSONResponse(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                content=f"Error loading existing indexes: {str(e)}"
            )
            
    def get_ids_from_query(self, input_vector, title):
        print("Searching Pinecone...")
        print(title)

        new_ids = set()  # Initialize new_ids outside the loop

        while True:
            results = self.pinecone_index.query(
                vector=input_vector,
                top_k=10000,
                filter={
                    "title": {"$eq": f"{title}"},
                },
            )
            
            ids = set()
            for result in results['matches']:
                ids.add(result['id'])
            # Check if there's any overlap between ids and new_ids
            if ids.issubset(new_ids):
                break
            else:
                new_ids.update(ids)  # Add all new ids to new_ids

        return new_ids

           
    def get_all_ids_from_index(self, title):
        num_dimensions = 1536

        num_vectors = self.pinecone_index.describe_index_stats(
        )["total_vector_count"]
        
        input_vector = np.random.rand(num_dimensions).tolist()
        ids = self.get_ids_from_query(input_vector, title)

        return ids
    
    def delete_vector_database(self, old_reference):
        try :
            batch_size = 1000
            all_ids = self.get_all_ids_from_index(old_reference['title'])
            all_ids = list(all_ids)
            
            # Split ids into chunks of batch_size
            num_batches = ceil(len(all_ids) / batch_size)
            
            for i in range(num_batches):
                # Fetch a batch of IDs
                batch_ids = all_ids[i * batch_size: (i + 1) * batch_size]
                self.pinecone_index.delete(ids=batch_ids)
                logging.info(f"delete from id {i * batch_size} to {(i + 1) * batch_size} successful")
        except Exception as e:
            return JSONResponse(status_code=500, content="An error occurred while delete metadata")    

    def update_vector_database(self, old_reference, new_reference):
        
        reference = new_reference
        
        all_ids = self.get_all_ids_from_index(old_reference['title'])
        all_ids = list(all_ids)
        
        for id in all_ids:
            self.pinecone_index.update(
                id=id,
                set_metadata=reference
            )

    def load_existing_indexes(self):
        """Load existing indexes from Pinecone."""
        try:
            client = self._get_pinecone_client()
            pinecone_index = client.Index(self.index_name)
            
            vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
            retriever = VectorStoreIndex.from_vector_store(vector_store)
            
            return retriever
        except Exception as e:
            return JSONResponse(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                content=f"Error loading existing indexes: {str(e)}"
            )