Spaces:
Runtime error
Runtime error
import os | |
import io | |
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import ( | |
AutoModelForAudioClassification, | |
AutoFeatureExtractor, | |
AutoTokenizer, | |
pipeline, | |
AutoModelForCausalLM, | |
BitsAndBytesConfig | |
) | |
from huggingface_hub import login | |
from utils import ( | |
load_audio, | |
extract_audio_duration, | |
extract_mfcc_features, | |
calculate_lyrics_length, | |
format_genre_results, | |
ensure_cuda_availability, | |
preprocess_audio_for_model | |
) | |
# Login to Hugging Face Hub if token is provided | |
if "HF_TOKEN" in os.environ: | |
login(token=os.environ["HF_TOKEN"]) | |
# Constants | |
GENRE_MODEL_NAME = "dima806/music_genres_classification" | |
MUSIC_DETECTION_MODEL = "MIT/ast-finetuned-audioset-10-10-0.4593" | |
LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | |
SAMPLE_RATE = 22050 # Standard sample rate for audio processing | |
# Check CUDA availability (for informational purposes) | |
CUDA_AVAILABLE = ensure_cuda_availability() | |
# Create music detection pipeline | |
print(f"Loading music detection model: {MUSIC_DETECTION_MODEL}") | |
try: | |
music_detector = pipeline( | |
"audio-classification", | |
model=MUSIC_DETECTION_MODEL, | |
device=0 if CUDA_AVAILABLE else -1 | |
) | |
print("Successfully loaded music detection pipeline") | |
except Exception as e: | |
print(f"Error creating music detection pipeline: {str(e)}") | |
# Fallback to manual loading | |
try: | |
music_processor = AutoFeatureExtractor.from_pretrained(MUSIC_DETECTION_MODEL) | |
music_model = AutoModelForAudioClassification.from_pretrained(MUSIC_DETECTION_MODEL) | |
print("Successfully loaded music detection model and feature extractor") | |
except Exception as e2: | |
print(f"Error loading music detection model components: {str(e2)}") | |
raise RuntimeError(f"Could not load music detection model: {str(e2)}") | |
# Create genre classification pipeline | |
print(f"Loading audio classification model: {GENRE_MODEL_NAME}") | |
try: | |
genre_classifier = pipeline( | |
"audio-classification", | |
model=GENRE_MODEL_NAME, | |
device=0 if CUDA_AVAILABLE else -1 | |
) | |
print("Successfully loaded audio classification pipeline") | |
except Exception as e: | |
print(f"Error creating pipeline: {str(e)}") | |
# Fallback to manual loading | |
try: | |
genre_processor = AutoFeatureExtractor.from_pretrained(GENRE_MODEL_NAME) | |
genre_model = AutoModelForAudioClassification.from_pretrained(GENRE_MODEL_NAME) | |
print("Successfully loaded audio classification model and feature extractor") | |
except Exception as e2: | |
print(f"Error loading model components: {str(e2)}") | |
raise RuntimeError(f"Could not load genre classification model: {str(e2)}") | |
# Load LLM with appropriate quantization for T4 GPU | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
LLM_MODEL_NAME, | |
device_map="auto", | |
quantization_config=bnb_config, | |
torch_dtype=torch.float16, | |
) | |
# Create LLM pipeline | |
llm_pipeline = pipeline( | |
"text-generation", | |
model=llm_model, | |
tokenizer=llm_tokenizer, | |
max_new_tokens=512, | |
) | |
def extract_audio_features(audio_file): | |
"""Extract audio features from an audio file.""" | |
# Load the audio file using utility function | |
y, sr = load_audio(audio_file, SAMPLE_RATE) | |
# Get audio duration in seconds | |
duration = extract_audio_duration(y, sr) | |
# Extract MFCCs for genre classification (may not be needed with the pipeline) | |
mfccs_mean = extract_mfcc_features(y, sr, n_mfcc=20) | |
return { | |
"features": mfccs_mean, | |
"duration": duration, | |
"waveform": y, | |
"sample_rate": sr, | |
"path": audio_file # Keep path for the pipeline | |
} | |
def classify_genre(audio_data): | |
"""Classify the genre of the audio using the loaded model.""" | |
try: | |
# First attempt: Try using the pipeline if available | |
if 'genre_classifier' in globals(): | |
results = genre_classifier(audio_data["path"]) | |
# Transform pipeline results to our expected format | |
top_genres = [(result["label"], result["score"]) for result in results[:3]] | |
return top_genres | |
# Second attempt: Use manually loaded model components | |
elif 'genre_processor' in globals() and 'genre_model' in globals(): | |
# Process audio input with feature extractor | |
inputs = genre_processor( | |
audio_data["waveform"], | |
sampling_rate=audio_data["sample_rate"], | |
return_tensors="pt" | |
) | |
with torch.no_grad(): | |
outputs = genre_model(**inputs) | |
predictions = outputs.logits.softmax(dim=-1) | |
# Get the top 3 genres | |
values, indices = torch.topk(predictions, 3) | |
# Map indices to genre labels | |
genre_labels = genre_model.config.id2label | |
top_genres = [] | |
for i, (value, index) in enumerate(zip(values[0], indices[0])): | |
genre = genre_labels[index.item()] | |
confidence = value.item() | |
top_genres.append((genre, confidence)) | |
return top_genres | |
else: | |
raise ValueError("No genre classification model available") | |
except Exception as e: | |
print(f"Error in genre classification: {str(e)}") | |
# Fallback: return a default genre if everything fails | |
return [("rock", 1.0)] | |
def generate_lyrics(genre, duration): | |
"""Generate lyrics based on the genre and with appropriate length.""" | |
# Calculate appropriate lyrics length based on audio duration | |
lines_count = calculate_lyrics_length(duration) | |
# Calculate approximate number of verses and chorus | |
if lines_count <= 6: | |
# Very short song - one verse and chorus | |
verse_lines = 2 | |
chorus_lines = 2 | |
elif lines_count <= 10: | |
# Medium song - two verses and chorus | |
verse_lines = 3 | |
chorus_lines = 2 | |
else: | |
# Longer song - two verses, chorus, and bridge | |
verse_lines = 3 | |
chorus_lines = 2 | |
# Create prompt for the LLM | |
prompt = f""" | |
You are a talented songwriter who specializes in {genre} music. | |
Write original {genre} song lyrics for a song that is {duration:.1f} seconds long. | |
The lyrics should: | |
- Perfectly capture the essence and style of {genre} music | |
- Be approximately {lines_count} lines long | |
- Have a coherent theme and flow | |
- Follow this structure: | |
* Verse: {verse_lines} lines | |
* Chorus: {chorus_lines} lines | |
* {f'Bridge: 2 lines' if lines_count > 10 else ''} | |
- Be completely original | |
- Match the song duration of {duration:.1f} seconds | |
- Keep each line concise and impactful | |
Your lyrics: | |
""" | |
# Generate lyrics using the LLM | |
response = llm_pipeline( | |
prompt, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
return_full_text=False | |
) | |
# Extract and clean generated lyrics | |
lyrics = response[0]["generated_text"].strip() | |
# Add section labels if they're not present | |
if "Verse" not in lyrics and "Chorus" not in lyrics: | |
lines = lyrics.split('\n') | |
formatted_lyrics = [] | |
current_section = "Verse" | |
for i, line in enumerate(lines): | |
if i == 0: | |
formatted_lyrics.append("[Verse]") | |
elif i == verse_lines: | |
formatted_lyrics.append("\n[Chorus]") | |
elif i == verse_lines + chorus_lines and lines_count > 10: | |
formatted_lyrics.append("\n[Bridge]") | |
formatted_lyrics.append(line) | |
lyrics = '\n'.join(formatted_lyrics) | |
return lyrics | |
def detect_music(audio_data): | |
"""Detect if the audio is music using the MIT AST model.""" | |
try: | |
# First attempt: Try using the pipeline if available | |
if 'music_detector' in globals(): | |
results = music_detector(audio_data["path"]) | |
# Look for music-related classes in the results | |
music_confidence = 0.0 | |
for result in results: | |
label = result["label"].lower() | |
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]): | |
music_confidence = max(music_confidence, result["score"]) | |
return music_confidence >= 0.5 | |
# Second attempt: Use manually loaded model components | |
elif 'music_processor' in globals() and 'music_model' in globals(): | |
# Process audio input with feature extractor | |
inputs = music_processor( | |
audio_data["waveform"], | |
sampling_rate=audio_data["sample_rate"], | |
return_tensors="pt" | |
) | |
with torch.no_grad(): | |
outputs = music_model(**inputs) | |
predictions = outputs.logits.softmax(dim=-1) | |
# Get the top predictions | |
values, indices = torch.topk(predictions, 5) | |
# Map indices to labels | |
labels = music_model.config.id2label | |
# Check for music-related classes | |
music_confidence = 0.0 | |
for i, (value, index) in enumerate(zip(values[0], indices[0])): | |
label = labels[index.item()].lower() | |
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]): | |
music_confidence = max(music_confidence, value.item()) | |
return music_confidence >= 0.5 | |
else: | |
raise ValueError("No music detection model available") | |
except Exception as e: | |
print(f"Error in music detection: {str(e)}") | |
return False | |
def process_audio(audio_file): | |
"""Main function to process audio file, classify genre, and generate lyrics.""" | |
if audio_file is None: | |
return "Please upload an audio file.", None | |
try: | |
# Extract audio features | |
audio_data = extract_audio_features(audio_file) | |
# First check if it's music | |
is_music = detect_music(audio_data) | |
if not is_music: | |
return "The uploaded audio does not appear to be music. Please upload a music file.", None | |
# Classify genre | |
top_genres = classify_genre(audio_data) | |
# Format genre results using utility function | |
genre_results = format_genre_results(top_genres) | |
# Generate lyrics based on top genre | |
primary_genre, _ = top_genres[0] | |
lyrics = generate_lyrics(primary_genre, audio_data["duration"]) | |
return genre_results, lyrics | |
except Exception as e: | |
return f"Error processing audio: {str(e)}", None | |
# Create Gradio interface | |
with gr.Blocks(title="Music Genre Classifier & Lyrics Generator") as demo: | |
gr.Markdown("# Music Genre Classifier & Lyrics Generator") | |
gr.Markdown("Upload a music file to classify its genre and generate matching lyrics.") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(label="Upload Music", type="filepath") | |
submit_btn = gr.Button("Analyze & Generate") | |
with gr.Column(): | |
genre_output = gr.Textbox(label="Detected Genres", lines=5) | |
lyrics_output = gr.Textbox(label="Generated Lyrics", lines=15) | |
submit_btn.click( | |
fn=process_audio, | |
inputs=[audio_input], | |
outputs=[genre_output, lyrics_output] | |
) | |
gr.Markdown("### How it works") | |
gr.Markdown(""" | |
1. Upload an audio file of your choice | |
2. The system will classify the genre using the dima806/music_genres_classification model | |
3. Based on the detected genre, it will generate appropriate lyrics using Llama-3.1-8B-Instruct | |
4. The lyrics length is automatically adjusted based on your audio duration | |
""") | |
# Launch the app | |
demo.launch() | |