Update README.md
Browse files
README.md
CHANGED
@@ -9,7 +9,8 @@ To do:
|
|
9 |
4. Loss function.
|
10 |
|
11 |
To run TE_Embedding model:
|
12 |
-
|
|
|
13 |
from transformers import (AutoConfig,
|
14 |
AutoTokenizer,AutoModelForCausalLM
|
15 |
)
|
@@ -30,7 +31,7 @@ class TEmbeddingModel(torch.nn.Module):
|
|
30 |
[torch.nn.Linear(self.hidden_size, self.hidden_size//len(self.prompt_suffixes))
|
31 |
for _ in range(len(self.prompt_suffixes))])
|
32 |
self.tokenizer, self.llama = self.load_llama()
|
33 |
-
self.device = torch.device('cuda')
|
34 |
self.tanh = torch.nn.Tanh()
|
35 |
self.suffixes_ids = []
|
36 |
self.suffixes_ids_len = []
|
@@ -79,12 +80,12 @@ class TEmbeddingModel(torch.nn.Module):
|
|
79 |
suffixes_ones = self.suffixes_ones.unsqueeze(0)
|
80 |
suffixes_ones = suffixes_ones.repeat(batch_size, 1)
|
81 |
device = next(self.parameters()).device
|
82 |
-
attention_mask = torch.cat([attention_mask, suffixes_ones], dim=-1).to(
|
83 |
|
84 |
suffixes_ids = self.suffixes_ids.unsqueeze(0)
|
85 |
suffixes_ids = suffixes_ids.repeat(batch_size, 1)
|
86 |
-
input_ids = torch.cat([input_ids, suffixes_ids], dim=-1)
|
87 |
-
last_hidden_state = self.llama.base_model(attention_mask=attention_mask, input_ids=input_ids).last_hidden_state.to(
|
88 |
index = -1
|
89 |
for i in range(len(self.suffixes_ids_len)):
|
90 |
embedding = last_hidden_state[:, index, :]
|
@@ -119,4 +120,5 @@ if __name__ == "__main__":
|
|
119 |
output = TE_model(["Hello", "Nice to meet you"])
|
120 |
cos_sim = F.cosine_similarity(output[0],output[1],dim=0)
|
121 |
print(cos_sim)
|
122 |
-
|
|
|
|
9 |
4. Loss function.
|
10 |
|
11 |
To run TE_Embedding model:
|
12 |
+
```python
|
13 |
+
import os
|
14 |
from transformers import (AutoConfig,
|
15 |
AutoTokenizer,AutoModelForCausalLM
|
16 |
)
|
|
|
31 |
[torch.nn.Linear(self.hidden_size, self.hidden_size//len(self.prompt_suffixes))
|
32 |
for _ in range(len(self.prompt_suffixes))])
|
33 |
self.tokenizer, self.llama = self.load_llama()
|
34 |
+
# self.device = torch.device('cuda')
|
35 |
self.tanh = torch.nn.Tanh()
|
36 |
self.suffixes_ids = []
|
37 |
self.suffixes_ids_len = []
|
|
|
80 |
suffixes_ones = self.suffixes_ones.unsqueeze(0)
|
81 |
suffixes_ones = suffixes_ones.repeat(batch_size, 1)
|
82 |
device = next(self.parameters()).device
|
83 |
+
attention_mask = torch.cat([attention_mask, suffixes_ones], dim=-1).to(device)
|
84 |
|
85 |
suffixes_ids = self.suffixes_ids.unsqueeze(0)
|
86 |
suffixes_ids = suffixes_ids.repeat(batch_size, 1)
|
87 |
+
input_ids = torch.cat([input_ids, suffixes_ids], dim=-1) #to("cuda")
|
88 |
+
last_hidden_state = self.llama.base_model(attention_mask=attention_mask, input_ids=input_ids).last_hidden_state.to(device)
|
89 |
index = -1
|
90 |
for i in range(len(self.suffixes_ids_len)):
|
91 |
embedding = last_hidden_state[:, index, :]
|
|
|
120 |
output = TE_model(["Hello", "Nice to meet you"])
|
121 |
cos_sim = F.cosine_similarity(output[0],output[1],dim=0)
|
122 |
print(cos_sim)
|
123 |
+
|
124 |
+
```
|