theOnlyJaco's picture
Override container for demo
1a25f7e unverified
from transformers import ClapModel, ClapProcessor, AutoFeatureExtractor
import gradio as gr
import torch
import torchaudio
import os
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http import models
import dotenv
dotenv.load_dotenv()
class ClapSSGradio():
def __init__(
self,
name,
model = "clap-2",
k=10,
):
self.name = name
self.k = k
self.model = ClapModel.from_pretrained(
f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
self.processor = ClapProcessor.from_pretrained(
f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
self.sas_token = os.environ['AZURE_SAS_TOKEN']
self.account_name = 'Audiogen'
self.storage_name = 'audiogentrainingdataeun'
self._start_qdrant()
def _start_qdrant(self):
self.client = QdrantClient(url=os.getenv(
"QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY'))
# print(self.client.get_collection(collection_name=self.name))
@torch.no_grad()
def _embed_query(self, query, audio_file):
if audio_file is not None:
waveform, sample_rate = torchaudio.load(audio_file.name)
print("Waveform shape:", waveform.shape)
waveform = torchaudio.functional.resample(
waveform, sample_rate, 48000)
print("Resampled waveform shape:", waveform.shape)
if waveform.shape[-1] < 480000:
waveform = torch.nn.functional.pad(
waveform, (0, 48000 - waveform.shape[-1]))
elif waveform.shape[-1] > 480000:
waveform = waveform[..., :480000]
audio_prompt_features = self.processor(
audios=waveform.mean(0), return_tensors='pt', sampling_rate=48000
)['input_features']
print("Audio prompt features shape:", audio_prompt_features.shape)
e = self.model.get_audio_features(
input_features=audio_prompt_features)[0]
if any(torch.isnan(e)):
raise ValueError("Audio features are NaN")
print("Embeddings: ", e.shape)
return e
else:
inputs = self.processor(
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
def _similarity_search(self, query, threshold, audio_file):
results = self.client.search(
collection_name=self.name,
query_vector=self._embed_query(query, audio_file),
limit=self.k,
score_threshold=threshold,
)
containers = [result.payload['container'] for result in results]
filenames = [result.id for result in results]
captions = [result.payload['caption'] for result in results]
scores = [result.score for result in results]
# print to stdout
print(f"\nQuery: {query}\n")
for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)):
print(f"{i}: {container} - {caption}. Score: {score}")
waveforms = self._download_results(containers, filenames)
if len(waveforms) == 0:
print("\nNo results found")
if len(waveforms) < self.k:
waveforms.extend([(int(48000), np.zeros((480000, 2)))
for _ in range(self.k - len(waveforms))])
return waveforms
def _download_results(self, containers: list, filenames: list):
# construct url
urls = [f"https://{self.storage_name}.blob.core.windows.net/snake/{file_name}.flac?{self.sas_token}" for file_name in filenames]
# make requests
waveforms = []
for url in urls:
waveform, sample_rate = torchaudio.load(url)
waveforms.append(tuple([sample_rate, waveform.numpy().T]))
return waveforms
def launch(self, share=False):
# gradio app structure
with gr.Blocks(title='Clap Semantic Search') as ui:
with gr.Row():
with gr.Column(variant='panel'):
search = gr.Textbox(placeholder='Search Samples')
float_input = gr.Number(
label='Similarity threshold [min: 0.1 max: 1]', value=0.5, minimum=0.1, maximum=1)
audio_file = gr.File(
label='Upload an Audio File', type="file")
search_button = gr.Button("Search", label='Search')
with gr.Column():
audioboxes = []
gr.Markdown("Output")
for i in range(self.k):
t = gr.components.Audio(label=f"{i}", visible=True)
audioboxes.append(t)
search_button.click(fn=self._similarity_search, inputs=[
search, float_input, audio_file], outputs=audioboxes)
ui.launch(share=share)
if __name__ == "__main__":
app = ClapSSGradio("demo")
app.launch(share=False)