Update app.py
Browse files
app.py
CHANGED
@@ -7,20 +7,26 @@ import plotly.graph_objects as go
|
|
7 |
|
8 |
TOKEN = os.getenv("HF_TOKEN")
|
9 |
|
10 |
-
|
11 |
-
tokenizer =
|
12 |
model = None
|
13 |
|
14 |
-
# Set pad token to eos token if not defined
|
15 |
-
if tokenizer.pad_token is None:
|
16 |
-
tokenizer.pad_token = tokenizer.eos_token
|
17 |
-
|
18 |
@spaces.GPU(duration=300)
|
19 |
-
def get_embedding(text):
|
20 |
-
global model
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
|
26 |
with torch.no_grad():
|
@@ -31,8 +37,18 @@ def reduce_to_3d(embedding):
|
|
31 |
return embedding[:3]
|
32 |
|
33 |
@spaces.GPU
|
34 |
-
def compare_embeddings(*texts):
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
|
37 |
|
38 |
fig = go.Figure()
|
@@ -50,8 +66,9 @@ def generate_text_boxes(n):
|
|
50 |
|
51 |
with gr.Blocks() as iface:
|
52 |
gr.Markdown("# 3D Embedding Comparison")
|
53 |
-
gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using
|
54 |
|
|
|
55 |
num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
|
56 |
|
57 |
with gr.Column() as input_column:
|
@@ -72,8 +89,8 @@ with gr.Blocks() as iface:
|
|
72 |
|
73 |
compare_button.click(
|
74 |
compare_embeddings,
|
75 |
-
inputs=text_boxes,
|
76 |
outputs=output
|
77 |
)
|
78 |
|
79 |
-
iface.launch()
|
|
|
7 |
|
8 |
TOKEN = os.getenv("HF_TOKEN")
|
9 |
|
10 |
+
default_model_name = "mistralai/Mistral-7B-Instruct-v0.3"
|
11 |
+
tokenizer = None
|
12 |
model = None
|
13 |
|
|
|
|
|
|
|
|
|
14 |
@spaces.GPU(duration=300)
|
15 |
+
def get_embedding(text, model_repo):
|
16 |
+
global tokenizer, model
|
17 |
+
|
18 |
+
if tokenizer is None or model is None or model.name_or_path != model_repo:
|
19 |
+
try:
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
21 |
+
model = AutoModel.from_pretrained(model_repo, torch_dtype=torch.float16).cuda()
|
22 |
+
|
23 |
+
# Set pad token to eos token if not defined
|
24 |
+
if tokenizer.pad_token is None:
|
25 |
+
tokenizer.pad_token = tokenizer.eos_token
|
26 |
+
|
27 |
+
model.resize_token_embeddings(len(tokenizer))
|
28 |
+
except Exception as e:
|
29 |
+
return f"Error loading model: {str(e)}"
|
30 |
|
31 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
|
32 |
with torch.no_grad():
|
|
|
37 |
return embedding[:3]
|
38 |
|
39 |
@spaces.GPU
|
40 |
+
def compare_embeddings(model_repo, *texts):
|
41 |
+
if not model_repo:
|
42 |
+
model_repo = default_model_name
|
43 |
+
|
44 |
+
embeddings = []
|
45 |
+
for text in texts:
|
46 |
+
if text.strip():
|
47 |
+
emb = get_embedding(text, model_repo)
|
48 |
+
if isinstance(emb, str): # Error message
|
49 |
+
return emb
|
50 |
+
embeddings.append(emb)
|
51 |
+
|
52 |
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
|
53 |
|
54 |
fig = go.Figure()
|
|
|
66 |
|
67 |
with gr.Blocks() as iface:
|
68 |
gr.Markdown("# 3D Embedding Comparison")
|
69 |
+
gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using a custom model.")
|
70 |
|
71 |
+
model_repo_input = gr.Textbox(label="Model Repository", value=default_model_name, placeholder="Enter the model repository (e.g., mistralai/Mistral-7B-Instruct-v0.3)")
|
72 |
num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
|
73 |
|
74 |
with gr.Column() as input_column:
|
|
|
89 |
|
90 |
compare_button.click(
|
91 |
compare_embeddings,
|
92 |
+
inputs=[model_repo_input] + text_boxes,
|
93 |
outputs=output
|
94 |
)
|
95 |
|
96 |
+
iface.launch()
|