Update app.py
Browse files
app.py
CHANGED
@@ -1,39 +1,20 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
-
from transformers import AutoTokenizer, AutoModel
|
4 |
import plotly.graph_objects as go
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
# Set pad token to eos token if not defined
|
11 |
-
if tokenizer.pad_token is None:
|
12 |
-
tokenizer.pad_token = tokenizer.eos_token
|
13 |
-
|
14 |
-
def get_embedding(text):
|
15 |
-
global model
|
16 |
-
if model is None:
|
17 |
-
model = AutoModel.from_pretrained(model_name)
|
18 |
-
model.resize_token_embeddings(len(tokenizer))
|
19 |
-
|
20 |
-
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
21 |
-
with torch.no_grad():
|
22 |
-
outputs = model(**inputs)
|
23 |
-
return outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
|
24 |
-
|
25 |
-
def reduce_to_3d(embedding):
|
26 |
-
return embedding[:3]
|
27 |
|
28 |
def compare_embeddings(*texts):
|
29 |
-
embeddings = [
|
30 |
-
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
|
31 |
|
32 |
fig = go.Figure()
|
33 |
|
34 |
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
|
35 |
|
36 |
-
for i, emb in enumerate(
|
37 |
color = colors[i % len(colors)]
|
38 |
fig.add_trace(go.Scatter3d(
|
39 |
x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
|
@@ -48,8 +29,9 @@ def compare_embeddings(*texts):
|
|
48 |
return fig
|
49 |
|
50 |
with gr.Blocks() as iface:
|
51 |
-
gr.Markdown("# 3D Embedding Comparison")
|
52 |
-
gr.Markdown("Compare
|
|
|
53 |
|
54 |
with gr.Row():
|
55 |
num_inputs = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
import plotly.graph_objects as go
|
3 |
+
import hashlib
|
4 |
|
5 |
+
def simple_embedding(text, dim=3):
|
6 |
+
"""A simple hash-based embedding function for demonstration purposes."""
|
7 |
+
hash_value = hashlib.md5(text.encode()).hexdigest()
|
8 |
+
return [int(hash_value[i:i+2], 16) / 255.0 for i in range(0, dim*2, 2)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def compare_embeddings(*texts):
|
11 |
+
embeddings = [simple_embedding(text) for text in texts if text.strip()] # Only process non-empty texts
|
|
|
12 |
|
13 |
fig = go.Figure()
|
14 |
|
15 |
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
|
16 |
|
17 |
+
for i, emb in enumerate(embeddings):
|
18 |
color = colors[i % len(colors)]
|
19 |
fig.add_trace(go.Scatter3d(
|
20 |
x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
|
|
|
29 |
return fig
|
30 |
|
31 |
with gr.Blocks() as iface:
|
32 |
+
gr.Markdown("# 3D Embedding Comparison (Simplified)")
|
33 |
+
gr.Markdown("Compare simplified embeddings of multiple strings visualized in 3D space.")
|
34 |
+
gr.Markdown("Note: This is a demonstration using a basic hash-based embedding, not a real NLP model.")
|
35 |
|
36 |
with gr.Row():
|
37 |
num_inputs = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
|