File size: 3,053 Bytes
7103ccc
 
 
748826b
 
f603bfa
7103ccc
714a27c
748826b
416fea8
748826b
2371338
 
 
 
748826b
 
d590a55
 
 
2371338
d590a55
 
748826b
 
d590a55
748826b
416fea8
 
748826b
f603bfa
 
 
7103ccc
f603bfa
 
 
748826b
f603bfa
 
 
 
 
 
 
 
 
 
5c51be8
f6551d5
5c51be8
f6551d5
5e78e4f
569b4a2
 
f603bfa
569b4a2
 
f603bfa
 
 
 
 
569b4a2
ce5f248
b5c2f12
569b4a2
 
 
f603bfa
 
 
 
 
 
 
 
569b4a2
f603bfa
569b4a2
 
ce5f248
5c51be8
d590a55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
import random

model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None

# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

@spaces.GPU
def get_embedding(text):
    global model
    if model is None:
        model = AutoModel.from_pretrained(model_name).cuda()
        model.resize_token_embeddings(len(tokenizer))
    
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

def reduce_to_3d(embedding):
    return embedding[:3]

def random_color():
    return f'rgb({random.randint(0,255)}, {random.randint(0,255)}, {random.randint(0,255)})'

@spaces.GPU
def compare_embeddings(*texts):
    embeddings = [get_embedding(text) for text in texts if text.strip()]
    embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
    
    fig = go.Figure()

    # Add black origin point
    fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', marker=dict(size=5, color='black'), name='Origin'))

    # Add lines and points for each text
    for i, emb in enumerate(embeddings_3d):
        color = random_color()
        fig.add_trace(go.Scatter3d(x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]], mode='lines+markers', 
                                   line=dict(color=color), marker=dict(color=color), name=f'Text {i+1}'))
    
    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
    
    return fig

def add_textbox(num_textboxes):
    return gr.Textbox.update(visible=True, label=f"Text {num_textboxes + 1}"), num_textboxes + 1

def remove_textbox(num_textboxes):
    return gr.Textbox.update(visible=False), max(2, num_textboxes - 1)

with gr.Blocks() as iface:
    gr.Markdown("# 3D Embedding Comparison")
    gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
    
    num_textboxes = gr.State(value=2)
    
    with gr.Column() as textbox_container:
        textboxes = [gr.Textbox(label=f"Text {i+1}") for i in range(10)]  # Create 10 textboxes
        for i in range(2, 10):  # Hide textboxes 3-10 initially
            textboxes[i].visible = False
    
    with gr.Row():
        add_btn = gr.Button("Add String")
        remove_btn = gr.Button("Remove String")
    
    plot_output = gr.Plot()
    
    submit_btn = gr.Button("Submit")
    clear_btn = gr.ClearButton(components=textboxes + [plot_output], value="Clear")
    
    add_btn.click(add_textbox, inputs=[num_textboxes], outputs=[textboxes[-1], num_textboxes])
    remove_btn.click(remove_textbox, inputs=[num_textboxes], outputs=[textboxes[-1], num_textboxes])
    submit_btn.click(compare_embeddings, inputs=textboxes, outputs=[plot_output])

iface.launch()