Spaces:
Runtime error
Runtime error
Add threshold field
Browse files
app.py
CHANGED
@@ -23,9 +23,6 @@ class ClapSSGradio():
|
|
23 |
self.name = name
|
24 |
self.k = k
|
25 |
|
26 |
-
print("Env?!")
|
27 |
-
print(os.getenv('HUGGINGFACE_API_TOKEN')[:2])
|
28 |
-
|
29 |
self.model = ClapModel.from_pretrained(
|
30 |
f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
|
31 |
self.tokenizer = ClapProcessor.from_pretrained(
|
@@ -48,12 +45,12 @@ class ClapSSGradio():
|
|
48 |
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
|
49 |
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
|
50 |
|
51 |
-
def _similarity_search(self, query):
|
52 |
results = self.client.search(
|
53 |
collection_name=self.name,
|
54 |
query_vector=self._embed_query(query),
|
55 |
limit=self.k,
|
56 |
-
score_threshold=
|
57 |
)
|
58 |
|
59 |
containers = [result.payload['container'] for result in results]
|
@@ -94,21 +91,17 @@ class ClapSSGradio():
|
|
94 |
def launch(self, share=False):
|
95 |
# gradio app structure
|
96 |
with gr.Blocks(title='Clap Semantic Search') as ui:
|
97 |
-
|
98 |
with gr.Row():
|
99 |
with gr.Column(variant='panel'):
|
100 |
search = gr.Textbox(placeholder='Search Samples')
|
101 |
-
|
102 |
with gr.Column():
|
103 |
audioboxes = []
|
104 |
gr.Markdown("Output")
|
105 |
for i in range(self.k):
|
106 |
t = gr.components.Audio(label=f"{i}", visible=True)
|
107 |
audioboxes.append(t)
|
108 |
-
|
109 |
-
search.submit(fn=self._similarity_search, inputs=[
|
110 |
-
search], outputs=audioboxes)
|
111 |
-
|
112 |
ui.launch(share=share)
|
113 |
|
114 |
|
|
|
23 |
self.name = name
|
24 |
self.k = k
|
25 |
|
|
|
|
|
|
|
26 |
self.model = ClapModel.from_pretrained(
|
27 |
f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
|
28 |
self.tokenizer = ClapProcessor.from_pretrained(
|
|
|
45 |
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
|
46 |
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
|
47 |
|
48 |
+
def _similarity_search(self, query, threshold):
|
49 |
results = self.client.search(
|
50 |
collection_name=self.name,
|
51 |
query_vector=self._embed_query(query),
|
52 |
limit=self.k,
|
53 |
+
score_threshold=threshold,
|
54 |
)
|
55 |
|
56 |
containers = [result.payload['container'] for result in results]
|
|
|
91 |
def launch(self, share=False):
|
92 |
# gradio app structure
|
93 |
with gr.Blocks(title='Clap Semantic Search') as ui:
|
|
|
94 |
with gr.Row():
|
95 |
with gr.Column(variant='panel'):
|
96 |
search = gr.Textbox(placeholder='Search Samples')
|
97 |
+
float_input = gr.Number(label='Similarity threshold [min: 0.1 max: 1]', default=0.5, minimum=0.1, maximum=1)
|
98 |
with gr.Column():
|
99 |
audioboxes = []
|
100 |
gr.Markdown("Output")
|
101 |
for i in range(self.k):
|
102 |
t = gr.components.Audio(label=f"{i}", visible=True)
|
103 |
audioboxes.append(t)
|
104 |
+
search.submit(fn=self._similarity_search, inputs=[search, float_input], outputs=audioboxes)
|
|
|
|
|
|
|
105 |
ui.launch(share=share)
|
106 |
|
107 |
|