Sergidev commited on
Commit
416fea8
·
verified ·
1 Parent(s): d590a55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import AutoTokenizer, AutoModel
5
- from sklearn.decomposition import PCA
6
  import plotly.graph_objects as go
7
  from huggingface_hub import HfApi
8
  from huggingface_hub import hf_hub_download
@@ -11,7 +10,7 @@ import sys
11
 
12
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- model = None # We'll load the model inside the GPU-enabled function
15
 
16
  @spaces.GPU
17
  def get_embedding(text):
@@ -24,17 +23,17 @@ def get_embedding(text):
24
  outputs = model(**inputs)
25
  return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
26
 
27
- def compress_to_3d(embedding):
28
- pca = PCA(n_components=3)
29
- return pca.fit_transform(embedding.reshape(1, -1))[0]
30
 
31
  @spaces.GPU
32
  def compare_embeddings(text1, text2):
33
  emb1 = get_embedding(text1)
34
  emb2 = get_embedding(text2)
35
 
36
- emb1_3d = compress_to_3d(emb1)
37
- emb2_3d = compress_to_3d(emb2)
38
 
39
  fig = go.Figure(data=[
40
  go.Scatter3d(x=[0, emb1_3d[0]], y=[0, emb1_3d[1]], z=[0, emb1_3d[2]], mode='lines+markers', name='Text 1'),
 
2
  import spaces
3
  import torch
4
  from transformers import AutoTokenizer, AutoModel
 
5
  import plotly.graph_objects as go
6
  from huggingface_hub import HfApi
7
  from huggingface_hub import hf_hub_download
 
10
 
11
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = None
14
 
15
  @spaces.GPU
16
  def get_embedding(text):
 
23
  outputs = model(**inputs)
24
  return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
25
 
26
+ def reduce_to_3d(embedding):
27
+ # Instead of PCA, we'll just take the first 3 dimensions
28
+ return embedding[:3]
29
 
30
  @spaces.GPU
31
  def compare_embeddings(text1, text2):
32
  emb1 = get_embedding(text1)
33
  emb2 = get_embedding(text2)
34
 
35
+ emb1_3d = reduce_to_3d(emb1)
36
+ emb2_3d = reduce_to_3d(emb2)
37
 
38
  fig = go.Figure(data=[
39
  go.Scatter3d(x=[0, emb1_3d[0]], y=[0, emb1_3d[1]], z=[0, emb1_3d[2]], mode='lines+markers', name='Text 1'),