import os
from typing import List, Dict, Tuple
from setup.easy_imports import (
    HuggingFaceEmbeddings,
    PyPDFLoader,
    Chroma,
    ChatOpenAI,
    create_extraction_chain,
    PromptTemplate,
    RecursiveCharacterTextSplitter,
)
from dataclasses import dataclass
import uuid
import json
from langchain_huggingface import HuggingFaceEndpoint
from setup.environment import default_model

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ.get("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_PROJECT"] = "VELLA"


@dataclass
class DocumentChunk:
    content: str
    page_number: int
    chunk_id: str
    start_char: int
    end_char: int


class DocumentSummarizer:

    def __init__(
        self, openai_api_key: str, model, embedding, chunk_config, system_prompt
    ):
        self.model = model
        self.system_prompt = system_prompt
        self.openai_api_key = openai_api_key
        self.embeddings = HuggingFaceEmbeddings(model_name=embedding)
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_config["size"], chunk_overlap=chunk_config["overlap"]
        )
        self.chunk_metadata = {}  # Store chunk metadata for tracing

    def load_and_split_document(self, pdf_path: str) -> List[DocumentChunk]:
        """Load PDF and split into chunks with metadata"""
        loader = PyPDFLoader(pdf_path)
        pages = loader.load()
        chunks = []
        char_count = 0

        for page in pages:
            text = page.page_content
            # Split the page content
            page_chunks = self.text_splitter.split_text(text)

            for chunk in page_chunks:
                chunk_id = str(uuid.uuid4())
                start_char = text.find(chunk)
                end_char = start_char + len(chunk)

                doc_chunk = DocumentChunk(
                    content=chunk,
                    page_number=page.metadata.get("page") + 1,  # 1-based page numbering
                    chunk_id=chunk_id,
                    start_char=char_count + start_char,
                    end_char=char_count + end_char,
                )
                chunks.append(doc_chunk)

                # Store metadata for later retrieval
                self.chunk_metadata[chunk_id] = {
                    "page": doc_chunk.page_number,
                    "start_char": doc_chunk.start_char,
                    "end_char": doc_chunk.end_char,
                }

            char_count += len(text)

        return chunks

    def create_vector_store(self, chunks: List[DocumentChunk]) -> Chroma:
        """Create vector store with metadata"""
        texts = [chunk.content for chunk in chunks]
        metadatas = [
            {
                "chunk_id": chunk.chunk_id,
                "page": chunk.page_number,
                "start_char": chunk.start_char,
                "end_char": chunk.end_char,
            }
            for chunk in chunks
        ]

        vector_store = Chroma.from_texts(
            texts=texts, metadatas=metadatas, embedding=self.embeddings
        )
        return vector_store

    def generate_summary_with_sources(
        self,
        vector_store: Chroma,
        query: str = "Summarize the main points of this document",
    ) -> List[Dict]:
        """Generate summary with source citations, returning structured JSON data"""
        # Retrieve relevant chunks with metadata
        relevant_docs = vector_store.similarity_search_with_score(query, k=5)

        # Prepare context and track sources
        contexts = []
        sources = []

        for doc, score in relevant_docs:
            chunk_id = doc.metadata["chunk_id"]
            context = doc.page_content
            contexts.append(context)

            sources.append(
                {
                    "content": context,
                    "page": doc.metadata["page"],
                    "chunk_id": chunk_id,
                    "relevance_score": score,
                }
            )

        prompt = PromptTemplate(
            template=self.system_prompt, input_variables=["context"]
        )
        llm = ""

        if self.model == default_model:
            llm = ChatOpenAI(
                temperature=0, model_name="gpt-4o-mini", api_key=self.openai_api_key
            )
        else:
            llm = HuggingFaceEndpoint(
                repo_id=self.model,
                task="text-generation",
                max_new_tokens=1100,
                do_sample=False,
                huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
            )

        response = llm.invoke(prompt.format(context="\n\n".join(contexts))).content

        # Split the response into paragraphs
        summaries = [p.strip() for p in response.split("\n\n") if p.strip()]

        # Create structured output
        structured_output = []
        for idx, summary in enumerate(summaries):
            # Associate each summary with the most relevant source
            structured_output.append(
                {
                    "content": summary,
                    "source": {
                        "page": sources[min(idx, len(sources) - 1)]["page"],
                        "text": sources[min(idx, len(sources) - 1)]["content"][:200]
                        + "...",
                        "relevance_score": sources[min(idx, len(sources) - 1)][
                            "relevance_score"
                        ],
                    },
                }
            )

        return structured_output

    def get_source_context(self, chunk_id: str, window: int = 100) -> Dict:
        """Get extended context around a specific chunk"""
        metadata = self.chunk_metadata.get(chunk_id)
        if not metadata:
            return None

        return {
            "page": metadata["page"],
            "start_char": metadata["start_char"],
            "end_char": metadata["end_char"],
        }


def get_llm_summary_answer_by_cursor(serializer, listaPDFs):
    # By Luan
    allPdfsChunks = []

    # Initialize summarizer
    summarizer = DocumentSummarizer(
        openai_api_key=os.environ.get("OPENAI_API_KEY"),
        embedding=serializer["hf_embedding"],
        chunk_config={
            "size": serializer["chunk_size"],
            "overlap": serializer["chunk_overlap"],
        },
        system_prompt=serializer["system_prompt"],
        model=serializer["model"],
    )

    # Load and process document
    for pdf in listaPDFs:
        pdf_path = pdf
        chunks = summarizer.load_and_split_document(pdf_path)
        allPdfsChunks = allPdfsChunks + chunks

    vector_store = summarizer.create_vector_store(allPdfsChunks)

    # Generate structured summary
    structured_summaries = summarizer.generate_summary_with_sources(vector_store)

    # Print or return the structured data
    # print(structured_summaries)
    json_data = json.dumps(structured_summaries)
    print("\n\n")
    print(json_data)
    return structured_summaries
    # If you need to send to frontend, you can just return structured_summaries
    # It will be in the format:
    # [
    #     {
    #         "content": "Summary point 1...",
    #         "source": {
    #             "page": 1,
    #             "text": "Source text...",
    #             "relevance_score": 0.95
    #         }
    #     },
    #     ...
    # ]


if __name__ == "__main__":
    get_llm_summary_answer_by_cursor()