# Lyrics Generation Model Development
# Author: [Your Name]
# Project: Opentunes.ai

import torch
import torch.nn as nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    GPT2LMHeadModel,
    GPT2Tokenizer
)
import pandas as pd
import numpy as np
from pathlib import Path
import json
import wandb
from tqdm import tqdm

# 1. Data Loading and Preprocessing
class LyricsDataset(torch.utils.data.Dataset):
    """
    Custom Dataset for lyrics data.
    
    Features:
    - Loads and processes lyrics text
    - Handles style/genre tags
    - Manages rhyme patterns
    - Tokenization for transformer models
    """
    
    def __init__(self, data_dir, max_length=512):
        self.data_dir = Path(data_dir)
        self.max_length = max_length
        
        # Initialize tokenizer (using GPT-2 as base)
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load lyrics data
        self.lyrics_files = list(self.data_dir.glob("*.txt"))
        self.lyrics_data = self._load_lyrics_data()
        
    def _load_lyrics_data(self):
        """Load and preprocess lyrics from files."""
        data = []
        for file in self.lyrics_files:
            with open(file, 'r', encoding='utf-8') as f:
                lyrics = f.read()
            
            # Extract metadata from filename or content
            metadata = self._extract_metadata(file)
            
            data.append({
                'lyrics': lyrics,
                'genre': metadata.get('genre', 'unknown'),
                'style': metadata.get('style', 'unknown'),
                'structure': metadata.get('structure', 'verse-chorus')
            })
        return data
    
    def _extract_metadata(self, file):
        """Extract metadata from filename or content."""
        # Example filename format: pop_love_verse-chorus.txt
        parts = file.stem.split('_')
        return {
            'genre': parts[0] if len(parts) > 0 else 'unknown',
            'style': parts[1] if len(parts) > 1 else 'unknown',
            'structure': parts[2] if len(parts) > 2 else 'verse-chorus'
        }
    
    def __len__(self):
        return len(self.lyrics_data)
    
    def __getitem__(self, idx):
        item = self.lyrics_data[idx]
        
        # Prepare input text with metadata
        input_text = f"<|genre|>{item['genre']}<|style|>{item['style']}<|lyrics|>{item['lyrics']}"
        
        # Tokenize
        encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze()
        }

# 2. Model Architecture
class LyricsTransformer(nn.Module):
    """
    Transformer model for lyrics generation.
    
    Features:
    - GPT-2 based architecture
    - Style conditioning
    - Rhyme awareness
    - Structure control
    """
    
    def __init__(self, 
                 vocab_size=50257,  # GPT-2 vocabulary size
                 d_model=768,
                 nhead=12,
                 num_layers=6):
        super().__init__()
        
        # Load pretrained GPT-2
        self.transformer = GPT2LMHeadModel.from_pretrained('gpt2')
        
        # Add style embedding
        self.style_embedding = nn.Embedding(100, d_model)  # 100 different styles
        
        # Add additional layers for style conditioning
        self.style_projection = nn.Linear(d_model, d_model)
        self.genre_embedding = nn.Embedding(50, d_model)  # 50 different genres
        
    def forward(self, input_ids, attention_mask=None, style_ids=None):
        """
        Forward pass with style conditioning.
        
        Args:
            input_ids: Tokenized input text
            attention_mask: Attention mask for padding
            style_ids: Optional style conditioning ids
        """
        # Get base transformer outputs
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Add style conditioning if provided
        if style_ids is not None:
            style_embeds = self.style_embedding(style_ids)
            style_projection = self.style_projection(style_embeds)
            outputs.logits += style_projection
        
        return outputs

# 3. Training Pipeline
class LyricsTrainer:
    """
    Training pipeline for lyrics generation model.
    
    Features:
    - Custom training loop
    - Style-conditional training
    - Rhyme pattern learning
    - Metrics tracking
    """
    
    def __init__(self, model, config, device='cuda'):
        self.model = model.to(device)
        self.config = config
        self.device = device
        
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config.get('weight_decay', 0.01)
        )
        
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=config['learning_rate'],
            epochs=config['epochs'],
            steps_per_epoch=config['steps_per_epoch']
        )
    
    def train_epoch(self, train_loader):
        """Train for one epoch."""
        self.model.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader):
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            
            # Forward pass
            outputs = self.model(input_ids, attention_mask=attention_mask)
            
            # Calculate loss
            loss = outputs.loss
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            
            epoch_loss += loss.item()
        
        return epoch_loss / len(train_loader)

    def evaluate(self, val_loader):
        """Evaluate the model."""
        self.model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                
                outputs = self.model(input_ids, attention_mask=attention_mask)
                loss = outputs.loss
                
                val_loss += loss.item()
        
        return val_loss / len(val_loader)

# 4. Generation and Inference
class LyricsGenerator:
    """
    Lyrics generation interface.
    
    Features:
    - Prompt-based generation
    - Style control
    - Rhyme scheme enforcement
    - Structure management
    """
    
    def __init__(self, model, tokenizer, device='cuda'):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        
    def generate(self, prompt, style=None, max_length=200):
        """Generate lyrics from prompt."""
        self.model.eval()
        
        # Prepare input
        input_text = prompt if style is None else f"<|style|>{style}<|prompt|>{prompt}"
        input_ids = self.tokenizer.encode(input_text, return_tensors='pt').to(self.device)
        
        # Generate
        outputs = self.model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=3,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7
        )
        
        # Decode
        generated_lyrics = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return generated_lyrics

# 5. Evaluation
class LyricsEvaluator:
    """
    Evaluation metrics for lyrics generation.
    
    Metrics:
    - Rhyme quality
    - Syllable count
    - Theme consistency
    - Style adherence
    """
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def evaluate_rhyme(self, lyrics):
        """Evaluate rhyme patterns."""
        # Implement rhyme detection and scoring
        pass
    
    def evaluate_structure(self, lyrics):
        """Evaluate lyrical structure."""
        # Implement structure analysis
        pass
    
    def evaluate_theme(self, lyrics, prompt):
        """Evaluate theme consistency."""
        # Implement theme analysis
        pass

# Example Usage
def main():
    # Load config
    with open('models/lyrics-gen/config/model_config.json') as f:
        config = json.load(f)
    
    # Initialize model and tokenizer
    model = LyricsTransformer()
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    
    # Create datasets
    train_dataset = LyricsDataset('datasets/lyrics/train')
    val_dataset = LyricsDataset('datasets/lyrics/val')
    
    # Initialize trainer
    trainer = LyricsTrainer(model, config)
    
    # Train model
    for epoch in range(config['epochs']):
        train_loss = trainer.train_epoch(train_dataset)
        val_loss = trainer.evaluate(val_dataset)
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
    
    # Generate sample
    generator = LyricsGenerator(model, tokenizer)
    lyrics = generator.generate(
        prompt="Write a love song about summer",
        style="pop"
    )
    print("Generated Lyrics:", lyrics)

if __name__ == "__main__":
    main()




# Additional Features for Lyrics Generation

# 1. Enhanced Style Control and Structure
class EnhancedLyricsGenerator(LyricsGenerator):
    """
    Enhanced lyrics generator with advanced features.
    
    Additional Features:
    - Song structure management (verse, chorus, bridge)
    - Rhyme scheme control
    - Syllable count management
    - Emotional tone control
    - Genre-specific patterns
    - Metaphor and imagery enhancement
    """
    
    def generate_structured_song(self, prompt, structure_dict):
        """
        Generate a complete song with specified structure.
        
        Args:
            prompt (str): Main theme/topic
            structure_dict (dict): Song structure specification
                Example:
                {
                    'verse1': {'lines': 4, 'syllables_per_line': 8, 'rhyme_scheme': 'AABB'},
                    'chorus': {'lines': 4, 'syllables_per_line': 6, 'rhyme_scheme': 'ABAB'},
                    'verse2': {'lines': 4, 'syllables_per_line': 8, 'rhyme_scheme': 'AABB'},
                    'bridge': {'lines': 2, 'syllables_per_line': 10, 'rhyme_scheme': 'AA'},
                }
        
        Returns:
            dict: Generated song sections with metadata
        """
        song_parts = {}
        
        for section, specs in structure_dict.items():
            section_prompt = self._create_section_prompt(
                base_prompt=prompt,
                section_type=section,
                specifications=specs
            )
            
            generated_section = self.generate_section(
                prompt=section_prompt,
                rhyme_scheme=specs['rhyme_scheme'],
                syllables=specs['syllables_per_line'],
                num_lines=specs['lines']
            )
            
            song_parts[section] = generated_section
        
        return self._compile_song(song_parts)
    
    def generate_with_emotion(self, prompt, emotion_params):
        """
        Generate lyrics with specific emotional qualities.
        
        Args:
            prompt (str): Base prompt
            emotion_params (dict): Emotional parameters
                Example:
                {
                    'primary_emotion': 'joy',
                    'intensity': 0.8,
                    'tone': 'uplifting',
                    'imagery_type': 'nature',
                    'word_choices': 'positive'
                }
        """
        # Enhance prompt with emotional context
        enhanced_prompt = self._add_emotional_context(prompt, emotion_params)
        
        # Generate with emotion-specific settings
        return self.generate(
            prompt=enhanced_prompt,
            temperature=self._get_emotion_temperature(emotion_params),
            top_p=self._get_emotion_top_p(emotion_params)
        )

    def generate_with_metaphors(self, prompt, theme_params):
        """
        Generate lyrics rich in metaphors and imagery.
        
        Args:
            prompt (str): Base prompt
            theme_params (dict): Theme and metaphor specifications
                Example:
                {
                    'primary_theme': 'love',
                    'metaphor_source': 'ocean',
                    'imagery_type': 'visual',
                    'complexity_level': 'advanced'
                }
        """
        metaphor_enhanced_prompt = self._enhance_with_metaphors(prompt, theme_params)
        return self.generate(prompt=metaphor_enhanced_prompt)

    def _enhance_with_metaphors(self, prompt, theme_params):
        """Add metaphorical elements to prompt."""
        metaphor_templates = {
            'love': {
                'ocean': [
                    "deep as the ocean",
                    "waves of emotion",
                    "tidal force of feeling"
                ],
                'fire': [
                    "burning passion",
                    "flame of desire",
                    "scorching intensity"
                ]
            }
            # Add more themes and metaphor sources
        }
        
        # Select appropriate metaphors
        chosen_metaphors = self._select_metaphors(
            metaphor_templates,
            theme_params
        )
        
        # Enhance prompt with metaphors
        return f"{prompt} {' '.join(chosen_metaphors)}"

class RhymeController:
    """
    Controls rhyme patterns in generated lyrics.
    
    Features:
    - Multiple rhyme scheme support
    - Syllable counting
    - Assonance detection
    - Alliteration management
    """
    
    def __init__(self):
        self.pronunciations = self._load_pronunciations()
        
    def enforce_rhyme_scheme(self, lines, scheme):
        """
        Modify lines to follow rhyme scheme.
        
        Args:
            lines (list): Generated lines
            scheme (str): Rhyme scheme (e.g., 'AABB', 'ABAB')
            
        Returns:
            list: Modified lines with proper rhyming
        """
        rhyme_groups = self._create_rhyme_groups(scheme)
        return self._modify_line_endings(lines, rhyme_groups)
    
    def _get_rhyming_words(self, word):
        """Find words that rhyme with given word."""
        pronunciation = self.pronunciations.get(word.lower())
        if not pronunciation:
            return []
        
        rhyming_words = []
        for w, p in self.pronunciations.items():
            if self._is_rhyme(pronunciation, p):
                rhyming_words.append(w)
                
        return rhyming_words

# Advanced Evaluation Metrics
class EnhancedLyricsEvaluator(LyricsEvaluator):
    """
    Comprehensive evaluation suite for lyrics generation.
    
    Metrics Categories:
    1. Technical Quality
    2. Musical Compatibility
    3. Content Analysis
    4. Style Adherence
    """
    
    def evaluate_comprehensive(self, lyrics, target_style=None):
        """
        Run comprehensive evaluation on generated lyrics.
        
        Returns:
            dict: Complete evaluation metrics
        """
        return {
            'technical': self.evaluate_technical(lyrics),
            'musical': self.evaluate_musical(lyrics),
            'content': self.evaluate_content(lyrics),
            'style': self.evaluate_style(lyrics, target_style)
        }
    
    def evaluate_technical(self, lyrics):
        """Evaluate technical aspects of lyrics."""
        return {
            'rhyme_quality': self._analyze_rhyme_patterns(lyrics),
            'syllable_consistency': self._analyze_syllable_patterns(lyrics),
            'vocabulary_richness': self._calculate_vocabulary_metrics(lyrics),
            'grammar_score': self._check_grammar(lyrics)
        }
    
    def evaluate_musical(self, lyrics):
        """Evaluate musical compatibility."""
        return {
            'rhythm_score': self._analyze_rhythm(lyrics),
            'singability': self._evaluate_singability(lyrics),
            'phrase_length': self._analyze_phrase_length(lyrics),
            'stress_patterns': self._analyze_stress_patterns(lyrics)
        }
    
    def evaluate_content(self, lyrics):
        """Evaluate lyrical content quality."""
        return {
            'theme_coherence': self._analyze_theme_consistency(lyrics),
            'emotional_impact': self._analyze_emotional_content(lyrics),
            'imagery_score': self._evaluate_imagery(lyrics),
            'narrative_strength': self._analyze_narrative(lyrics)
        }

# Example Usage and Scenarios

def demonstrate_lyrics_generation():
    """
    Demonstrate various lyrics generation scenarios.
    """
    
    # Initialize generator
    generator = EnhancedLyricsGenerator(model, tokenizer)
    
    # 1. Generate a complete pop song
    pop_structure = {
        'verse1': {
            'lines': 4,
            'syllables_per_line': 8,
            'rhyme_scheme': 'AABB'
        },
        'chorus': {
            'lines': 4,
            'syllables_per_line': 6,
            'rhyme_scheme': 'ABAB'
        },
        'verse2': {
            'lines': 4,
            'syllables_per_line': 8,
            'rhyme_scheme': 'AABB'
        }
    }
    
    pop_song = generator.generate_structured_song(
        prompt="A summer love story",
        structure_dict=pop_structure
    )
    
    # 2. Generate emotional ballad
    emotion_params = {
        'primary_emotion': 'longing',
        'intensity': 0.9,
        'tone': 'melancholic',
        'imagery_type': 'nature',
        'word_choices': 'poetic'
    }
    
    ballad = generator.generate_with_emotion(
        prompt="Lost love and memories",
        emotion_params=emotion_params
    )
    
    # 3. Generate metaphorical lyrics
    theme_params = {
        'primary_theme': 'love',
        'metaphor_source': 'ocean',
        'imagery_type': 'visual',
        'complexity_level': 'advanced'
    }
    
    metaphorical = generator.generate_with_metaphors(
        prompt="Finding inner strength",
        theme_params=theme_params
    )
    
    return {
        'pop_song': pop_song,
        'ballad': ballad,
        'metaphorical': metaphorical
    }


# Integration with Melody Generation
class SongIntegrator:
    """
    Integrates lyrics and melody generation.
    
    Features:
    - Synchronizes lyrics with melody
    - Adjusts rhythm to match syllables
    - Ensures musical phrase alignment
    """
    
    def __init__(self, lyrics_generator, melody_generator):
        self.lyrics_generator = lyrics_generator
        self.melody_generator = melody_generator
    
    def generate_complete_song(self, prompt, style):
        """
        Generate a complete song with matching lyrics and melody.
        
        Args:
            prompt (str): Song theme/topic
            style (dict): Musical and lyrical style parameters
            
        Returns:
            dict: Complete song with melody and lyrics
        """
        # Generate lyrics first
        lyrics = self.lyrics_generator.generate_structured_song(
            prompt=prompt,
            structure_dict=self._get_structure_for_style(style)
        )
        
        # Generate matching melody
        melody = self.melody_generator.generate_with_structure(
            prompt=prompt,
            form=self._extract_form_from_lyrics(lyrics)
        )
        
        # Align lyrics and melody
        aligned_song = self._align_lyrics_and_melody(lyrics, melody)
        
        return aligned_song
    
    def _align_lyrics_and_melody(self, lyrics, melody):
        """Align lyrics with melody phrases."""
        aligned = {}
        
        for section in lyrics:
            section_melody = melody[section]
            section_lyrics = lyrics[section]
            
            # Adjust melody note durations to match syllables
            adjusted_melody = self._adjust_melody_to_lyrics(
                section_melody,
                self._count_syllables(section_lyrics)
            )
            
            aligned[section] = {
                'lyrics': section_lyrics,
                'melody': adjusted_melody
            }
        
        return aligned

def example_complete_song():
    """
    Generate and demonstrate a complete song.
    """
    # Initialize components
    lyrics_gen = EnhancedLyricsGenerator(lyrics_model, tokenizer)
    melody_gen = MelodyGenerator(melody_model)
    integrator = SongIntegrator(lyrics_gen, melody_gen)
    
    # Generate complete song
    song = integrator.generate_complete_song(
        prompt="A hopeful song about new beginnings",
        style={
            'genre': 'pop',
            'mood': 'uplifting',
            'tempo': 'moderate',
            'complexity': 'medium'
        }
    )
    
    # Evaluate the result
    evaluator = EnhancedLyricsEvaluator(lyrics_model, tokenizer)
    evaluation = evaluator.evaluate_comprehensive(
        song['lyrics'],
        target_style='pop'
    )
    
    return {
        'song': song,
        'evaluation': evaluation
    }

if __name__ == "__main__":
    # Run demonstrations
    lyrics_examples = demonstrate_lyrics_generation()
    complete_song = example_complete_song()
    
    # Print results
    print("Generated Lyrics Examples:")
    print(json.dumps(lyrics_examples, indent=2))
    
    print("\nComplete Song Generation:")
    print(json.dumps(complete_song, indent=2))