File size: 3,354 Bytes
3c8d75e
b20af9f
 
 
 
334681f
 
b20af9f
 
3c8d75e
00a76b6
833fa47
b20af9f
 
 
 
 
 
 
 
 
 
 
334681f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f752a6
b20af9f
 
 
334681f
 
b20af9f
 
334681f
ddf85f7
334681f
 
 
 
 
 
b20af9f
ddf85f7
 
8b37f0a
 
 
 
 
a9320d1
 
 
 
 
6f752a6
a9320d1
8095871
a9320d1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image


def greet(image, audio):
    device = torch.device('cpu')

    # Get model
    model_conf_file = f'./config/model/ACL_ViT16.yaml'
    model = ACL(model_conf_file, device)
    model.train(False)
    model.load('./pretrain/Param_best.pth')

    # Get placeholder text
    prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()

    # Input pre processing
    sample_rate, audio = audio
    audio = audio.astype(np.float32, order='C') / 32768.0
    desired_sample_rate = 16000
    set_length = 10

    audio_file = torch.from_numpy(audio)

    if desired_sample_rate != sample_rate:
        audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)

    if audio_file.shape[0] == 2:
        audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0)  # Stereo -> mono (x2 duration)

    audio_file.squeeze(0)

    if audio_file.shape[0] > (desired_sample_rate * set_length):
        audio_file = audio_file[:desired_sample_rate * set_length]

    # zero padding
    if audio_file.shape[0] < (desired_sample_rate * set_length):
        pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
        pad_val = torch.zeros(pad_len)
        audio_file = torch.cat((audio_file, pad_val), dim=0)

    audio_file = audio_file.unsqueeze(0)

    image_transform = vt.Compose([
        vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
        vt.ToTensor(),
        vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),  # CLIP
    ])

    image_file = image_transform(image).unsqueeze(0)

    # Inference
    placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
    audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
                                                prompt_length)

    # Localization result
    out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
    seg = out_dict['heatmap'][0:1]
    seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
    seg_image = Image.fromarray(seg_image)
    heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
    overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)

    return overlaid_image


title = "Zero-shot sound source localization with ACL"
description = """<p>This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?'
                To use it, simply upload an image and corresponding audio to mask (identify in the image)or use one of 
                the examples below and click 'submit'. Results will show up in a few seconds.<br></p>"""
article = """<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source
            Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Official Github Repository</a></p>"""

demo = gr.Interface(
    fn=greet,
    inputs=[gr.Image(type='pil'), gr.Audio()],
    outputs=gr.Image(type="pil"),
    title=title,
    description=description,
    article=article,
)

demo.launch(debug=True)