#semantic_analysis.py
import streamlit as st
import spacy
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter

# Remove the global nlp model loading

# Define colors for grammatical categories
POS_COLORS = {
    'ADJ': '#FFA07A',    # Light Salmon
    'ADP': '#98FB98',    # Pale Green
    'ADV': '#87CEFA',    # Light Sky Blue
    'AUX': '#DDA0DD',    # Plum
    'CCONJ': '#F0E68C',  # Khaki
    'DET': '#FFB6C1',    # Light Pink
    'INTJ': '#FF6347',   # Tomato
    'NOUN': '#90EE90',   # Light Green
    'NUM': '#FAFAD2',    # Light Goldenrod Yellow
    'PART': '#D3D3D3',   # Light Gray
    'PRON': '#FFA500',   # Orange
    'PROPN': '#20B2AA',  # Light Sea Green
    'SCONJ': '#DEB887',  # Burlywood
    'SYM': '#7B68EE',    # Medium Slate Blue
    'VERB': '#FF69B4',   # Hot Pink
    'X': '#A9A9A9',      # Dark Gray
}

POS_TRANSLATIONS = {
    'es': {
        'ADJ': 'Adjetivo',
        'ADP': 'Adposición',
        'ADV': 'Adverbio',
        'AUX': 'Auxiliar',
        'CCONJ': 'Conjunción Coordinante',
        'DET': 'Determinante',
        'INTJ': 'Interjección',
        'NOUN': 'Sustantivo',
        'NUM': 'Número',
        'PART': 'Partícula',
        'PRON': 'Pronombre',
        'PROPN': 'Nombre Propio',
        'SCONJ': 'Conjunción Subordinante',
        'SYM': 'Símbolo',
        'VERB': 'Verbo',
        'X': 'Otro',
    },
    'en': {
        'ADJ': 'Adjective',
        'ADP': 'Adposition',
        'ADV': 'Adverb',
        'AUX': 'Auxiliary',
        'CCONJ': 'Coordinating Conjunction',
        'DET': 'Determiner',
        'INTJ': 'Interjection',
        'NOUN': 'Noun',
        'NUM': 'Number',
        'PART': 'Particle',
        'PRON': 'Pronoun',
        'PROPN': 'Proper Noun',
        'SCONJ': 'Subordinating Conjunction',
        'SYM': 'Symbol',
        'VERB': 'Verb',
        'X': 'Other',
    },
    'fr': {
        'ADJ': 'Adjectif',
        'ADP': 'Adposition',
        'ADV': 'Adverbe',
        'AUX': 'Auxiliaire',
        'CCONJ': 'Conjonction de Coordination',
        'DET': 'Déterminant',
        'INTJ': 'Interjection',
        'NOUN': 'Nom',
        'NUM': 'Nombre',
        'PART': 'Particule',
        'PRON': 'Pronom',
        'PROPN': 'Nom Propre',
        'SCONJ': 'Conjonction de Subordination',
        'SYM': 'Symbole',
        'VERB': 'Verbe',
        'X': 'Autre',
    }
}
########################################################################################################################################

# Definimos las etiquetas y colores para cada idioma
ENTITY_LABELS = {
    'es': {
        "Personas": "lightblue",
        "Conceptos": "lightgreen",
        "Lugares": "lightcoral",
        "Fechas": "lightyellow"
    },
    'en': {
        "People": "lightblue",
        "Concepts": "lightgreen",
        "Places": "lightcoral",
        "Dates": "lightyellow"
    },
    'fr': {
        "Personnes": "lightblue",
        "Concepts": "lightgreen",
        "Lieux": "lightcoral",
        "Dates": "lightyellow"
    }
}

#########################################################################################################
def count_pos(doc):
    return Counter(token.pos_ for token in doc if token.pos_ != 'PUNCT')

import spacy
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter

# Mantén las definiciones de POS_COLORS y POS_TRANSLATIONS que ya tienes

#############################################################################################################################
def extract_entities(doc, lang):
    entities = {label: [] for label in ENTITY_LABELS[lang].keys()}
    
    for ent in doc.ents:
        if ent.label_ == "PERSON":
            entities[list(ENTITY_LABELS[lang].keys())[0]].append(ent.text)
        elif ent.label_ in ["LOC", "GPE"]:
            entities[list(ENTITY_LABELS[lang].keys())[2]].append(ent.text)
        elif ent.label_ == "DATE":
            entities[list(ENTITY_LABELS[lang].keys())[3]].append(ent.text)
        else:
            entities[list(ENTITY_LABELS[lang].keys())[1]].append(ent.text)
    
    return entities

#####################################################################################################################

def visualize_context_graph(doc, lang):
    G = nx.Graph()
    entities = extract_entities(doc, lang)
    color_map = ENTITY_LABELS[lang]

    # Add nodes
    for category, items in entities.items():
        for item in items:
            G.add_node(item, category=category)

    # Add edges
    for sent in doc.sents:
        sent_entities = [ent for ent in sent.ents if ent.text in G.nodes()]
        for i in range(len(sent_entities)):
            for j in range(i+1, len(sent_entities)):
                G.add_edge(sent_entities[i].text, sent_entities[j].text)

    # Visualize
    plt.figure(figsize=(30, 22))  # Increased figure size
    pos = nx.spring_layout(G, k=0.7, iterations=50)  # Adjusted layout

    node_colors = [color_map[G.nodes[node]['category']] for node in G.nodes()]

    nx.draw(G, pos, node_color=node_colors, with_labels=True, 
            node_size=10000,  # Increased node size
            font_size=18,  # Increased font size
            font_weight='bold',
            width=2,  # Increased edge width
            arrowsize=30)  # Increased arrow size

    # Add a legend
    legend_elements = [plt.Rectangle((0,0),1,1,fc=color, edgecolor='none', label=category) 
                       for category, color in color_map.items()]
    plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1, 1), fontsize=16)  # Increased legend font size

    plt.title("Análisis del Contexto" if lang == 'es' else "Context Analysis" if lang == 'en' else "Analyse du Contexte", fontsize=24)  # Increased title font size
    plt.axis('off')

    return plt

############################################################################################################################################

def visualize_semantic_relations(doc, lang):
    G = nx.Graph()
    word_freq = Counter(token.text.lower() for token in doc if token.pos_ not in ['PUNCT', 'SPACE'])
    top_words = [word for word, _ in word_freq.most_common(20)]  # Top 20 most frequent words

    for token in doc:
        if token.text.lower() in top_words:
            G.add_node(token.text, pos=token.pos_)

    for token in doc:
        if token.text.lower() in top_words and token.head.text.lower() in top_words:
            G.add_edge(token.text, token.head.text, label=token.dep_)

    plt.figure(figsize=(36, 27))  # Increased figure size
    pos = nx.spring_layout(G, k=0.7, iterations=50)  # Adjusted layout

    node_colors = [POS_COLORS.get(G.nodes[node]['pos'], '#CCCCCC') for node in G.nodes()]

    nx.draw(G, pos, node_color=node_colors, with_labels=True, 
            node_size=10000,  # Increased node size
            font_size=16,  # Increased font size
            font_weight='bold', 
            arrows=True, 
            arrowsize=30,  # Increased arrow size
            width=3,  # Increased edge width
            edge_color='gray')

    edge_labels = nx.get_edge_attributes(G, 'label')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=14)  # Increased edge label font size

    title = {
        'es': "Relaciones Semánticas Relevantes",
        'en': "Relevant Semantic Relations",
        'fr': "Relations Sémantiques Pertinentes"
    }
    plt.title(title[lang], fontsize=24, fontweight='bold')  # Increased title font size
    plt.axis('off')

    legend_elements = [plt.Rectangle((0,0),1,1, facecolor=POS_COLORS.get(pos, '#CCCCCC'), edgecolor='none', 
                       label=f"{POS_TRANSLATIONS[lang].get(pos, pos)}")
                       for pos in set(nx.get_node_attributes(G, 'pos').values())]
    plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16)  # Increased legend font size

    return plt

    
############################################################################################################################################
def perform_semantic_analysis(text, nlp, lang):
    doc = nlp(text)
    
    # Imprimir entidades para depuración
    print(f"Entidades encontradas ({lang}):")
    for ent in doc.ents:
        print(f"{ent.text} - {ent.label_}")
    
    context_graph = visualize_context_graph(doc, lang)
    relations_graph = visualize_semantic_relations(doc, lang)
    return context_graph, relations_graph