from openai import OpenAI
import os
import networkx as nx
from dotenv import load_dotenv
from constants import DOCUMENTS
from tqdm import tqdm
from cdlib import algorithms
import matplotlib.pyplot as plt


load_dotenv(".env.example")

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


def draw_graph(graph):
    pos = nx.spring_layout(graph)  # Position the nodes
    plt.figure(figsize=(12, 8))
    nx.draw(
        graph,
        pos,
        with_labels=True,
        node_color="skyblue",
        edge_color="gray",
        node_size=1500,
        font_size=10,
        font_weight="bold",
    )
    edge_labels = nx.get_edge_attributes(graph, "label")
    nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8)
    plt.title("Graph Visualization of Extracted Entities and Relationships")
    plt.savefig("graph.png")
    plt.show()


# Source texts -> chunks
def get_chunks(documents, chunk_size=1000, overlap_size=200):
    chunks = []
    for doc in documents:
        for i in range(0, len(doc), chunk_size - overlap_size):
            chunks.append(doc[i : i + chunk_size])
    return chunks


# print(get_chunks(DOCUMENTS))


# Chunks -> Element instances
def extract_elements(chunks):
    elements = []
    for index, chunk in enumerate(chunks):
        print(f"Processing chunk {index + 1}/{len(chunks)}")
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": "Extract entities and relationships from the following text.",
                },
                {"role": "user", "content": chunk},
            ],
        )
        print(response.choices[0].message.content)
        entities_and_relations = response.choices[0].message.content
        elements.append(entities_and_relations)
    return elements


# print(extract_elements(get_chunks(DOCUMENTS)))


# Element instances -> Element summaries
def summarize_elements(elements):
    summaries = []
    for index, element in enumerate(elements):
        print(f"Summarizing element {index + 1}/{len(elements)}")
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": 'Summarize the following entities and relationships in a structured format. Use "->" to represent relationships, after the "Relationships:" word.',
                },
                {"role": "user", "content": element},
            ],
        )
        print("Element summary:", response.choices[0].message.content)
        summary = response.choices[0].message.content
        summaries.append(summary)
    return summaries


# print(summarize_elements(extract_elements(get_chunks(DOCUMENTS))))


# Element summaries -> Graph communities
def build_graph(summaries):
    G = nx.Graph()
    for index, summary in enumerate(summaries):
        print(f"Summary index {index + 1}/{len(summaries)}")
        lines = summary.split("\n")
        entities_section = False
        relationships_section = False
        entities = []
        for line in tqdm(lines):
            if line.startswith("### Entities:") or line.startswith("**Entities:**"):
                entities_section = True
                relationships_section = False
                continue
            elif line.startswith("### Relationships:") or line.startswith(
                "**Relationships:**"
            ):
                entities_section = False
                relationships_section = True
                continue
            if entities_section and line.strip():
                if line[0].isdigit() and line[1] == ".":
                    line = line.split(".", 1)[1].strip()
                entity = line.strip()
                entity = entity.replace("**", "")
                entities.append(entity)
                G.add_node(entity)
            elif relationships_section and line.strip():
                parts = line.split("->")
                if len(parts) == 2:
                    source = parts[0].strip()
                    target = parts[-1].strip()
                    relation = " -> ".join(parts[1:-1]).strip()
                    G.add_edge(source, target, label=relation)

    return G


# Graph communities -> Graph summaries
def detect_communities(graph):
    communities = []
    index = 0
    for component in nx.connected_components(graph):
        print(
            f"Component index {index} of {len(list(nx.connected_components(graph)))}:"
        )
        subgraph = graph.subgraph(component)
        if len(subgraph.nodes) > 1:
            try:
                sub_communities = algorithms.leiden(subgraph)
                for community in sub_communities.communities:
                    communities.append(list(community))
            except Exception as e:
                print(f"Error processing community {index}: {e}")
        else:
            communities.append(list(subgraph.nodes))
        index += 1
    print("Communities from detect_communities:", communities)
    return communities


# summarize communities
def summarize_communities(communities, graph):
    community_summaries = []
    for index, community in enumerate(communities):
        print(f"Summarize Community index {index+1}/{len(communities)}:")
        subgraph = graph.subgraph(community)
        nodes = list(subgraph.nodes)
        edges = list(subgraph.edges(data=True))
        description = "Entities: " + ", ".join(nodes) + "\nRelationships: "
        relationships = []
        for edge in edges:
            source, target, data = edge
            relation = data.get("label", "")
            relationships.append(f"{source} -> {data['label']} -> {target}")
        description += ", ".join(relationships)

        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": "Summarize the following community of entities and relationships.",
                },
                {"role": "user", "content": description},
            ],
        )
        summary = response.choices[0].message.content.strip()
        community_summaries.append(summary)
    return community_summaries


# Community Summaries → Community Answers → Global Answer
def generate_answer(community_summaries, query):
    intermediate_answers = []
    for index, summary in enumerate(community_summaries):
        print(f"Answering community {index+1}/{len(community_summaries)}:")
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": "Answer the following query based on the provided summary.",
                },
                {"role": "user", "content": f"Query: {query} Summary: {summary}"},
            ],
        )
        print("Intermediate answer:", response.choices[0].message.content)
        intermediate_answers.append(response.choices[0].message.content)

    final_response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "system",
                "content": "Combine these answers into a final, concise response.",
            },
            {
                "role": "user",
                "content": f"Intermediate answers: {intermediate_answers}",
            },
        ],
    )
    final_answer = final_response.choices[0].message.content
    return final_answer


def graphrag_pipeline(documents, query):
    chunks = get_chunks(documents)
    elements = extract_elements(chunks)
    summaries = summarize_elements(elements)
    graph = build_graph(summaries)
    num_entities = graph.number_of_nodes()
    print(f"Number of entities in the graph: {num_entities}")
    draw_graph(graph)
    communities = detect_communities(graph)
    print(communities)
    community_summaries = summarize_communities(communities, graph)
    final_answer = generate_answer(community_summaries, query)
    return final_answer


query = "What factors in these articles can impact medical inflation in the UK in the short term?"
# "What are the main themes in these documents?"
print("Query:", query)
answer = graphrag_pipeline(DOCUMENTS, query)
print("Answer:", answer)