Sergidev commited on
Commit
4250d36
·
verified ·
1 Parent(s): 5a0b505

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -28,31 +28,52 @@ def reduce_to_3d(embedding):
28
  return embedding[:3]
29
 
30
  @spaces.GPU
31
- def compare_embeddings(text1, text2):
32
- emb1 = get_embedding(text1)
33
- emb2 = get_embedding(text2)
34
 
35
- emb1_3d = reduce_to_3d(emb1)
36
- emb2_3d = reduce_to_3d(emb2)
37
 
38
- fig = go.Figure(data=[
39
- go.Scatter3d(x=[0, emb1_3d[0]], y=[0, emb1_3d[1]], z=[0, emb1_3d[2]], mode='lines+markers', name='Text 1'),
40
- go.Scatter3d(x=[0, emb2_3d[0]], y=[0, emb2_3d[1]], z=[0, emb2_3d[2]], mode='lines+markers', name='Text 2')
41
- ])
42
 
43
  fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
44
 
45
  return fig
46
 
47
- iface = gr.Interface(
48
- fn=compare_embeddings,
49
- inputs=[
50
- gr.Textbox(label="Text 1"),
51
- gr.Textbox(label="Text 2")
52
- ],
53
- outputs=gr.Plot(),
54
- title="3D Embedding Comparison",
55
- description="Compare the embeddings of two strings visualized in 3D space using Mistral 7B."
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  iface.launch()
 
28
  return embedding[:3]
29
 
30
  @spaces.GPU
31
+ def compare_embeddings(*texts):
32
+ embeddings = [get_embedding(text) for text in texts]
33
+ embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
34
 
35
+ fig = go.Figure()
 
36
 
37
+ for i, emb in enumerate(embeddings_3d):
38
+ fig.add_trace(go.Scatter3d(x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
39
+ mode='lines+markers', name=f'Text {i+1}'))
 
40
 
41
  fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
42
 
43
  return fig
44
 
45
+ def generate_text_boxes(n):
46
+ return [gr.Textbox(label=f"Text {i+1}") for i in range(n)]
47
+
48
+ def update_interface(n):
49
+ new_inputs = generate_text_boxes(n)
50
+ return new_inputs, gr.Plot()
51
+
52
+ with gr.Blocks() as iface:
53
+ gr.Markdown("# 3D Embedding Comparison")
54
+ gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
55
+
56
+ with gr.Row():
57
+ num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
58
+ update_button = gr.Button("Update")
59
+
60
+ with gr.Column() as input_column:
61
+ inputs = generate_text_boxes(2)
62
+
63
+ output = gr.Plot()
64
+
65
+ compare_button = gr.Button("Compare Embeddings")
66
+
67
+ update_button.click(
68
+ update_interface,
69
+ inputs=[num_texts],
70
+ outputs=[input_column, output]
71
+ )
72
+
73
+ compare_button.click(
74
+ compare_embeddings,
75
+ inputs=inputs,
76
+ outputs=output
77
+ )
78
 
79
  iface.launch()