fyp_start_space / app.py
jacob-c's picture
.
c9ff2a7
raw
history blame
7.76 kB
import requests
import gradio as gr
import os
import torch
import json
import time
import tempfile
import shutil
from transformers import AutoTokenizer, AutoModelForCausalLM
# Check if CUDA is available and set the device accordingly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# API URLs and headers
AUDIO_API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
LYRICS_API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"
headers = {"Authorization": f"Bearer {os.environ.get('HF_TOKEN')}"}
def format_error(message):
"""Helper function to format error messages as JSON"""
return {"error": message}
def create_lyrics_prompt(classification_results):
"""Create a prompt for lyrics generation based on classification results"""
# Get the top genre and its characteristics
top_result = classification_results[0]
genre = top_result['label']
confidence = float(top_result['score'].strip('%')) / 100
# Get additional musical elements
additional_elements = [r['label'] for r in classification_results[1:3]]
# Create a detailed creative prompt
prompt = f"""Write creative and original song lyrics that capture the following musical elements:
Primary Style: {genre} ({confidence*100:.1f}% confidence)
Secondary Elements: {', '.join(additional_elements)}
Requirements:
1. Create lyrics that strongly reflect the {genre} style
2. Incorporate elements of {' and '.join(additional_elements)}
3. Include both verses and a chorus
4. Match the mood and atmosphere typical of this genre
5. Use appropriate musical terminology and style
Lyrics:
[Verse 1]
"""
return prompt
def generate_lyrics_with_retry(prompt, max_retries=5, initial_wait=2):
"""Generate lyrics using GPT-J-6B with retry logic"""
wait_time = initial_wait
for attempt in range(max_retries):
try:
response = requests.post(
LYRICS_API_URL,
headers=headers,
json={
"inputs": prompt,
"parameters": {
"max_new_tokens": 250,
"temperature": 0.8,
"top_p": 0.92,
"do_sample": True,
"return_full_text": False,
"stop": ["[End]", "\n\n\n"]
}
}
)
print(f"Response status: {response.status_code}")
print(f"Response content: {response.content.decode('utf-8', errors='ignore')}")
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
generated_text = result[0].get("generated_text", "")
# Clean up and format the generated text
lines = generated_text.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
if line and not line.startswith('###') and not line.startswith('```'):
cleaned_lines.append(line)
return "\n".join(cleaned_lines)
return "Error: No text generated"
elif response.status_code == 503:
print(f"Model loading, attempt {attempt + 1}/{max_retries}. Waiting {wait_time} seconds...")
time.sleep(wait_time)
wait_time *= 1.5 # Increase wait time for next attempt
continue
else:
return f"Error generating lyrics: {response.text}"
except Exception as e:
if attempt == max_retries - 1: # Last attempt
return f"Error after {max_retries} attempts: {str(e)}"
time.sleep(wait_time)
wait_time *= 1.5
return "Failed to generate lyrics after multiple attempts. Please try again."
def format_results(classification_results, lyrics, prompt):
"""Format the results for display"""
# Format classification results
classification_text = "Classification Results:\n"
for i, result in enumerate(classification_results):
classification_text += f"{i+1}. {result['label']}: {result['score']}\n"
# Format final output
output = f"""
{classification_text}
\n---Generated Lyrics---\n
{lyrics}
"""
return output
def classify_and_generate(audio_file):
"""
Classify the audio and generate matching lyrics
"""
if audio_file is None:
return "Please upload an audio file."
try:
token = os.environ.get('HF_TOKEN')
if not token:
return "Error: HF_TOKEN environment variable is not set. Please set your Hugging Face API token."
# Create a temporary file to handle the audio data
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_audio:
# If audio_file is a tuple (file path and sampling rate)
if isinstance(audio_file, tuple):
audio_path = audio_file[0]
else:
audio_path = audio_file
# Copy the audio file to our temporary file
shutil.copy2(audio_path, temp_audio.name)
# Read the temporary file
with open(temp_audio.name, "rb") as f:
data = f.read()
print("Sending request to Audio Classification API...")
response = requests.post(AUDIO_API_URL, headers=headers, data=data)
# Clean up the temporary file
try:
os.unlink(temp_audio.name)
except:
pass
if response.status_code == 200:
classification_results = response.json()
# Format classification results
formatted_results = []
for result in classification_results:
formatted_results.append({
'label': result['label'],
'score': f"{result['score']*100:.2f}%"
})
# Generate lyrics based on classification with retry logic
print("Generating lyrics based on classification...")
prompt = create_lyrics_prompt(formatted_results)
lyrics = generate_lyrics_with_retry(prompt)
# Format and return results
return format_results(formatted_results, lyrics, prompt)
elif response.status_code == 401:
return "Error: Invalid or missing API token. Please check your Hugging Face API token."
elif response.status_code == 503:
return "Error: Model is loading. Please try again in a few seconds."
else:
return f"Error: API returned status code {response.status_code}\nResponse: {response.text}"
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"Error processing request: {str(e)}\nDetails:\n{error_details}"
# Create Gradio interface
iface = gr.Interface(
fn=classify_and_generate,
inputs=gr.Audio(type="filepath", label="Upload Audio File"),
outputs=gr.Textbox(
label="Results",
lines=15,
placeholder="Upload an audio file to see classification results and generated lyrics..."
),
title="Music Genre Classifier + Lyric Generator",
description="Upload an audio file to classify its genre and generate matching lyrics using AI.",
examples=[],
)
# Launch the interface
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)