File size: 2,608 Bytes
7103ccc
1ba32de
7103ccc
1ba32de
748826b
88b0945
7103ccc
714a27c
1ba32de
416fea8
748826b
2371338
 
 
 
1ba32de
748826b
d590a55
 
1ba32de
 
d590a55
1ba32de
748826b
 
1ba32de
748826b
416fea8
 
748826b
1ba32de
88b0945
fae735f
88b0945
1ba32de
88b0945
1ba32de
88b0945
 
 
 
 
 
 
 
 
 
 
1ba32de
 
 
 
5e78e4f
fae735f
88b0945
 
 
 
 
 
 
 
 
 
 
 
fae735f
88b0945
fae735f
 
 
 
 
 
 
9006e63
1ba32de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
import numpy as np

model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None

# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

@spaces.GPU
def get_embedding(text):
    global model
    if model is None:
        model = AutoModel.from_pretrained(model_name).cuda()
        model.resize_token_embeddings(len(tokenizer))
    
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

def reduce_to_3d(embedding):
    return embedding[:3]

@spaces.GPU
def compare_embeddings(*texts):
    embeddings = [get_embedding(text) for text in texts if text.strip()]  # Only process non-empty texts
    embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
    
    fig = go.Figure()
    
    colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
    
    for i, emb in enumerate(embeddings_3d):
        color = colors[i % len(colors)]
        fig.add_trace(go.Scatter3d(
            x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
            mode='lines+markers',
            name=f'Text {i+1}',
            line=dict(color=color),
            marker=dict(color=color)
        ))
    
    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
    
    return fig

def create_interface(num_inputs):
    with gr.Blocks() as new_interface:
        text_inputs = [gr.Textbox(label=f"Text {i+1}") for i in range(num_inputs)]
        output = gr.Plot()
        submit_btn = gr.Button("Compare Embeddings")
        submit_btn.click(fn=compare_embeddings, inputs=text_inputs, outputs=output)
    return new_interface

with gr.Blocks() as iface:
    gr.Markdown("# 3D Embedding Comparison")
    gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
    
    num_inputs = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
    interface_container = gr.HTML()
    
    def update_interface(num):
        return create_interface(num)
    
    num_inputs.change(fn=update_interface, inputs=[num_inputs], outputs=[interface_container])
    
    # Initialize the interface with 2 text boxes
    interface_container.update(create_interface(2))

iface.launch()