Sergidev commited on
Commit
748826b
·
verified ·
1 Parent(s): 7103ccc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -5
app.py CHANGED
@@ -1,14 +1,56 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
 
 
 
 
 
4
 
5
  zero = torch.Tensor([0]).cuda()
6
  print(zero.device) # <-- 'cpu' 🤔
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from sklearn.decomposition import PCA
6
+ import plotly.graph_objects as go
7
+ from huggingface_hub import HfApi
8
+ from huggingface_hub import hf_hub_download
9
+ import os
10
+ import sys
11
 
12
  zero = torch.Tensor([0]).cuda()
13
  print(zero.device) # <-- 'cpu' 🤔
14
 
15
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModel.from_pretrained(model_name)
18
+
19
+ @spaces.GPU
20
+ def get_embedding(text):
21
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
25
+
26
+ def compress_to_3d(embedding):
27
+ pca = PCA(n_components=3)
28
+ return pca.fit_transform(embedding.reshape(1, -1))[0]
29
+
30
  @spaces.GPU
31
+ def compare_embeddings(text1, text2):
32
+ emb1 = get_embedding(text1)
33
+ emb2 = get_embedding(text2)
34
+
35
+ emb1_3d = compress_to_3d(emb1)
36
+ emb2_3d = compress_to_3d(emb2)
37
+
38
+ fig = go.Figure(data=[
39
+ go.Scatter3d(x=[0, emb1_3d[0]], y=[0, emb1_3d[1]], z=[0, emb1_3d[2]], mode='lines+markers', name='Text 1'),
40
+ go.Scatter3d(x=[0, emb2_3d[0]], y=[0, emb2_3d[1]], z=[0, emb2_3d[2]], mode='lines+markers', name='Text 2')
41
+ ])
42
+
43
+ fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
44
+
45
+ return fig
46
 
47
+ iface = gr.Interface(
48
+ fn=compare_embeddings,
49
+ inputs=[
50
+ gr.Textbox(label="Text 1"),
51
+ gr.Textbox(label="Text 2")
52
+ ],
53
+ outputs=gr.Plot(),
54
+ title="3D Embedding Comparison",
55
+ description="Compare the embeddings of two strings visualized in 3D space."
56
+ )