File size: 3,104 Bytes
7103ccc 5a0b505 be25071 5a0b505 748826b 7103ccc 7bde26f 4525b31 673350b 5a0b505 748826b 42d891d 673350b 5a0b505 673350b 4250d36 1ba32de 4250d36 88b0945 4250d36 1ba32de 5e78e4f 4250d36 d7977e8 4250d36 673350b 4250d36 673350b d3f7084 4250d36 66e1050 4250d36 d3f7084 d7977e8 d3f7084 4250d36 66e1050 4250d36 673350b 4250d36 9006e63 673350b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import spaces
import torch
import os
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
TOKEN = os.getenv("HF_TOKEN")
default_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = None
model = None
@spaces.GPU(duration=300)
def get_embedding(text, model_repo):
global tokenizer, model
if tokenizer is None or model is None or model.name_or_path != model_repo:
try:
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModel.from_pretrained(model_repo, torch_dtype=torch.float16).cuda()
# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
except Exception as e:
return f"Error loading model: {str(e)}"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
def reduce_to_3d(embedding):
return embedding[:3]
@spaces.GPU
def compare_embeddings(model_repo, *texts):
if not model_repo:
model_repo = default_model_name
embeddings = []
for text in texts:
if text.strip():
emb = get_embedding(text, model_repo)
if isinstance(emb, str): # Error message
return emb
embeddings.append(emb)
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
fig = go.Figure()
for i, emb in enumerate(embeddings_3d):
fig.add_trace(go.Scatter3d(x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
mode='lines+markers', name=f'Text {i+1}'))
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
return fig
def generate_text_boxes(n):
return [gr.Textbox(label=f"Text {i+1}", visible=(i < n)) for i in range(10)]
with gr.Blocks() as iface:
gr.Markdown("# 3D Embedding Comparison")
gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using a custom model.")
model_repo_input = gr.Textbox(label="Model Repository", value=default_model_name, placeholder="Enter the model repository (e.g., mistralai/Mistral-7B-Instruct-v0.3)")
num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
with gr.Column() as input_column:
text_boxes = generate_text_boxes(2)
output = gr.Plot()
compare_button = gr.Button("Compare Embeddings")
def update_interface(n):
return [gr.update(visible=(i < n)) for i in range(10)]
num_texts.change(
update_interface,
inputs=[num_texts],
outputs=text_boxes
)
compare_button.click(
compare_embeddings,
inputs=[model_repo_input] + text_boxes,
outputs=output
)
iface.launch() |