import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple
import re
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import StratifiedKFold
import torch
from transformers import AutoTokenizer
import logging
from tqdm import tqdm

class AdvancedNarrativeProcessor:
    def __init__(self, annotations_file: str, raw_dir: str, model_name: str = "microsoft/deberta-v3-large"):
        self.setup_logging()
        self.logger = logging.getLogger(__name__)
        
        self.annotations_file = Path(annotations_file)
        self.raw_dir = Path(raw_dir)
        self.model_name = model_name
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Initialize SpaCy
        self.nlp = spacy.load("en_core_web_sm")  # Download it with `python -m spacy download en_core_web_sm`
        self.stopwords = spacy.lang.en.stop_words.STOP_WORDS
        
        # Initialize state
        self.df = None
        self.processed_data = None
        self.label_encodings = None
        self.tfidf_vectorizer = None
        
    def setup_logging(self):
        """Set up logging configuration"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )

    def load_and_process_data(self) -> Dict:
        """Main processing pipeline"""
        self.logger.info("Starting data processing pipeline...")
        
        # 1. Load Raw Data
        self.load_data()
        
        # 2. Process Text and Labels
        processed_articles = self.process_all_articles()
        
        # 3. Engineer Features
        self.add_features(processed_articles)
        
        # 4. Create Data Splits
        train_data, val_data = self.create_splits(processed_articles)
        
        # 5. Prepare Model Inputs
        train_inputs = self.prepare_model_inputs(train_data)
        val_inputs = self.prepare_model_inputs(val_data)
        
        self.logger.info("Data processing complete!")
        
        return {
            'train': train_inputs,
            'val': val_inputs,
            'label_encodings': self.label_encodings,
            'stats': self.get_statistics()
        }

    def load_data(self):
        """Load and prepare the annotation data"""
        self.logger.info(f"Loading annotations from {self.annotations_file}")
        
        # Load annotations file
        self.df = pd.read_csv(
            self.annotations_file,
            sep='\t',
            names=['article_id', 'narratives', 'subnarratives']
        )
        
        # Create label encodings
        all_subnarratives = set()
        for subnarrs in self.df['subnarratives'].str.split(';'):
            all_subnarratives.update(subnarrs)
        
        self.label_encodings = {
            label: idx for idx, label in enumerate(sorted(all_subnarratives))
        }
        
        self.logger.info(f"Loaded {len(self.df)} articles with {len(self.label_encodings)} unique labels")

    def read_article(self, article_id: str) -> str:
        """Read article content from file"""
        try:
            with open(self.raw_dir / article_id, 'r', encoding='utf-8') as f:
                return f.read()
        except Exception as e:
            self.logger.error(f"Error reading article {article_id}: {e}")
            return ""

    def process_text(self, text: str) -> str:
        """Enhanced text processing"""
        # Remove URLs and emails
        text = re.sub(r'http\S+|www\S+|\S+@\S+', '', text)
        
        # Normalize whitespace
        text = ' '.join(text.split())
        
        # Handle numbers and special characters
        text = re.sub(r'\d+', ' NUM ', text)
        text = re.sub(r'[^\w\s.,!?-]', ' ', text)
        
        return text.strip()

    def extract_features(self, text: str) -> Dict:
        """Extract rich text features using SpaCy."""
        # Process text with SpaCy
        doc = self.nlp(text)
        words = [token.text for token in doc if not token.is_space]
        sentences = list(doc.sents)
        
        return {
            'length': len(words),
            'avg_word_length': np.mean([len(w) for w in words]),
            'sentence_count': len(sentences),
            'avg_sentence_length': len(words) / len(sentences) if sentences else 0,
            'unique_words': len(set(words)),
            'density': len(set(words)) / len(words) if words else 0
        }

    def process_all_articles(self) -> List[Dict]:
        """Process all articles with rich features"""
        processed_articles = []
        
        for _, row in tqdm(self.df.iterrows(), desc="Processing articles"):
            # Read and process text
            text = self.read_article(row['article_id'])
            processed_text = self.process_text(text)
            
            # Extract features
            features = self.extract_features(processed_text)
            
            # Process labels
            labels = self.process_labels(row['subnarratives'])
            
            processed_articles.append({
                'id': row['article_id'],
                'text': processed_text,
                'features': features,
                'labels': labels,
                'domain': 'UA' if 'UA' in row['article_id'] else 'CC'
            })
            
        return processed_articles

    def process_labels(self, subnarratives: str) -> List[int]:
        """Convert subnarratives string to label vector"""
        label_vector = [0] * len(self.label_encodings)
        for subnarr in subnarratives.split(';'):
            if subnarr in self.label_encodings:
                label_vector[self.label_encodings[subnarr]] = 1
        return label_vector

    def add_features(self, articles: List[Dict]):
        """Add TF-IDF and additional features"""
        # Create TF-IDF features
        self.tfidf_vectorizer = TfidfVectorizer(
            max_features=5000,
            stop_words='english'
        )
        
        texts = [article['text'] for article in articles]
        tfidf_features = self.tfidf_vectorizer.fit_transform(texts)
        
        # Add to articles
        for idx, article in enumerate(articles):
            article['tfidf_features'] = tfidf_features[idx]

    def create_splits(self, articles: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
        """Create stratified splits"""
        # Use domain and label distribution for stratification
        stratify_labels = [f"{a['domain']}_{'-'.join(str(l) for l in a['labels'])}" 
                          for a in articles]
        
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        train_idx, val_idx = next(skf.split(articles, stratify_labels))
        
        return [articles[i] for i in train_idx], [articles[i] for i in val_idx]

    def prepare_model_inputs(self, articles: List[Dict]) -> Dict[str, torch.Tensor]:
        """Prepare inputs for the model"""
        # Tokenize texts
        encodings = self.tokenizer(
            [a['text'] for a in articles],
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        
        # Convert labels to tensor
        labels = torch.tensor([a['labels'] for a in articles])
        
        # Convert features to tensor with explicit float32 dtype
        features = torch.tensor([[
            a['features']['length'],
            a['features']['avg_word_length'],
            a['features']['sentence_count'],
            a['features']['avg_sentence_length'],
            a['features']['density']
        ] for a in articles], dtype=torch.float32)  # Specify float32 dtype
        
        return {
            'input_ids': encodings['input_ids'],
            'attention_mask': encodings['attention_mask'],
            'labels': labels,
            'features': features
        }

    def get_label_distribution(self) -> Dict:
        """Calculate the distribution of labels in the dataset"""
        if self.df is None:
            return {}
        
        label_counts = {}
        for subnarrs in self.df['subnarratives'].str.split(';'):
            for subnarr in subnarrs:
                if subnarr in self.label_encodings:
                    label_counts[subnarr] = label_counts.get(subnarr, 0) + 1
        
        return label_counts

    def get_statistics(self) -> Dict:
        """Get processing statistics"""
        return {
            'total_articles': len(self.df),
            'label_distribution': self.get_label_distribution(),
            'vocabulary_size': len(self.tfidf_vectorizer.vocabulary_),
            'domain_distribution': self.df['article_id'].apply(
                lambda x: 'UA' if 'UA' in x else 'CC'
            ).value_counts().to_dict()
        }

    def analyze_features(self, processed_data: Dict) -> Dict:
        """Analyze feature statistics from processed data"""
        train_features = processed_data['train']['features']
        feature_names = ['length', 'avg_word_length', 'sentence_count', 
                        'avg_sentence_length', 'density']
        
        feature_stats = {}
        for i, name in enumerate(feature_names):
            values = train_features[:, i]
            feature_stats[name] = {
                'mean': float(values.mean()),
                'std': float(values.std()),
                'min': float(values.min()),
                'max': float(values.max())
            }
        
        return feature_stats

# Usage example
if __name__ == "__main__":
    processor = AdvancedNarrativeProcessor(
        annotations_file="../../data/subtask-2-annotations.txt",
        raw_dir="../../data/raw"
    )
    
    processed_data = processor.load_and_process_data()
    
    # Print statistics
    stats = processed_data['stats']
    print("\n=== Processing Statistics ===")
    print(f"Total Articles: {stats['total_articles']}")
    print(f"Vocabulary Size: {stats['vocabulary_size']}")
    print("\nDomain Distribution:")
    for domain, count in stats['domain_distribution'].items():
        print(f"{domain}: {count} articles")
    
    # Print feature analysis
    feature_stats = processor.analyze_features(processed_data)
    print("\n=== Feature Statistics ===")
    for name, stats in feature_stats.items():
        print(f"{name}:")
        print(f"  Mean: {stats['mean']:.2f}")
        print(f"  Std: {stats['std']:.2f}")
        print(f"  Range: [{stats['min']:.2f}, {stats['max']:.2f}]")