Sergidev commited on
Commit
673350b
·
verified ·
1 Parent(s): ecb8d51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -7,20 +7,26 @@ import plotly.graph_objects as go
7
 
8
  TOKEN = os.getenv("HF_TOKEN")
9
 
10
- model_name = "mistralai/Mistral-7B-v0.3"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = None
13
 
14
- # Set pad token to eos token if not defined
15
- if tokenizer.pad_token is None:
16
- tokenizer.pad_token = tokenizer.eos_token
17
-
18
  @spaces.GPU(duration=300)
19
- def get_embedding(text):
20
- global model
21
- if model is None:
22
- model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
23
- model.resize_token_embeddings(len(tokenizer))
 
 
 
 
 
 
 
 
 
 
24
 
25
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
26
  with torch.no_grad():
@@ -31,8 +37,18 @@ def reduce_to_3d(embedding):
31
  return embedding[:3]
32
 
33
  @spaces.GPU
34
- def compare_embeddings(*texts):
35
- embeddings = [get_embedding(text) for text in texts if text.strip()]
 
 
 
 
 
 
 
 
 
 
36
  embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
37
 
38
  fig = go.Figure()
@@ -50,8 +66,9 @@ def generate_text_boxes(n):
50
 
51
  with gr.Blocks() as iface:
52
  gr.Markdown("# 3D Embedding Comparison")
53
- gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Llama 3.1.")
54
 
 
55
  num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
56
 
57
  with gr.Column() as input_column:
@@ -72,8 +89,8 @@ with gr.Blocks() as iface:
72
 
73
  compare_button.click(
74
  compare_embeddings,
75
- inputs=text_boxes,
76
  outputs=output
77
  )
78
 
79
- iface.launch()
 
7
 
8
  TOKEN = os.getenv("HF_TOKEN")
9
 
10
+ default_model_name = "mistralai/Mistral-7B-Instruct-v0.3"
11
+ tokenizer = None
12
  model = None
13
 
 
 
 
 
14
  @spaces.GPU(duration=300)
15
+ def get_embedding(text, model_repo):
16
+ global tokenizer, model
17
+
18
+ if tokenizer is None or model is None or model.name_or_path != model_repo:
19
+ try:
20
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
21
+ model = AutoModel.from_pretrained(model_repo, torch_dtype=torch.float16).cuda()
22
+
23
+ # Set pad token to eos token if not defined
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ model.resize_token_embeddings(len(tokenizer))
28
+ except Exception as e:
29
+ return f"Error loading model: {str(e)}"
30
 
31
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
32
  with torch.no_grad():
 
37
  return embedding[:3]
38
 
39
  @spaces.GPU
40
+ def compare_embeddings(model_repo, *texts):
41
+ if not model_repo:
42
+ model_repo = default_model_name
43
+
44
+ embeddings = []
45
+ for text in texts:
46
+ if text.strip():
47
+ emb = get_embedding(text, model_repo)
48
+ if isinstance(emb, str): # Error message
49
+ return emb
50
+ embeddings.append(emb)
51
+
52
  embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
53
 
54
  fig = go.Figure()
 
66
 
67
  with gr.Blocks() as iface:
68
  gr.Markdown("# 3D Embedding Comparison")
69
+ gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using a custom model.")
70
 
71
+ model_repo_input = gr.Textbox(label="Model Repository", value=default_model_name, placeholder="Enter the model repository (e.g., mistralai/Mistral-7B-Instruct-v0.3)")
72
  num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
73
 
74
  with gr.Column() as input_column:
 
89
 
90
  compare_button.click(
91
  compare_embeddings,
92
+ inputs=[model_repo_input] + text_boxes,
93
  outputs=output
94
  )
95
 
96
+ iface.launch()