import os
from typing import List, Optional, Tuple, Any
from collections import OrderedDict

import pandas as pd
from loguru import logger
import pm4py
import plotly.graph_objects as go
import networkx as nx
import matplotlib.pyplot as plt
from PIL import Image
from pydantic import BaseModel


class ProcessMap(BaseModel):
    net: Any
    start_activities: List | None
    end_activities: List | None
    img: Any | None


def dfg2networkx( dfg, start, end):
    """Dfg to networkx 
    Argument
        dfg: a list of dict of edges from directly-follow-graph
        start: a dict of start activities
        end: a dict of end activities
    Return
        nx: networkx graph object
    """
    PROCESS_START = '#Start#'
    PROCESS_END = '#End#'
    nodes = { PROCESS_START: 0, PROCESS_END: 1}
    node_idx = 2
    for activity in start:
        assert activity not in nodes, f"#ERROR: {activity} exists"
        nodes[activity] = node_idx
        node_idx += 1
    for activity in end:
        assert activity not in nodes, f"#ERROR: {activity} exists"
        nodes[activity] = node_idx
        node_idx += 1
    for node in dfg:
        left_activity = node[0]
        if left_activity not in nodes:
            nodes[left_activity] = node_idx
            node_idx +=1 
        right_activity = node[1]
        if right_activity not in nodes:
            nodes[right_activity] = node_idx
            node_idx +=1 
    nodes = list(nodes.keys())
    
    edges = []
    for activity in start:
        from_id = str(PROCESS_START)
        to_id = str(activity)
        edges.append( ( PROCESS_START, activity) ) 
    for activity in end:
        from_id = str(activity)
        to_id = str(PROCESS_END)
        edges.append( ( activity, PROCESS_END) ) 
    for transition in dfg:
        edges.append( ( transition[0], transition[1]) ) 
    nx_graph = nx.DiGraph()
    nx_graph.add_nodes_from( nodes)
    nx_graph.add_edges_from(edges)
    return nx_graph


def discover_process_map_variants( df, top_k: int = 0, type: str = 'dfg'):
    """Discover process map from data frame (raw event log)
    Argument
        df: a pandas dataframe
        top_k: top k variants
        type: dfg or petri
    Return
        dfg, start_activities, end_activities
    """
    event_log = pm4py.format_dataframe( df, case_id='case_id', activity_key='activity', timestamp_key='timestamp')
    if top_k > 0:
        event_log = pm4py.filter_variants_top_k( event_log, k = top_k)
    dfg, start_activities, end_activities = pm4py.discover_dfg(event_log)
    pm4py.view_dfg(dfg, start_activities=start_activities, end_activities=end_activities)
    return dfg, start_activities, end_activities


def discover_process_map_activities_connections( df, activity_rank: int = 0, connection_rank: int = 0, state: dict = {}, type: str = 'dfg'):
    """Discover process map from data frame (raw event log)
    Argument
        df: a pandas dataframe
        top_k: top k variants
        type: dfg or petri
    Return
        dfg, start_activities, end_activities
    """
    event_log = pm4py.format_dataframe( df, case_id='case_id', activity_key='activity', timestamp_key='timestamp')
    full_dfg, _, __ = pm4py.discover_dfg(event_log)
    ranked_connections = OrderedDict(sorted(full_dfg.items(), key=lambda item: item[1], reverse=True))

    if activity_rank > 0:
        pass
    if connection_rank > 0:
        top_variant_connections = state.get('top_variant_connections', [])
        filtered_connections = list(ranked_connections.keys())[ : (connection_rank+ len(ranked_connections))]
    else:
        filtered_connections = list(ranked_connections.keys())
    event_log = pm4py.filter_directly_follows_relation( event_log, relations = filtered_connections)
    dfg, start_activities, end_activities = pm4py.discover_dfg(event_log)
    pm4py.view_dfg(dfg, start_activities=start_activities, end_activities=end_activities)
    return dfg, start_activities, end_activities


def discover_process_map( df: pd.DataFrame, type: str = 'dfg'):
    """
    """
    event_log = pm4py.format_dataframe( df, case_id='case_id', activity_key='activity', timestamp_key='timestamp')
    if type=='dfg':
        dfg, start_activities, end_activities = pm4py.discover_dfg(event_log)
        pm4py.view_dfg(dfg, start_activities=start_activities, end_activities=end_activities)
        return dfg, start_activities, end_activities
    elif type=='petrinet':
        net, im, fm = pm4py.discover_petri_net_inductive(event_log)
        pm4py.view_petri_net( petri_net=net, initial_marking=im, final_marking=fm)
        file_path = 'output/petri_net.png'
        pm4py.save_vis_petri_net( net, im, fm, file_path)
        img = Image.open(file_path)
        return net, img
    elif type=='bpmn':
        net = pm4py.discover_bpmn_inductive(event_log)
        pm4py.view_bpmn(net, format='png')
        file_path = 'output/bpmn.png'
        pm4py.save_vis_bpmn( net, file_path)
        img = Image.open(file_path)
        return net, img
    else:
        raise Exception(f"Invalid type: {type}")


def view_networkx( nx_graph, layout):
    """
    Argument
        nx_graph
    Return 
        graph object
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    """
    # Create node scatter plot
    node_trace = go.Scatter(
        x=[layout[n][0] for n in nx_graph.nodes],
        y=[layout[n][1] for n in nx_graph.nodes],
        text=list(nx_graph.nodes),
        mode='markers+text',
        hovertext = [n for n in nx_graph.nodes],
        textposition='top center',
        marker=dict(size=5, color='LightSkyBlue', line=dict(width=2))
    )
    
    # Create edge lines
    edge_trace = go.Scatter(
        x=(),
        y=(),
        line=dict(width=1.5, color='#888'),
        hoverinfo='none',
        mode='lines'
    )
    
    # Add arrows for directed edges
    annotations = []
    for edge in nx_graph.edges:
        x0, y0 = layout[edge[0]]
        x1, y1 = layout[edge[1]]
        edge_trace['x'] += (x0, x1, None)
        edge_trace['y'] += (y0, y1, None)
    
        # Calculate direction of the arrow
        annotations.append(
            dict(
                ax=x0,
                ay=y0,
                axref='x',
                ayref='y',
                x=x1,
                y=y1,
                xref='x',
                yref='y',
                showarrow=True,
                arrowhead=2,
                arrowsize=1,
                arrowwidth=2,
                arrowcolor='Gray'
            )
        )
    
    # Draw the figure
    fig = go.Figure(data=[edge_trace, node_trace],
        layout=go.Layout( 
        showlegend=False,
        hovermode='closest',
        margin=dict(b=0, l=0, r=0, t=0),
        annotations=annotations,
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False)
    ))
    fig = fig.update_xaxes(showticklabels=False)
    fig = fig.update_yaxes(showticklabels=False)
    return fig


def view_process_map( nx_graph, process_type: str = 'dfg', layout_type: str = 'sfdp'):
    """
    """
    layout = nx.nx_agraph.graphviz_layout( nx_graph, prog=layout_type)
    fig = view_networkx(nx_graph, layout)
    return fig