#syntax_analysis.py
import streamlit as st
import spacy
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter

# Remove the global nlp model loading

# Define colors for grammatical categories
POS_COLORS = {
    'ADJ': '#FFA07A',    # Light Salmon
    'ADP': '#98FB98',    # Pale Green
    'ADV': '#87CEFA',    # Light Sky Blue
    'AUX': '#DDA0DD',    # Plum
    'CCONJ': '#F0E68C',  # Khaki
    'DET': '#FFB6C1',    # Light Pink
    'INTJ': '#FF6347',   # Tomato
    'NOUN': '#90EE90',   # Light Green
    'NUM': '#FAFAD2',    # Light Goldenrod Yellow
    'PART': '#D3D3D3',   # Light Gray
    'PRON': '#FFA500',   # Orange
    'PROPN': '#20B2AA',  # Light Sea Green
    'SCONJ': '#DEB887',  # Burlywood
    'SYM': '#7B68EE',    # Medium Slate Blue
    'VERB': '#FF69B4',   # Hot Pink
    'X': '#A9A9A9',      # Dark Gray
}

POS_TRANSLATIONS = {
    'es': {
        'ADJ': 'Adjetivo',
        'ADP': 'Adposición',
        'ADV': 'Adverbio',
        'AUX': 'Auxiliar',
        'CCONJ': 'Conjunción Coordinante',
        'DET': 'Determinante',
        'INTJ': 'Interjección',
        'NOUN': 'Sustantivo',
        'NUM': 'Número',
        'PART': 'Partícula',
        'PRON': 'Pronombre',
        'PROPN': 'Nombre Propio',
        'SCONJ': 'Conjunción Subordinante',
        'SYM': 'Símbolo',
        'VERB': 'Verbo',
        'X': 'Otro',
    },
    'en': {
        'ADJ': 'Adjective',
        'ADP': 'Adposition',
        'ADV': 'Adverb',
        'AUX': 'Auxiliary',
        'CCONJ': 'Coordinating Conjunction',
        'DET': 'Determiner',
        'INTJ': 'Interjection',
        'NOUN': 'Noun',
        'NUM': 'Number',
        'PART': 'Particle',
        'PRON': 'Pronoun',
        'PROPN': 'Proper Noun',
        'SCONJ': 'Subordinating Conjunction',
        'SYM': 'Symbol',
        'VERB': 'Verb',
        'X': 'Other',
    },
    'fr': {
        'ADJ': 'Adjectif',
        'ADP': 'Adposition',
        'ADV': 'Adverbe',
        'AUX': 'Auxiliaire',
        'CCONJ': 'Conjonction de Coordination',
        'DET': 'Déterminant',
        'INTJ': 'Interjection',
        'NOUN': 'Nom',
        'NUM': 'Nombre',
        'PART': 'Particule',
        'PRON': 'Pronom',
        'PROPN': 'Nom Propre',
        'SCONJ': 'Conjonction de Subordination',
        'SYM': 'Symbole',
        'VERB': 'Verbe',
        'X': 'Autre',
    }
}

def count_pos(doc):
    return Counter(token.pos_ for token in doc if token.pos_ != 'PUNCT')

def create_syntax_graph(doc, lang):
    G = nx.DiGraph()
    pos_counts = count_pos(doc)
    word_nodes = {}
    word_colors = {}

    for token in doc:
        if token.pos_ != 'PUNCT':
            lower_text = token.text.lower()
            if lower_text not in word_nodes:
                node_id = len(word_nodes)
                word_nodes[lower_text] = node_id
                color = POS_COLORS.get(token.pos_, '#FFFFFF')
                word_colors[lower_text] = color
                G.add_node(node_id,
                           label=f"{token.text}\n[{POS_TRANSLATIONS[lang].get(token.pos_, token.pos_)}]",
                           pos=token.pos_,
                           size=pos_counts[token.pos_] * 500,
                           color=color)

            if token.dep_ != "ROOT" and token.head.pos_ != 'PUNCT':
                head_id = word_nodes.get(token.head.text.lower())
                if head_id is not None:
                    G.add_edge(head_id, word_nodes[lower_text], label=token.dep_)

    return G, word_colors

def visualize_syntax_graph(doc, lang):
    G, word_colors = create_syntax_graph(doc, lang)

    plt.figure(figsize=(24, 18))  # Increase figure size
    pos = nx.spring_layout(G, k=0.9, iterations=50)  # Adjust layout parameters

    node_colors = [data['color'] for _, data in G.nodes(data=True)]
    node_sizes = [data['size'] for _, data in G.nodes(data=True)]

    nx.draw(G, pos, with_labels=False, node_color=node_colors, node_size=node_sizes, arrows=True, 
            arrowsize=20, width=2, edge_color='gray')  # Adjust node and edge appearance

    nx.draw_networkx_labels(G, pos, {node: data['label'] for node, data in G.nodes(data=True)}, 
                            font_size=10, font_weight='bold')  # Increase font size and make bold

    edge_labels = nx.get_edge_attributes(G, 'label')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

    plt.title("Syntactic Analysis" if lang == 'en' else "Analyse Syntaxique" if lang == 'fr' else "Análisis Sintáctico",
              fontsize=20, fontweight='bold')  # Increase title font size
    plt.axis('off')

    legend_elements = [plt.Rectangle((0,0),1,1, facecolor=color, edgecolor='none', 
                       label=f"{POS_TRANSLATIONS[lang][pos]} ({count_pos(doc)[pos]})")
                       for pos, color in POS_COLORS.items() if pos in set(nx.get_node_attributes(G, 'pos').values())]
    plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=12)  # Increase legend font size

    return plt

def visualize_syntax(text, nlp, lang):
    max_tokens = 5000
    doc = nlp(text)
    if len(doc) > max_tokens:
        doc = nlp(text[:max_tokens])
        print(f"Warning: The input text is too long. Only the first {max_tokens} tokens will be visualized.")
    return visualize_syntax_graph(doc, lang)

def get_repeated_words_colors(doc):
    word_counts = Counter(token.text.lower() for token in doc if token.pos_ != 'PUNCT')
    repeated_words = {word: count for word, count in word_counts.items() if count > 1}

    word_colors = {}
    for token in doc:
        if token.text.lower() in repeated_words:
            word_colors[token.text.lower()] = POS_COLORS.get(token.pos_, '#FFFFFF')

    return word_colors

def highlight_repeated_words(doc, word_colors):
    highlighted_text = []
    for token in doc:
        if token.text.lower() in word_colors:
            color = word_colors[token.text.lower()]
            highlighted_text.append(f'<span style="background-color: {color};">{token.text}</span>')
        else:
            highlighted_text.append(token.text)
    return ' '.join(highlighted_text)