Qwen2-Audio-7B / app.py
desiree's picture
Update app.py
21a98db verified
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import soundfile as sf
import numpy as np
import os
from io import BytesIO
import base64
import spaces
# Model and Tokenizer Loading
MODEL_ID = "Qwen/Qwen-Audio-Chat"
def load_model():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
chat_template = """<s>[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>
{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}[/INST]"""
tokenizer.chat_template = chat_template
return model, tokenizer
def process_audio(audio_path):
try:
audio_data, sample_rate = sf.read(audio_path)
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
audio_data = audio_data.astype(np.float32)
audio_buffer = BytesIO()
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
audio_buffer.seek(0)
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
return {
"audio": audio_base64,
"sampling_rate": sample_rate
}
except Exception:
return None
@spaces.GPU
def analyze_audio(audio_path: str, question: str = None) -> str:
if audio_path is None or not isinstance(audio_path, str):
return "Please provide a valid audio file."
if not os.path.exists(audio_path):
return f"Audio file not found: {audio_path}"
audio_data = process_audio(audio_path)
if not audio_data or "audio" not in audio_data or "sampling_rate" not in audio_data:
return "Failed to process the audio file. Please ensure it's a valid audio format."
try:
model, tokenizer = load_model()
query = question if question else "Please describe what you hear in this audio clip."
messages = [
{
"role": "user",
"content": query
}
]
if tokenizer.chat_template:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
raise ValueError("Tokenizer chat_template is not set.")
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
if outputs is None or len(outputs) == 0:
return "The model failed to generate a response. Please try again."
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception:
return "An error occurred while processing. Please check your inputs and try again."
demo = gr.Interface(
fn=analyze_audio,
inputs=[
gr.Audio(
type="filepath",
label="Audio Input",
sources=["upload", "microphone"],
format="mp3"
),
gr.Textbox(
label="Question",
placeholder="Optional: Ask a specific question about the audio",
value=""
)
],
outputs=gr.Textbox(label="Analysis"),
title="Qwen Audio Analysis Tool",
description="Upload an audio file or record from microphone to get AI-powered analysis using Qwen-Audio-Chat model",
examples=[
["example1.wav", "What instruments do you hear?"]
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()