Sergidev commited on
Commit
35042da
·
verified ·
1 Parent(s): 5cea6ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -28
app.py CHANGED
@@ -1,39 +1,20 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModel
4
  import plotly.graph_objects as go
 
5
 
6
- model_name = "mistralai/Mistral-7B-v0.1"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = None
9
-
10
- # Set pad token to eos token if not defined
11
- if tokenizer.pad_token is None:
12
- tokenizer.pad_token = tokenizer.eos_token
13
-
14
- def get_embedding(text):
15
- global model
16
- if model is None:
17
- model = AutoModel.from_pretrained(model_name)
18
- model.resize_token_embeddings(len(tokenizer))
19
-
20
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
21
- with torch.no_grad():
22
- outputs = model(**inputs)
23
- return outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
24
-
25
- def reduce_to_3d(embedding):
26
- return embedding[:3]
27
 
28
  def compare_embeddings(*texts):
29
- embeddings = [get_embedding(text) for text in texts if text.strip()] # Only process non-empty texts
30
- embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
31
 
32
  fig = go.Figure()
33
 
34
  colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
35
 
36
- for i, emb in enumerate(embeddings_3d):
37
  color = colors[i % len(colors)]
38
  fig.add_trace(go.Scatter3d(
39
  x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
@@ -48,8 +29,9 @@ def compare_embeddings(*texts):
48
  return fig
49
 
50
  with gr.Blocks() as iface:
51
- gr.Markdown("# 3D Embedding Comparison")
52
- gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
 
53
 
54
  with gr.Row():
55
  num_inputs = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
 
1
  import gradio as gr
 
 
2
  import plotly.graph_objects as go
3
+ import hashlib
4
 
5
+ def simple_embedding(text, dim=3):
6
+ """A simple hash-based embedding function for demonstration purposes."""
7
+ hash_value = hashlib.md5(text.encode()).hexdigest()
8
+ return [int(hash_value[i:i+2], 16) / 255.0 for i in range(0, dim*2, 2)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def compare_embeddings(*texts):
11
+ embeddings = [simple_embedding(text) for text in texts if text.strip()] # Only process non-empty texts
 
12
 
13
  fig = go.Figure()
14
 
15
  colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
16
 
17
+ for i, emb in enumerate(embeddings):
18
  color = colors[i % len(colors)]
19
  fig.add_trace(go.Scatter3d(
20
  x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
 
29
  return fig
30
 
31
  with gr.Blocks() as iface:
32
+ gr.Markdown("# 3D Embedding Comparison (Simplified)")
33
+ gr.Markdown("Compare simplified embeddings of multiple strings visualized in 3D space.")
34
+ gr.Markdown("Note: This is a demonstration using a basic hash-based embedding, not a real NLP model.")
35
 
36
  with gr.Row():
37
  num_inputs = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")