Spaces:
Sleeping
Sleeping
adding application
Browse files
app.py
CHANGED
@@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
|
|
14 |
|
15 |
from enums import MultiTokenKind, RetrievalTechniques
|
16 |
from processor import RetrievalProcessor
|
17 |
-
|
18 |
from model_utils import extract_token_i_hidden_states
|
19 |
|
20 |
|
@@ -118,15 +118,15 @@ class PatchscopesRetriever(WordRetrieverBase):
|
|
118 |
return patchscopes_description_by_layers, last_token_hidden_states
|
119 |
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
|
131 |
|
132 |
class AnalysisWordRetriever:
|
|
|
14 |
|
15 |
from enums import MultiTokenKind, RetrievalTechniques
|
16 |
from processor import RetrievalProcessor
|
17 |
+
from logit_lens import ReverseLogitLens
|
18 |
from model_utils import extract_token_i_hidden_states
|
19 |
|
20 |
|
|
|
118 |
return patchscopes_description_by_layers, last_token_hidden_states
|
119 |
|
120 |
|
121 |
+
class ReverseLogitLensRetriever(WordRetrieverBase):
|
122 |
+
def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
|
123 |
+
super().__init__(model, tokenizer)
|
124 |
+
self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
|
125 |
+
|
126 |
+
def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
|
127 |
+
result = self.reverse_logit_lens(hidden_states, layer_idx)
|
128 |
+
token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
|
129 |
+
return token
|
130 |
|
131 |
|
132 |
class AnalysisWordRetriever:
|