File size: 2,608 Bytes
7103ccc 1ba32de 7103ccc 1ba32de 748826b 88b0945 7103ccc 714a27c 1ba32de 416fea8 748826b 2371338 1ba32de 748826b d590a55 1ba32de d590a55 1ba32de 748826b 1ba32de 748826b 416fea8 748826b 1ba32de 88b0945 fae735f 88b0945 1ba32de 88b0945 1ba32de 88b0945 1ba32de 5e78e4f fae735f 88b0945 fae735f 88b0945 fae735f 9006e63 1ba32de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
import numpy as np
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None
# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
@spaces.GPU
def get_embedding(text):
global model
if model is None:
model = AutoModel.from_pretrained(model_name).cuda()
model.resize_token_embeddings(len(tokenizer))
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
def reduce_to_3d(embedding):
return embedding[:3]
@spaces.GPU
def compare_embeddings(*texts):
embeddings = [get_embedding(text) for text in texts if text.strip()] # Only process non-empty texts
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
fig = go.Figure()
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
for i, emb in enumerate(embeddings_3d):
color = colors[i % len(colors)]
fig.add_trace(go.Scatter3d(
x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
mode='lines+markers',
name=f'Text {i+1}',
line=dict(color=color),
marker=dict(color=color)
))
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
return fig
def create_interface(num_inputs):
with gr.Blocks() as new_interface:
text_inputs = [gr.Textbox(label=f"Text {i+1}") for i in range(num_inputs)]
output = gr.Plot()
submit_btn = gr.Button("Compare Embeddings")
submit_btn.click(fn=compare_embeddings, inputs=text_inputs, outputs=output)
return new_interface
with gr.Blocks() as iface:
gr.Markdown("# 3D Embedding Comparison")
gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
num_inputs = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
interface_container = gr.HTML()
def update_interface(num):
return create_interface(num)
num_inputs.change(fn=update_interface, inputs=[num_inputs], outputs=[interface_container])
# Initialize the interface with 2 text boxes
interface_container.update(create_interface(2))
iface.launch() |