Spaces:
Running
on
Zero
Running
on
Zero
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import gradio as gr | |
import soundfile as sf | |
import numpy as np | |
import os | |
from io import BytesIO | |
import base64 | |
import spaces | |
# Model and Tokenizer Loading | |
MODEL_ID = "Qwen/Qwen-Audio-Chat" | |
def load_model(): | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
chat_template = """<s>[INST] <<SYS>> | |
You are a helpful assistant. | |
<</SYS>> | |
{% for message in messages %} | |
{{ message['role'] }}: {{ message['content'] }} | |
{% endfor %}[/INST]""" | |
tokenizer.chat_template = chat_template | |
return model, tokenizer | |
def process_audio(audio_path): | |
try: | |
audio_data, sample_rate = sf.read(audio_path) | |
if len(audio_data.shape) > 1: | |
audio_data = audio_data.mean(axis=1) | |
audio_data = audio_data.astype(np.float32) | |
audio_buffer = BytesIO() | |
sf.write(audio_buffer, audio_data, sample_rate, format='WAV') | |
audio_buffer.seek(0) | |
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
return { | |
"audio": audio_base64, | |
"sampling_rate": sample_rate | |
} | |
except Exception: | |
return None | |
def analyze_audio(audio_path: str, question: str = None) -> str: | |
if audio_path is None or not isinstance(audio_path, str): | |
return "Please provide a valid audio file." | |
if not os.path.exists(audio_path): | |
return f"Audio file not found: {audio_path}" | |
audio_data = process_audio(audio_path) | |
if not audio_data or "audio" not in audio_data or "sampling_rate" not in audio_data: | |
return "Failed to process the audio file. Please ensure it's a valid audio format." | |
try: | |
model, tokenizer = load_model() | |
query = question if question else "Please describe what you hear in this audio clip." | |
messages = [ | |
{ | |
"role": "user", | |
"content": query | |
} | |
] | |
if tokenizer.chat_template: | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
else: | |
raise ValueError("Tokenizer chat_template is not set.") | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**model_inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
if outputs is None or len(outputs) == 0: | |
return "The model failed to generate a response. Please try again." | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception: | |
return "An error occurred while processing. Please check your inputs and try again." | |
demo = gr.Interface( | |
fn=analyze_audio, | |
inputs=[ | |
gr.Audio( | |
type="filepath", | |
label="Audio Input", | |
sources=["upload", "microphone"], | |
format="mp3" | |
), | |
gr.Textbox( | |
label="Question", | |
placeholder="Optional: Ask a specific question about the audio", | |
value="" | |
) | |
], | |
outputs=gr.Textbox(label="Analysis"), | |
title="Qwen Audio Analysis Tool", | |
description="Upload an audio file or record from microphone to get AI-powered analysis using Qwen-Audio-Chat model", | |
examples=[ | |
["example1.wav", "What instruments do you hear?"] | |
], | |
cache_examples=False | |
) | |
if __name__ == "__main__": | |
demo.launch() | |