Sergidev commited on
Commit
edd101a
·
verified ·
1 Parent(s): d7977e8

Beta. Revert if trash.

Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -4,7 +4,8 @@ import torch
4
  from transformers import AutoTokenizer, AutoModel
5
  import plotly.graph_objects as go
6
 
7
- model_name = "mistralai/Mistral-7B-v0.1"
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = None
10
 
@@ -16,7 +17,7 @@ if tokenizer.pad_token is None:
16
  def get_embedding(text):
17
  global model
18
  if model is None:
19
- model = AutoModel.from_pretrained(model_name).cuda()
20
  model.resize_token_embeddings(len(tokenizer))
21
 
22
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
@@ -47,7 +48,7 @@ def generate_text_boxes(n):
47
 
48
  with gr.Blocks() as iface:
49
  gr.Markdown("# 3D Embedding Comparison")
50
- gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.")
51
 
52
  num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
53
 
@@ -73,4 +74,4 @@ with gr.Blocks() as iface:
73
  outputs=output
74
  )
75
 
76
- iface.launch()
 
4
  from transformers import AutoTokenizer, AutoModel
5
  import plotly.graph_objects as go
6
 
7
+ # Update the model name to Llama 3.1
8
+ model_name = "meta-llama/Meta-Llama-3.1-405B-Instruct-FP8"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = None
11
 
 
17
  def get_embedding(text):
18
  global model
19
  if model is None:
20
+ model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
21
  model.resize_token_embeddings(len(tokenizer))
22
 
23
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
 
48
 
49
  with gr.Blocks() as iface:
50
  gr.Markdown("# 3D Embedding Comparison")
51
+ gr.Markdown("Compare the embeddings of multiple strings visualized in 3D space using Llama 3.1.")
52
 
53
  num_texts = gr.Slider(minimum=2, maximum=10, step=1, value=2, label="Number of texts to compare")
54
 
 
74
  outputs=output
75
  )
76
 
77
+ iface.launch()