kcz358's picture
fix
7ad0d4f
raw
history blame
4.9 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
import librosa
from threading import Thread
import spaces
def split_audio(audio_arrays, chunk_limit=480000):
CHUNK_LIM = chunk_limit
audio_splits = []
# Split the loaded audio to 30s chunks and extend the messages content
for i in range(
0,
len(audio_arrays),
CHUNK_LIM,
):
audio_splits.append(audio_arrays[i : i + CHUNK_LIM])
return audio_splits
def user(audio, text, chat_history):
if audio is not None:
chat_history.append(gr.ChatMessage(role="user", content={"path": audio, "alt_text": "Audio"}))
chat_history.append({"role": "user", "content": text})
return "", chat_history
@spaces.GPU
def process_audio(audio, text, chat_history):
conversation = [
{
"role": "user",
"content": [
],
},
]
audio_path = audio
audio = librosa.load(audio, sr=16000)[0]
if audio is not None:
splitted_audio = split_audio(audio)
for au in splitted_audio:
conversation[0]["content"].append(
{
"type": "audio_url",
"audio": "placeholder",
}
)
# chat_history.append(gr.ChatMessage(role="user", content={"path": audio_path, "alt_text": "Audio"}))
conversation[0]["content"].append(
{
"type": "text",
"text": text,
}
)
chat_history.append({"role": "user", "content": text})
# Set up the streamer for token generation
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=prompt, audios=splitted_audio, sampling_rate=16000, return_tensors="pt", padding=True)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Set up generation arguments including max tokens and streamer
generation_args = {
"max_new_tokens": 4096,
"streamer": streamer,
"eos_token_id":151645,
"pad_token_id":151643,
**inputs
}
# Start a separate thread for model generation to allow streaming output
chat_history.append({"role": "assistant", "content": ""})
thread = Thread(
target=model.generate,
kwargs=generation_args,
)
thread.start()
for character in streamer:
chat_history[-1]['content'] += character
yield chat_history
with gr.Blocks() as demo:
gr.Markdown("## 🎙️ Aero-1-Audio")
gr.Markdown(
"""
Aero-1-Audio is a compact audio model. With only 1.5B parameters and 50k hours training data, it can perform a variety of tasks, including:
ASR, basic Audio Understanding, Audio Instruction Following, and scene analysis
We provide several examples such as:
- nvidia conference and a show from elon musk for long ASR
- Simple Audio Instruction Following
- Audio Understanding for weather and music
The model might not be able to follow your instruction in multiple cases and might be wrong in many times
"""
)
chatbot = gr.Chatbot(type="messages")
with gr.Row(variant="compact", equal_height=True):
audio_input = gr.Audio(label="Speak Here", type="filepath")
text_input = gr.Textbox(label="Text Input", placeholder="Type here", interactive=True)
with gr.Row():
chatbot_clear = gr.ClearButton([text_input, audio_input, chatbot], value="Clear")
chatbot_submit = gr.Button("Submit", variant="primary")
chatbot_submit.click(
user,
inputs=[audio_input, text_input, chatbot],
outputs=[text_input, chatbot],
queue=False
).then(
process_audio,
inputs=[audio_input, text_input, chatbot],
outputs=[chatbot],
)
gr.Examples(
[
["Please transcribe the audio for me", "./examples/elon_musk.mp3"],
["Please transcribe the audio for me", "./examples/nvidia_conference.mp3"],
["Please follow the instruction in the audio", "./examples/audio_instruction.wav"],
["What is the primary instrument featured in the solo of this track?", "./examples/music_under.wav"],
["What weather condition can be heard in the audio?", "./examples/audio_understand.wav"],
],
inputs=[text_input, audio_input],
label="Examples",
)
if __name__ == "__main__":
processor = AutoProcessor.from_pretrained("lmms-lab/Aero-1-Audio-1.5B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("lmms-lab/Aero-1-Audio-1.5B", device_map="cuda", torch_dtype="auto", attn_implementation="sdpa", trust_remote_code=True)
demo.launch()