fyp_start_space / src /lyric_generator.py
jacob-c's picture
.
8599ceb
raw
history blame
5.11 kB
from transformers import pipeline
import torch
from typing import Dict, List, Optional
class LyricGenerator:
def __init__(self, model_name: str = "gpt2-medium"):
"""
Initialize the lyric generator with a specified language model.
Args:
model_name: The name of the pre-trained model to use
"""
try:
# Try to use CUDA if available
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
self.generator = pipeline(
"text-generation",
model=model_name,
device_map="auto" # Let transformers handle device mapping
)
except Exception as e:
print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}")
self.generator = pipeline(
"text-generation",
model=model_name,
device="cpu"
)
# Genre-specific prompts to guide generation
self.genre_prompts = {
"rock": "Write energetic rock lyrics about",
"pop": "Create catchy pop lyrics about",
"hip hop": "Write hip hop verses about",
"country": "Write country music lyrics about",
"jazz": "Compose smooth jazz lyrics about",
"classical": "Write classical music lyrics about",
"electronic": "Create electronic dance music lyrics about",
"blues": "Write soulful blues lyrics about",
"reggae": "Write laid-back reggae lyrics about",
"metal": "Write intense metal lyrics about"
}
def generate_lyrics(
self,
genre: str,
theme: str,
max_length: int = 200,
num_return_sequences: int = 1,
temperature: float = 0.9,
top_p: float = 0.9,
top_k: int = 50
) -> List[str]:
"""
Generate lyrics based on genre and theme.
Args:
genre: The music genre to generate lyrics for
theme: The theme or topic for the lyrics
max_length: Maximum length of generated text
num_return_sequences: Number of different lyrics to generate
temperature: Controls randomness (higher = more random)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
Returns:
List of generated lyrics
"""
try:
# Get genre-specific prompt or use default
genre = genre.lower()
base_prompt = self.genre_prompts.get(
genre,
"Write song lyrics about"
)
# Construct full prompt
prompt = f"{base_prompt} {theme}:\n\n"
# Generate lyrics
outputs = self.generator(
prompt,
max_length=max_length,
num_return_sequences=num_return_sequences,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=50256 # GPT-2's pad token ID
)
# Process and clean up the generated texts
generated_lyrics = []
for output in outputs:
# Remove the prompt from the generated text
lyrics = output['generated_text'][len(prompt):].strip()
# Basic cleanup
lyrics = lyrics.replace('<|endoftext|>', '').strip()
generated_lyrics.append(lyrics)
return generated_lyrics
except Exception as e:
raise ValueError(f"Lyric generation failed: {str(e)}")
def style_transfer(
self,
original_lyrics: str,
target_genre: str,
temperature: float = 0.9
) -> str:
"""
Attempt to transfer the style of existing lyrics to a target genre.
Args:
original_lyrics: The original lyrics to restyle
target_genre: The target genre for the style transfer
temperature: Controls randomness of generation
Returns:
Restyled lyrics in the target genre
"""
try:
prompt = f"Rewrite these lyrics in {target_genre} style:\n\n{original_lyrics}\n\nNew version:\n"
output = self.generator(
prompt,
max_length=len(prompt) + 200,
temperature=temperature,
top_p=0.9,
do_sample=True,
num_return_sequences=1
)[0]
# Extract the new version only
generated_text = output['generated_text']
new_lyrics = generated_text.split("New version:\n")[-1].strip()
return new_lyrics
except Exception as e:
raise ValueError(f"Style transfer failed: {str(e)}")