Guy24 commited on
Commit
58852d6
·
1 Parent(s): 3e871e6

adding application

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
14
 
15
  from enums import MultiTokenKind, RetrievalTechniques
16
  from processor import RetrievalProcessor
17
- # from .utils.logit_lens import ReverseLogitLens
18
  from model_utils import extract_token_i_hidden_states
19
 
20
 
@@ -118,15 +118,15 @@ class PatchscopesRetriever(WordRetrieverBase):
118
  return patchscopes_description_by_layers, last_token_hidden_states
119
 
120
 
121
- # class ReverseLogitLensRetriever(WordRetrieverBase):
122
- # def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
123
- # super().__init__(model, tokenizer)
124
- # self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
125
- #
126
- # def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
127
- # result = self.reverse_logit_lens(hidden_states, layer_idx)
128
- # token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
129
- # return token
130
 
131
 
132
  class AnalysisWordRetriever:
 
14
 
15
  from enums import MultiTokenKind, RetrievalTechniques
16
  from processor import RetrievalProcessor
17
+ from logit_lens import ReverseLogitLens
18
  from model_utils import extract_token_i_hidden_states
19
 
20
 
 
118
  return patchscopes_description_by_layers, last_token_hidden_states
119
 
120
 
121
+ class ReverseLogitLensRetriever(WordRetrieverBase):
122
+ def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
123
+ super().__init__(model, tokenizer)
124
+ self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
125
+
126
+ def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
127
+ result = self.reverse_logit_lens(hidden_states, layer_idx)
128
+ token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
129
+ return token
130
 
131
 
132
  class AnalysisWordRetriever: