root
ss
8df3af9
import os
import io
import gradio as gr
import torch
import numpy as np
from transformers import (
AutoModelForAudioClassification,
AutoFeatureExtractor,
AutoTokenizer,
pipeline,
AutoModelForCausalLM,
BitsAndBytesConfig
)
from huggingface_hub import login
from utils import (
load_audio,
extract_audio_duration,
extract_mfcc_features,
calculate_lyrics_length,
format_genre_results,
ensure_cuda_availability,
preprocess_audio_for_model
)
# Login to Hugging Face Hub if token is provided
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
# Constants
GENRE_MODEL_NAME = "dima806/music_genres_classification"
MUSIC_DETECTION_MODEL = "MIT/ast-finetuned-audioset-10-10-0.4593"
LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
SAMPLE_RATE = 22050 # Standard sample rate for audio processing
# Check CUDA availability (for informational purposes)
CUDA_AVAILABLE = ensure_cuda_availability()
# Create music detection pipeline
print(f"Loading music detection model: {MUSIC_DETECTION_MODEL}")
try:
music_detector = pipeline(
"audio-classification",
model=MUSIC_DETECTION_MODEL,
device=0 if CUDA_AVAILABLE else -1
)
print("Successfully loaded music detection pipeline")
except Exception as e:
print(f"Error creating music detection pipeline: {str(e)}")
# Fallback to manual loading
try:
music_processor = AutoFeatureExtractor.from_pretrained(MUSIC_DETECTION_MODEL)
music_model = AutoModelForAudioClassification.from_pretrained(MUSIC_DETECTION_MODEL)
print("Successfully loaded music detection model and feature extractor")
except Exception as e2:
print(f"Error loading music detection model components: {str(e2)}")
raise RuntimeError(f"Could not load music detection model: {str(e2)}")
# Create genre classification pipeline
print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
try:
genre_classifier = pipeline(
"audio-classification",
model=GENRE_MODEL_NAME,
device=0 if CUDA_AVAILABLE else -1
)
print("Successfully loaded audio classification pipeline")
except Exception as e:
print(f"Error creating pipeline: {str(e)}")
# Fallback to manual loading
try:
genre_processor = AutoFeatureExtractor.from_pretrained(GENRE_MODEL_NAME)
genre_model = AutoModelForAudioClassification.from_pretrained(GENRE_MODEL_NAME)
print("Successfully loaded audio classification model and feature extractor")
except Exception as e2:
print(f"Error loading model components: {str(e2)}")
raise RuntimeError(f"Could not load genre classification model: {str(e2)}")
# Load LLM with appropriate quantization for T4 GPU
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
llm_model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL_NAME,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16,
)
# Create LLM pipeline
llm_pipeline = pipeline(
"text-generation",
model=llm_model,
tokenizer=llm_tokenizer,
max_new_tokens=512,
)
def extract_audio_features(audio_file):
"""Extract audio features from an audio file."""
# Load the audio file using utility function
y, sr = load_audio(audio_file, SAMPLE_RATE)
# Get audio duration in seconds
duration = extract_audio_duration(y, sr)
# Extract MFCCs for genre classification (may not be needed with the pipeline)
mfccs_mean = extract_mfcc_features(y, sr, n_mfcc=20)
return {
"features": mfccs_mean,
"duration": duration,
"waveform": y,
"sample_rate": sr,
"path": audio_file # Keep path for the pipeline
}
def classify_genre(audio_data):
"""Classify the genre of the audio using the loaded model."""
try:
# First attempt: Try using the pipeline if available
if 'genre_classifier' in globals():
results = genre_classifier(audio_data["path"])
# Transform pipeline results to our expected format
top_genres = [(result["label"], result["score"]) for result in results[:3]]
return top_genres
# Second attempt: Use manually loaded model components
elif 'genre_processor' in globals() and 'genre_model' in globals():
# Process audio input with feature extractor
inputs = genre_processor(
audio_data["waveform"],
sampling_rate=audio_data["sample_rate"],
return_tensors="pt"
)
with torch.no_grad():
outputs = genre_model(**inputs)
predictions = outputs.logits.softmax(dim=-1)
# Get the top 3 genres
values, indices = torch.topk(predictions, 3)
# Map indices to genre labels
genre_labels = genre_model.config.id2label
top_genres = []
for i, (value, index) in enumerate(zip(values[0], indices[0])):
genre = genre_labels[index.item()]
confidence = value.item()
top_genres.append((genre, confidence))
return top_genres
else:
raise ValueError("No genre classification model available")
except Exception as e:
print(f"Error in genre classification: {str(e)}")
# Fallback: return a default genre if everything fails
return [("rock", 1.0)]
def generate_lyrics(genre, duration):
"""Generate lyrics based on the genre and with appropriate length."""
# Calculate appropriate lyrics length based on audio duration
lines_count = calculate_lyrics_length(duration)
# Calculate approximate number of verses and chorus
if lines_count <= 6:
# Very short song - one verse and chorus
verse_lines = 2
chorus_lines = 2
elif lines_count <= 10:
# Medium song - two verses and chorus
verse_lines = 3
chorus_lines = 2
else:
# Longer song - two verses, chorus, and bridge
verse_lines = 3
chorus_lines = 2
# Create prompt for the LLM
prompt = f"""
You are a talented songwriter who specializes in {genre} music.
Write original {genre} song lyrics for a song that is {duration:.1f} seconds long.
The lyrics should:
- Perfectly capture the essence and style of {genre} music
- Be approximately {lines_count} lines long
- Have a coherent theme and flow
- Follow this structure:
* Verse: {verse_lines} lines
* Chorus: {chorus_lines} lines
* {f'Bridge: 2 lines' if lines_count > 10 else ''}
- Be completely original
- Match the song duration of {duration:.1f} seconds
- Keep each line concise and impactful
Your lyrics:
"""
# Generate lyrics using the LLM
response = llm_pipeline(
prompt,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
return_full_text=False
)
# Extract and clean generated lyrics
lyrics = response[0]["generated_text"].strip()
# Add section labels if they're not present
if "Verse" not in lyrics and "Chorus" not in lyrics:
lines = lyrics.split('\n')
formatted_lyrics = []
current_section = "Verse"
for i, line in enumerate(lines):
if i == 0:
formatted_lyrics.append("[Verse]")
elif i == verse_lines:
formatted_lyrics.append("\n[Chorus]")
elif i == verse_lines + chorus_lines and lines_count > 10:
formatted_lyrics.append("\n[Bridge]")
formatted_lyrics.append(line)
lyrics = '\n'.join(formatted_lyrics)
return lyrics
def detect_music(audio_data):
"""Detect if the audio is music using the MIT AST model."""
try:
# First attempt: Try using the pipeline if available
if 'music_detector' in globals():
results = music_detector(audio_data["path"])
# Look for music-related classes in the results
music_confidence = 0.0
for result in results:
label = result["label"].lower()
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
music_confidence = max(music_confidence, result["score"])
return music_confidence >= 0.5
# Second attempt: Use manually loaded model components
elif 'music_processor' in globals() and 'music_model' in globals():
# Process audio input with feature extractor
inputs = music_processor(
audio_data["waveform"],
sampling_rate=audio_data["sample_rate"],
return_tensors="pt"
)
with torch.no_grad():
outputs = music_model(**inputs)
predictions = outputs.logits.softmax(dim=-1)
# Get the top predictions
values, indices = torch.topk(predictions, 5)
# Map indices to labels
labels = music_model.config.id2label
# Check for music-related classes
music_confidence = 0.0
for i, (value, index) in enumerate(zip(values[0], indices[0])):
label = labels[index.item()].lower()
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
music_confidence = max(music_confidence, value.item())
return music_confidence >= 0.5
else:
raise ValueError("No music detection model available")
except Exception as e:
print(f"Error in music detection: {str(e)}")
return False
def process_audio(audio_file):
"""Main function to process audio file, classify genre, and generate lyrics."""
if audio_file is None:
return "Please upload an audio file.", None
try:
# Extract audio features
audio_data = extract_audio_features(audio_file)
# First check if it's music
is_music = detect_music(audio_data)
if not is_music:
return "The uploaded audio does not appear to be music. Please upload a music file.", None
# Classify genre
top_genres = classify_genre(audio_data)
# Format genre results using utility function
genre_results = format_genre_results(top_genres)
# Generate lyrics based on top genre
primary_genre, _ = top_genres[0]
lyrics = generate_lyrics(primary_genre, audio_data["duration"])
return genre_results, lyrics
except Exception as e:
return f"Error processing audio: {str(e)}", None
# Create Gradio interface
with gr.Blocks(title="Music Genre Classifier & Lyrics Generator") as demo:
gr.Markdown("# Music Genre Classifier & Lyrics Generator")
gr.Markdown("Upload a music file to classify its genre and generate matching lyrics.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Upload Music", type="filepath")
submit_btn = gr.Button("Analyze & Generate")
with gr.Column():
genre_output = gr.Textbox(label="Detected Genres", lines=5)
lyrics_output = gr.Textbox(label="Generated Lyrics", lines=15)
submit_btn.click(
fn=process_audio,
inputs=[audio_input],
outputs=[genre_output, lyrics_output]
)
gr.Markdown("### How it works")
gr.Markdown("""
1. Upload an audio file of your choice
2. The system will classify the genre using the dima806/music_genres_classification model
3. Based on the detected genre, it will generate appropriate lyrics using Llama-3.1-8B-Instruct
4. The lyrics length is automatically adjusted based on your audio duration
""")
# Launch the app
demo.launch()