File size: 3,104 Bytes
7103ccc
5a0b505
 
be25071
5a0b505
748826b
7103ccc
7bde26f
 
4525b31
673350b
5a0b505
748826b
42d891d
673350b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a0b505
 
 
 
 
 
 
 
 
 
673350b
 
 
 
 
 
 
 
 
 
 
 
4250d36
1ba32de
4250d36
88b0945
4250d36
 
 
1ba32de
 
 
 
5e78e4f
4250d36
d7977e8
4250d36
 
 
673350b
4250d36
673350b
d3f7084
4250d36
 
66e1050
4250d36
 
 
 
 
d3f7084
d7977e8
d3f7084
 
4250d36
 
66e1050
4250d36
 
 
 
673350b
4250d36
 
9006e63
673350b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import spaces
import torch
import os
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go

TOKEN = os.getenv("HF_TOKEN")

default_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = None
model = None

@spaces.GPU(duration=300)
def get_embedding(text, model_repo):
    global tokenizer, model
    
    if tokenizer is None or model is None or model.name_or_path != model_repo:
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_repo)
            model = AutoModel.from_pretrained(model_repo, torch_dtype=torch.float16).cuda()
            
            # Set pad token to eos token if not defined
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            model.resize_token_embeddings(len(tokenizer))
        except Exception as e:
            return f"Error loading model: {str(e)}"
    
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

def reduce_to_3d(embedding):
    return embedding[:3]

@spaces.GPU
def compare_embeddings(model_repo, *texts):
    if not model_repo:
        model_repo = default_model_name
    
    embeddings = []
    for text in texts:
        if text.strip():
            emb = get_embedding(text, model_repo)
            if isinstance(emb, str):  # Error message
                return emb
            embeddings.append(emb)
    
    embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
    
    fig = go.Figure()
    
    for i, emb in enumerate(embeddings_3d):
        fig.add_trace(go.Scatter3d(x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]], 
                                   mode='lines+markers', name=f'Text {i+1}'))
    
    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
    
    return fig

def generate_text_boxes(n):
    return [gr.Textbox(label=f"Text {i+1}", visible=(i < n)) for i in range(10)]

with gr.Blocks() as iface:
    gr.Markdown("# 3D Embedding Comparison")
    gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using a custom model.")
    
    model_repo_input = gr.Textbox(label="Model Repository", value=default_model_name, placeholder="Enter the model repository (e.g., mistralai/Mistral-7B-Instruct-v0.3)")
    num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
    
    with gr.Column() as input_column:
        text_boxes = generate_text_boxes(2)
    
    output = gr.Plot()
    
    compare_button = gr.Button("Compare Embeddings")
    
    def update_interface(n):
        return [gr.update(visible=(i < n)) for i in range(10)]

    num_texts.change(
        update_interface,
        inputs=[num_texts],
        outputs=text_boxes
    )
    
    compare_button.click(
        compare_embeddings,
        inputs=[model_repo_input] + text_boxes,
        outputs=output
    )

iface.launch()