theOnlyJaco commited on
Commit
de3512d
·
unverified ·
1 Parent(s): 6667d8a

Add threshold field

Browse files
Files changed (1) hide show
  1. app.py +4 -11
app.py CHANGED
@@ -23,9 +23,6 @@ class ClapSSGradio():
23
  self.name = name
24
  self.k = k
25
 
26
- print("Env?!")
27
- print(os.getenv('HUGGINGFACE_API_TOKEN')[:2])
28
-
29
  self.model = ClapModel.from_pretrained(
30
  f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
31
  self.tokenizer = ClapProcessor.from_pretrained(
@@ -48,12 +45,12 @@ class ClapSSGradio():
48
  query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
49
  return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
50
 
51
- def _similarity_search(self, query):
52
  results = self.client.search(
53
  collection_name=self.name,
54
  query_vector=self._embed_query(query),
55
  limit=self.k,
56
- score_threshold=0.5,
57
  )
58
 
59
  containers = [result.payload['container'] for result in results]
@@ -94,21 +91,17 @@ class ClapSSGradio():
94
  def launch(self, share=False):
95
  # gradio app structure
96
  with gr.Blocks(title='Clap Semantic Search') as ui:
97
-
98
  with gr.Row():
99
  with gr.Column(variant='panel'):
100
  search = gr.Textbox(placeholder='Search Samples')
101
-
102
  with gr.Column():
103
  audioboxes = []
104
  gr.Markdown("Output")
105
  for i in range(self.k):
106
  t = gr.components.Audio(label=f"{i}", visible=True)
107
  audioboxes.append(t)
108
-
109
- search.submit(fn=self._similarity_search, inputs=[
110
- search], outputs=audioboxes)
111
-
112
  ui.launch(share=share)
113
 
114
 
 
23
  self.name = name
24
  self.k = k
25
 
 
 
 
26
  self.model = ClapModel.from_pretrained(
27
  f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
28
  self.tokenizer = ClapProcessor.from_pretrained(
 
45
  query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
46
  return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
47
 
48
+ def _similarity_search(self, query, threshold):
49
  results = self.client.search(
50
  collection_name=self.name,
51
  query_vector=self._embed_query(query),
52
  limit=self.k,
53
+ score_threshold=threshold,
54
  )
55
 
56
  containers = [result.payload['container'] for result in results]
 
91
  def launch(self, share=False):
92
  # gradio app structure
93
  with gr.Blocks(title='Clap Semantic Search') as ui:
 
94
  with gr.Row():
95
  with gr.Column(variant='panel'):
96
  search = gr.Textbox(placeholder='Search Samples')
97
+ float_input = gr.Number(label='Similarity threshold [min: 0.1 max: 1]', default=0.5, minimum=0.1, maximum=1)
98
  with gr.Column():
99
  audioboxes = []
100
  gr.Markdown("Output")
101
  for i in range(self.k):
102
  t = gr.components.Audio(label=f"{i}", visible=True)
103
  audioboxes.append(t)
104
+ search.submit(fn=self._similarity_search, inputs=[search, float_input], outputs=audioboxes)
 
 
 
105
  ui.launch(share=share)
106
 
107