Spaces:
Running
Running
from transformers import pipeline | |
from typing import Dict, List, Optional | |
class LyricGenerator: | |
def __init__(self, model_name: str = "gpt2-medium"): | |
""" | |
Initialize the lyric generator with a specified language model. | |
Args: | |
model_name: The name of the pre-trained model to use | |
""" | |
self.generator = pipeline( | |
"text-generation", | |
model=model_name, | |
device=0 if pipeline.device.type == "cuda" else -1 | |
) | |
# Genre-specific prompts to guide generation | |
self.genre_prompts = { | |
"rock": "Write energetic rock lyrics about", | |
"pop": "Create catchy pop lyrics about", | |
"hip hop": "Write hip hop verses about", | |
"country": "Write country music lyrics about", | |
"jazz": "Compose smooth jazz lyrics about", | |
"classical": "Write classical music lyrics about", | |
"electronic": "Create electronic dance music lyrics about", | |
"blues": "Write soulful blues lyrics about", | |
"reggae": "Write laid-back reggae lyrics about", | |
"metal": "Write intense metal lyrics about" | |
} | |
def generate_lyrics( | |
self, | |
genre: str, | |
theme: str, | |
max_length: int = 200, | |
num_return_sequences: int = 1, | |
temperature: float = 0.9, | |
top_p: float = 0.9, | |
top_k: int = 50 | |
) -> List[str]: | |
""" | |
Generate lyrics based on genre and theme. | |
Args: | |
genre: The music genre to generate lyrics for | |
theme: The theme or topic for the lyrics | |
max_length: Maximum length of generated text | |
num_return_sequences: Number of different lyrics to generate | |
temperature: Controls randomness (higher = more random) | |
top_p: Nucleus sampling parameter | |
top_k: Top-k sampling parameter | |
Returns: | |
List of generated lyrics | |
""" | |
try: | |
# Get genre-specific prompt or use default | |
genre = genre.lower() | |
base_prompt = self.genre_prompts.get( | |
genre, | |
"Write song lyrics about" | |
) | |
# Construct full prompt | |
prompt = f"{base_prompt} {theme}:\n\n" | |
# Generate lyrics | |
outputs = self.generator( | |
prompt, | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
do_sample=True, | |
pad_token_id=50256 # GPT-2's pad token ID | |
) | |
# Process and clean up the generated texts | |
generated_lyrics = [] | |
for output in outputs: | |
# Remove the prompt from the generated text | |
lyrics = output['generated_text'][len(prompt):].strip() | |
# Basic cleanup | |
lyrics = lyrics.replace('<|endoftext|>', '').strip() | |
generated_lyrics.append(lyrics) | |
return generated_lyrics | |
except Exception as e: | |
raise ValueError(f"Lyric generation failed: {str(e)}") | |
def style_transfer( | |
self, | |
original_lyrics: str, | |
target_genre: str, | |
temperature: float = 0.9 | |
) -> str: | |
""" | |
Attempt to transfer the style of existing lyrics to a target genre. | |
Args: | |
original_lyrics: The original lyrics to restyle | |
target_genre: The target genre for the style transfer | |
temperature: Controls randomness of generation | |
Returns: | |
Restyled lyrics in the target genre | |
""" | |
try: | |
prompt = f"Rewrite these lyrics in {target_genre} style:\n\n{original_lyrics}\n\nNew version:\n" | |
output = self.generator( | |
prompt, | |
max_length=len(prompt) + 200, | |
temperature=temperature, | |
top_p=0.9, | |
do_sample=True, | |
num_return_sequences=1 | |
)[0] | |
# Extract the new version only | |
generated_text = output['generated_text'] | |
new_lyrics = generated_text.split("New version:\n")[-1].strip() | |
return new_lyrics | |
except Exception as e: | |
raise ValueError(f"Style transfer failed: {str(e)}") |