File size: 3,053 Bytes
7103ccc 748826b f603bfa 7103ccc 714a27c 748826b 416fea8 748826b 2371338 748826b d590a55 2371338 d590a55 748826b d590a55 748826b 416fea8 748826b f603bfa 7103ccc f603bfa 748826b f603bfa 5c51be8 f6551d5 5c51be8 f6551d5 5e78e4f 569b4a2 f603bfa 569b4a2 f603bfa 569b4a2 ce5f248 b5c2f12 569b4a2 f603bfa 569b4a2 f603bfa 569b4a2 ce5f248 5c51be8 d590a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
import random
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None
# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
@spaces.GPU
def get_embedding(text):
global model
if model is None:
model = AutoModel.from_pretrained(model_name).cuda()
model.resize_token_embeddings(len(tokenizer))
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
def reduce_to_3d(embedding):
return embedding[:3]
def random_color():
return f'rgb({random.randint(0,255)}, {random.randint(0,255)}, {random.randint(0,255)})'
@spaces.GPU
def compare_embeddings(*texts):
embeddings = [get_embedding(text) for text in texts if text.strip()]
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
fig = go.Figure()
# Add black origin point
fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', marker=dict(size=5, color='black'), name='Origin'))
# Add lines and points for each text
for i, emb in enumerate(embeddings_3d):
color = random_color()
fig.add_trace(go.Scatter3d(x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]], mode='lines+markers',
line=dict(color=color), marker=dict(color=color), name=f'Text {i+1}'))
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
return fig
def add_textbox(num_textboxes):
return gr.Textbox.update(visible=True, label=f"Text {num_textboxes + 1}"), num_textboxes + 1
def remove_textbox(num_textboxes):
return gr.Textbox.update(visible=False), max(2, num_textboxes - 1)
with gr.Blocks() as iface:
gr.Markdown("# 3D Embedding Comparison")
gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
num_textboxes = gr.State(value=2)
with gr.Column() as textbox_container:
textboxes = [gr.Textbox(label=f"Text {i+1}") for i in range(10)] # Create 10 textboxes
for i in range(2, 10): # Hide textboxes 3-10 initially
textboxes[i].visible = False
with gr.Row():
add_btn = gr.Button("Add String")
remove_btn = gr.Button("Remove String")
plot_output = gr.Plot()
submit_btn = gr.Button("Submit")
clear_btn = gr.ClearButton(components=textboxes + [plot_output], value="Clear")
add_btn.click(add_textbox, inputs=[num_textboxes], outputs=[textboxes[-1], num_textboxes])
remove_btn.click(remove_textbox, inputs=[num_textboxes], outputs=[textboxes[-1], num_textboxes])
submit_btn.click(compare_embeddings, inputs=textboxes, outputs=[plot_output])
iface.launch() |