Sergidev commited on
Commit
d3f7084
·
verified ·
1 Parent(s): 4250d36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -29,7 +29,7 @@ def reduce_to_3d(embedding):
29
 
30
  @spaces.GPU
31
  def compare_embeddings(*texts):
32
- embeddings = [get_embedding(text) for text in texts]
33
  embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
34
 
35
  fig = go.Figure()
@@ -45,34 +45,31 @@ def compare_embeddings(*texts):
45
  def generate_text_boxes(n):
46
  return [gr.Textbox(label=f"Text {i+1}") for i in range(n)]
47
 
48
- def update_interface(n):
49
- new_inputs = generate_text_boxes(n)
50
- return new_inputs, gr.Plot()
51
-
52
  with gr.Blocks() as iface:
53
  gr.Markdown("# 3D Embedding Comparison")
54
  gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
55
 
56
- with gr.Row():
57
- num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
58
- update_button = gr.Button("Update")
59
 
60
  with gr.Column() as input_column:
61
- inputs = generate_text_boxes(2)
62
 
63
  output = gr.Plot()
64
 
65
  compare_button = gr.Button("Compare Embeddings")
66
 
67
- update_button.click(
 
 
 
68
  update_interface,
69
  inputs=[num_texts],
70
- outputs=[input_column, output]
71
  )
72
 
73
  compare_button.click(
74
  compare_embeddings,
75
- inputs=inputs,
76
  outputs=output
77
  )
78
 
 
29
 
30
  @spaces.GPU
31
  def compare_embeddings(*texts):
32
+ embeddings = [get_embedding(text) for text in texts if text.strip()]
33
  embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
34
 
35
  fig = go.Figure()
 
45
  def generate_text_boxes(n):
46
  return [gr.Textbox(label=f"Text {i+1}") for i in range(n)]
47
 
 
 
 
 
48
  with gr.Blocks() as iface:
49
  gr.Markdown("# 3D Embedding Comparison")
50
  gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
51
 
52
+ num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
 
 
53
 
54
  with gr.Column() as input_column:
55
+ text_boxes = gr.Group(generate_text_boxes(2))
56
 
57
  output = gr.Plot()
58
 
59
  compare_button = gr.Button("Compare Embeddings")
60
 
61
+ def update_interface(n):
62
+ return gr.Group.update(visible=False), gr.Group.update(visible=True, components=generate_text_boxes(n))
63
+
64
+ num_texts.change(
65
  update_interface,
66
  inputs=[num_texts],
67
+ outputs=[text_boxes, text_boxes]
68
  )
69
 
70
  compare_button.click(
71
  compare_embeddings,
72
+ inputs=text_boxes,
73
  outputs=output
74
  )
75