Spaces:
Runtime error
Runtime error
File size: 5,267 Bytes
148d2c2 6667d8a 148d2c2 6667d8a 971e667 6667d8a 971e667 148d2c2 971e667 6667d8a 148d2c2 6667d8a 148d2c2 6667d8a de3512d 6667d8a 1a25f7e 6667d8a 1a25f7e 6667d8a 148d2c2 6667d8a 148d2c2 6667d8a 971e667 6667d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from transformers import ClapModel, ClapProcessor, AutoFeatureExtractor
import gradio as gr
import torch
import torchaudio
import os
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http import models
import dotenv
dotenv.load_dotenv()
class ClapSSGradio():
def __init__(
self,
name,
model = "clap-2",
k=10,
):
self.name = name
self.k = k
self.model = ClapModel.from_pretrained(
f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
self.processor = ClapProcessor.from_pretrained(
f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
self.sas_token = os.environ['AZURE_SAS_TOKEN']
self.account_name = 'Audiogen'
self.storage_name = 'audiogentrainingdataeun'
self._start_qdrant()
def _start_qdrant(self):
self.client = QdrantClient(url=os.getenv(
"QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY'))
# print(self.client.get_collection(collection_name=self.name))
@torch.no_grad()
def _embed_query(self, query, audio_file):
if audio_file is not None:
waveform, sample_rate = torchaudio.load(audio_file.name)
print("Waveform shape:", waveform.shape)
waveform = torchaudio.functional.resample(
waveform, sample_rate, 48000)
print("Resampled waveform shape:", waveform.shape)
if waveform.shape[-1] < 480000:
waveform = torch.nn.functional.pad(
waveform, (0, 48000 - waveform.shape[-1]))
elif waveform.shape[-1] > 480000:
waveform = waveform[..., :480000]
audio_prompt_features = self.processor(
audios=waveform.mean(0), return_tensors='pt', sampling_rate=48000
)['input_features']
print("Audio prompt features shape:", audio_prompt_features.shape)
e = self.model.get_audio_features(
input_features=audio_prompt_features)[0]
if any(torch.isnan(e)):
raise ValueError("Audio features are NaN")
print("Embeddings: ", e.shape)
return e
else:
inputs = self.processor(
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
def _similarity_search(self, query, threshold, audio_file):
results = self.client.search(
collection_name=self.name,
query_vector=self._embed_query(query, audio_file),
limit=self.k,
score_threshold=threshold,
)
containers = [result.payload['container'] for result in results]
filenames = [result.id for result in results]
captions = [result.payload['caption'] for result in results]
scores = [result.score for result in results]
# print to stdout
print(f"\nQuery: {query}\n")
for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)):
print(f"{i}: {container} - {caption}. Score: {score}")
waveforms = self._download_results(containers, filenames)
if len(waveforms) == 0:
print("\nNo results found")
if len(waveforms) < self.k:
waveforms.extend([(int(48000), np.zeros((480000, 2)))
for _ in range(self.k - len(waveforms))])
return waveforms
def _download_results(self, containers: list, filenames: list):
# construct url
urls = [f"https://{self.storage_name}.blob.core.windows.net/snake/{file_name}.flac?{self.sas_token}" for file_name in filenames]
# make requests
waveforms = []
for url in urls:
waveform, sample_rate = torchaudio.load(url)
waveforms.append(tuple([sample_rate, waveform.numpy().T]))
return waveforms
def launch(self, share=False):
# gradio app structure
with gr.Blocks(title='Clap Semantic Search') as ui:
with gr.Row():
with gr.Column(variant='panel'):
search = gr.Textbox(placeholder='Search Samples')
float_input = gr.Number(
label='Similarity threshold [min: 0.1 max: 1]', value=0.5, minimum=0.1, maximum=1)
audio_file = gr.File(
label='Upload an Audio File', type="file")
search_button = gr.Button("Search", label='Search')
with gr.Column():
audioboxes = []
gr.Markdown("Output")
for i in range(self.k):
t = gr.components.Audio(label=f"{i}", visible=True)
audioboxes.append(t)
search_button.click(fn=self._similarity_search, inputs=[
search, float_input, audio_file], outputs=audioboxes)
ui.launch(share=share)
if __name__ == "__main__":
app = ClapSSGradio("demo")
app.launch(share=False)
|