File size: 5,267 Bytes
148d2c2
6667d8a
 
 
 
 
 
 
 
 
 
148d2c2
 
6667d8a
 
 
 
 
 
 
971e667
6667d8a
 
 
 
 
 
 
971e667
148d2c2
971e667
6667d8a
 
 
 
 
 
 
 
 
 
 
 
 
148d2c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6667d8a
 
148d2c2
6667d8a
de3512d
6667d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a25f7e
 
6667d8a
1a25f7e
6667d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148d2c2
 
 
 
 
6667d8a
 
 
 
 
 
148d2c2
 
6667d8a
 
 
 
971e667
6667d8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from transformers import ClapModel, ClapProcessor, AutoFeatureExtractor
import gradio as gr
import torch
import torchaudio
import os
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http import models


import dotenv
dotenv.load_dotenv()


class ClapSSGradio():

    def __init__(
        self,
        name,
        model = "clap-2",
        k=10,
    ):

        self.name = name
        self.k = k

        self.model = ClapModel.from_pretrained(
            f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
        self.processor = ClapProcessor.from_pretrained(
            f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))

        self.sas_token = os.environ['AZURE_SAS_TOKEN']
        self.account_name = 'Audiogen'
        self.storage_name = 'audiogentrainingdataeun'

        self._start_qdrant()

    def _start_qdrant(self):
        self.client = QdrantClient(url=os.getenv(
            "QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY'))
        # print(self.client.get_collection(collection_name=self.name))

    @torch.no_grad()
    def _embed_query(self, query, audio_file):
        if audio_file is not None:
            waveform, sample_rate = torchaudio.load(audio_file.name)
            print("Waveform shape:", waveform.shape)
            waveform = torchaudio.functional.resample(
                waveform, sample_rate, 48000)
            print("Resampled waveform shape:", waveform.shape)

            if waveform.shape[-1] < 480000:
                waveform = torch.nn.functional.pad(
                    waveform, (0, 48000 - waveform.shape[-1]))
            elif waveform.shape[-1] > 480000:
                waveform = waveform[..., :480000]

            audio_prompt_features = self.processor(
                audios=waveform.mean(0), return_tensors='pt', sampling_rate=48000
            )['input_features']
            print("Audio prompt features shape:", audio_prompt_features.shape)
            e = self.model.get_audio_features(
                input_features=audio_prompt_features)[0]

            if any(torch.isnan(e)):
                raise ValueError("Audio features are NaN")
            print("Embeddings: ", e.shape)
            return e
        else:
            inputs = self.processor(
                query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)

            return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]

    def _similarity_search(self, query, threshold, audio_file):
        results = self.client.search(
            collection_name=self.name,
            query_vector=self._embed_query(query, audio_file),
            limit=self.k,
            score_threshold=threshold,
        )

        containers = [result.payload['container'] for result in results]
        filenames = [result.id for result in results]
        captions = [result.payload['caption'] for result in results]
        scores = [result.score for result in results]

        # print to stdout
        print(f"\nQuery: {query}\n")
        for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)):
            print(f"{i}: {container} - {caption}. Score: {score}")

        waveforms = self._download_results(containers, filenames)

        if len(waveforms) == 0:
            print("\nNo results found")

        if len(waveforms) < self.k:
            waveforms.extend([(int(48000), np.zeros((480000, 2)))
                             for _ in range(self.k - len(waveforms))])

        return waveforms

    def _download_results(self, containers: list, filenames: list):



        # construct url
        urls = [f"https://{self.storage_name}.blob.core.windows.net/snake/{file_name}.flac?{self.sas_token}" for file_name in filenames]

        # make requests
        waveforms = []
        for url in urls:
            waveform, sample_rate = torchaudio.load(url)
            waveforms.append(tuple([sample_rate, waveform.numpy().T]))

        return waveforms

    def launch(self, share=False):
        # gradio app structure
        with gr.Blocks(title='Clap Semantic Search') as ui:
            with gr.Row():
                with gr.Column(variant='panel'):
                    search = gr.Textbox(placeholder='Search Samples')
                    float_input = gr.Number(
                        label='Similarity threshold [min: 0.1 max: 1]', value=0.5, minimum=0.1, maximum=1)
                    audio_file = gr.File(
                        label='Upload an Audio File', type="file")
                    search_button = gr.Button("Search", label='Search')
                with gr.Column():
                    audioboxes = []
                    gr.Markdown("Output")
                    for i in range(self.k):
                        t = gr.components.Audio(label=f"{i}", visible=True)
                        audioboxes.append(t)
            search_button.click(fn=self._similarity_search, inputs=[
                search, float_input, audio_file], outputs=audioboxes)
            ui.launch(share=share)


if __name__ == "__main__":
    app = ClapSSGradio("demo")
    app.launch(share=False)