# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. import os import yaml # import spaces import gradio as gr import librosa from pydub import AudioSegment import soundfile as sf import numpy as np import torch import laion_clap from inference_utils import prepare_tokenizer, prepare_model, inference from data import AudioTextDataProcessor if torch.cuda.is_available(): device = 'cuda:0' else: device = 'cpu' # @spaces.GPU def load_laionclap(): model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').to(device) model.load_ckpt(ckpt='630k-audioset-fusion-best.pt') model.eval() return model def int16_to_float32(x): return (x / 32767.0).astype(np.float32) def float32_to_int16(x): x = np.clip(x, a_min=-1., a_max=1.) return (x * 32767.).astype(np.int16) def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0): if file_path.endswith('.mp3'): audio = AudioSegment.from_file(file_path) if len(audio) > (start + duration) * 1000: audio = audio[start * 1000:(start + duration) * 1000] if audio.frame_rate != target_sr: audio = audio.set_frame_rate(target_sr) if audio.channels > 1: audio = audio.set_channels(1) data = np.array(audio.get_array_of_samples()) if audio.sample_width == 2: data = data.astype(np.float32) / np.iinfo(np.int16).max elif audio.sample_width == 4: data = data.astype(np.float32) / np.iinfo(np.int32).max else: raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) else: with sf.SoundFile(file_path) as audio: original_sr = audio.samplerate channels = audio.channels max_frames = int((start + duration) * original_sr) audio.seek(int(start * original_sr)) frames_to_read = min(max_frames, len(audio)) data = audio.read(frames_to_read) if data.max() > 1 or data.min() < -1: data = data / max(abs(data.max()), abs(data.min())) if original_sr != target_sr: if channels == 1: data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) else: data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] else: if channels != 1: data = data.T[0] if data.min() >= 0: data = 2 * data / abs(data.max()) - 1.0 else: data = data / max(abs(data.max()), abs(data.min())) return data # @spaces.GPU @torch.no_grad() def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs): try: data = load_audio(audio_file, target_sr=48000) except Exception as e: print(audio_file, 'unsuccessful due to', e) return [0.0] * len(outputs) audio_data = data.reshape(1, -1) audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().to(device) audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True) cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed) return cos_similarity.squeeze().cpu().numpy() inference_kwargs = { "do_sample": True, "top_k": 50, "top_p": 0.95, "num_return_sequences": 20 } config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader) clap_config = config['clap_config'] model_config = config['model_config'] text_tokenizer = prepare_tokenizer(model_config) DataProcessor = AudioTextDataProcessor( data_root='./', clap_config=clap_config, tokenizer=text_tokenizer, max_tokens=512, ) laionclap_model = load_laionclap() model = prepare_model( model_config=model_config, clap_config=clap_config, checkpoint_path='chat.pt', device=device ) # @spaces.GPU def inference_item(name, prompt): item = { 'name': str(name), 'prefix': 'The task is dialog.', 'prompt': str(prompt) } processed_item = DataProcessor.process(item) outputs = inference( model, text_tokenizer, item, processed_item, inference_kwargs, device=device ) laionclap_scores = compute_laionclap_text_audio_sim( item["name"], laionclap_model, outputs ) outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)] outputs_joint.sort(key=lambda x: -x[1]) return outputs_joint[0][0] css = """ a { color: inherit; text-decoration: underline; } .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .gr-button { color: white; border-color: #000000; background: #000000; } input[type='range'] { accent-color: #000000; } .dark input[type='range'] { accent-color: #dfdfdf; } .container { max-width: 730px; margin: auto; padding-top: 1.5rem; } #gallery { min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important; } #gallery>div>.h-full { min-height: 20rem; } .details:hover { text-decoration: underline; } .gr-button { white-space: nowrap; } .gr-button:focus { border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1; --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); --tw-ring-opacity: .5; } #advanced-btn { font-size: .7rem !important; line-height: 19px; margin-top: 12px; margin-bottom: 12px; padding: 2px 8px; border-radius: 14px !important; } #advanced-options { margin-bottom: 20px; } .footer { margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .dark .footer { border-color: #303030; } .dark .footer>p { background: #0b0f19; } .acknowledgments h4{ margin: 1.25em 0 .25em 0; font-weight: bold; font-size: 115%; } #container-advanced-btns{ display: flex; flex-wrap: wrap; justify-content: space-between; align-items: center; } .animate-spin { animation: spin 1s linear infinite; } @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; margin-top: 10px; margin-left: auto; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; } #share-btn * { all: unset; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } #share-btn-container .wrap { display: none !important; } .gr-form{ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; } #prompt-container{ gap: 0; } #generated_id{ min-height: 700px } #setting_id{ margin-bottom: 12px; text-align: center; font-weight: 900; } """ ui = gr.Blocks(css=css, title="Audio Flamingo - Demo") with ui: gr.HTML( """

Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities

[Paper] [Code] [Demo Website] [Demo Video]

""" ) gr.HTML( """

Overview

Audio Flamingo is an audio language model that can understand sounds beyond speech. It can also answer questions about the sound in natural language.
Examples of questions include:
- Can you briefly describe what you hear in this audio?
- What is the emotion conveyed in this music?
- Where is this audio usually heard?
- What place is this music usually played at?
""" ) name = gr.Textbox( label="Audio file path (choose one from: audio/wav{1--6}.wav)", value="audio/wav1.wav" ) prompt = gr.Textbox( label="Instruction", value='Can you briefly describe what you hear in this audio?' ) with gr.Row(): play_audio_button = gr.Button("Play Audio") audio_output = gr.Audio(label="Playback") play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output) inference_button = gr.Button("Inference") output_text = gr.Textbox(label="Audio Flamingo output") inference_button.click( fn=inference_item, inputs=[name, prompt], outputs=output_text ) ui.queue() ui.launch()