swimmiing's picture
Add packages.txt
8b37f0a
raw
history blame
3.35 kB
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image
def greet(image, audio):
device = torch.device('cpu')
# Get model
model_conf_file = f'./config/model/ACL_ViT16.yaml'
model = ACL(model_conf_file, device)
model.train(False)
model.load('./pretrain/Param_best.pth')
# Get placeholder text
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
# Input pre processing
sample_rate, audio = audio
audio = audio.astype(np.float32, order='C') / 32768.0
desired_sample_rate = 16000
set_length = 10
audio_file = torch.from_numpy(audio)
if desired_sample_rate != sample_rate:
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)
if audio_file.shape[0] == 2:
audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration)
audio_file.squeeze(0)
if audio_file.shape[0] > (desired_sample_rate * set_length):
audio_file = audio_file[:desired_sample_rate * set_length]
# zero padding
if audio_file.shape[0] < (desired_sample_rate * set_length):
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
pad_val = torch.zeros(pad_len)
audio_file = torch.cat((audio_file, pad_val), dim=0)
audio_file = audio_file.unsqueeze(0)
image_transform = vt.Compose([
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
vt.ToTensor(),
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP
])
image_file = image_transform(image).unsqueeze(0)
# Inference
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
prompt_length)
# Localization result
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
seg = out_dict['heatmap'][0:1]
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
seg_image = Image.fromarray(seg_image)
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)
return overlaid_image
title = "Zero-shot sound source localization with ACL"
description = """<p>This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?'
To use it, simply upload an image and corresponding audio to mask (identify in the image)or use one of
the examples below and click 'submit'. Results will show up in a few seconds.<br></p>"""
article = """<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source
Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Official Github Repository</a></p>"""
demo = gr.Interface(
fn=greet,
inputs=[gr.Image(type='pil'), gr.Audio()],
outputs=gr.Image(type="pil"),
title=title,
description=description,
article=article,
)
demo.launch(debug=True)