Guy24 commited on
Commit
682420f
·
1 Parent(s): 87899e6

adding application

Browse files
Files changed (1) hide show
  1. app.py +190 -4
app.py CHANGED
@@ -8,10 +8,196 @@ from functools import lru_cache
8
  # IMPORTANT: This version uses the PatchscopesRetriever implementation
9
  # from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words)
10
  # ----------------------------------------------------------------------
11
- try:
12
- from inner_lexicon.word_retriever import PatchscopesRetriever # pip install tokens2words
13
- except ImportError:
14
- PatchscopesRetriever = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere
17
  DEVICE = (
 
8
  # IMPORTANT: This version uses the PatchscopesRetriever implementation
9
  # from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words)
10
  # ----------------------------------------------------------------------
11
+ import torch
12
+ from tqdm import tqdm
13
+ from abc import ABC, abstractmethod
14
+
15
+ from .utils.enums import MultiTokenKind, RetrievalTechniques
16
+ from .processor import RetrievalProcessor
17
+ from .utils.logit_lens import ReverseLogitLens
18
+ from .utils.model_utils import extract_token_i_hidden_states
19
+
20
+
21
+ class WordRetrieverBase(ABC):
22
+ def __init__(self, model, tokenizer):
23
+ self.model = model
24
+ self.tokenizer = tokenizer
25
+
26
+ @abstractmethod
27
+ def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
28
+ pass
29
+
30
+
31
+ class PatchscopesRetriever(WordRetrieverBase):
32
+ def __init__(
33
+ self,
34
+ model,
35
+ tokenizer,
36
+ representation_prompt: str = "{word}",
37
+ patchscopes_prompt: str = "Next is the same word twice: 1) {word} 2)",
38
+ prompt_target_placeholder: str = "{word}",
39
+ representation_token_idx_to_extract: int = -1,
40
+ num_tokens_to_generate: int = 10,
41
+ ):
42
+ super().__init__(model, tokenizer)
43
+ self.prompt_input_ids, self.prompt_target_idx = \
44
+ self._build_prompt_input_ids_template(patchscopes_prompt, prompt_target_placeholder)
45
+ self._prepare_representation_prompt = \
46
+ self._build_representation_prompt_func(representation_prompt, prompt_target_placeholder)
47
+ self.representation_token_idx = representation_token_idx_to_extract
48
+ self.num_tokens_to_generate = num_tokens_to_generate
49
+
50
+ def _build_prompt_input_ids_template(self, prompt, target_placeholder):
51
+ prompt_input_ids = [self.tokenizer.bos_token_id] if self.tokenizer.bos_token_id is not None else []
52
+ target_idx = []
53
+
54
+ if prompt:
55
+ assert target_placeholder is not None, \
56
+ "Trying to set a prompt for Patchscopes without defining the prompt's target placeholder string, e.g., [MASK]"
57
+
58
+ prompt_parts = prompt.split(target_placeholder)
59
+ for part_i, prompt_part in enumerate(prompt_parts):
60
+ prompt_input_ids += self.tokenizer.encode(prompt_part, add_special_tokens=False)
61
+ if part_i < len(prompt_parts)-1:
62
+ target_idx += [len(prompt_input_ids)]
63
+ prompt_input_ids += [0]
64
+ else:
65
+ prompt_input_ids += [0]
66
+ target_idx = [len(prompt_input_ids)]
67
+
68
+ prompt_input_ids = torch.tensor(prompt_input_ids, dtype=torch.long)
69
+ target_idx = torch.tensor(target_idx, dtype=torch.long)
70
+ return prompt_input_ids, target_idx
71
+
72
+ def _build_representation_prompt_func(self, prompt, target_placeholder):
73
+ return lambda word: prompt.replace(target_placeholder, word)
74
+
75
+ def generate_states(self, tokenizer, word='Wakanda', with_prompt=True):
76
+ prompt = self.generate_prompt() if with_prompt else word
77
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
78
+ return input_ids
79
+
80
+ def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=None):
81
+ self.model.eval()
82
+
83
+ # insert hidden states into patchscopes prompt
84
+ if hidden_states.dim() == 1:
85
+ hidden_states = hidden_states.unsqueeze(0)
86
+
87
+ inputs_embeds = self.model.get_input_embeddings()(self.prompt_input_ids.to(self.model.device)).unsqueeze(0)
88
+ batched_patchscope_inputs = inputs_embeds.repeat(len(hidden_states), 1, 1).to(hidden_states.dtype)
89
+ batched_patchscope_inputs[:, self.prompt_target_idx] = hidden_states.unsqueeze(1).to(self.model.device)
90
+
91
+ attention_mask = (self.prompt_input_ids != self.tokenizer.eos_token_id).long().unsqueeze(0).repeat(
92
+ len(hidden_states), 1).to(self.model.device)
93
+
94
+ num_tokens_to_generate = num_tokens_to_generate if num_tokens_to_generate else self.num_tokens_to_generate
95
+
96
+ with torch.no_grad():
97
+ patchscope_outputs = self.model.generate(
98
+ do_sample=False, num_beams=1, top_p=1.0, temperature=None,
99
+ inputs_embeds=batched_patchscope_inputs,# attention_mask=attention_mask,
100
+ max_new_tokens=num_tokens_to_generate, pad_token_id=self.tokenizer.eos_token_id, )
101
+
102
+ decoded_patchscope_outputs = self.tokenizer.batch_decode(patchscope_outputs)
103
+ return decoded_patchscope_outputs
104
+
105
+ def extract_hidden_states(self, word):
106
+ representation_input = self._prepare_representation_prompt(word)
107
+
108
+ last_token_hidden_states = extract_token_i_hidden_states(
109
+ self.model, self.tokenizer, representation_input, token_idx_to_extract=self.representation_token_idx, return_dict=False, verbose=False)
110
+
111
+ return last_token_hidden_states
112
+
113
+ def get_hidden_states_and_retrieve_word(self, word, num_tokens_to_generate=None):
114
+ last_token_hidden_states = self.extract_hidden_states(word)
115
+ patchscopes_description_by_layers = self.retrieve_word(
116
+ last_token_hidden_states, num_tokens_to_generate=num_tokens_to_generate)
117
+
118
+ return patchscopes_description_by_layers, last_token_hidden_states
119
+
120
+
121
+ class ReverseLogitLensRetriever(WordRetrieverBase):
122
+ def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
123
+ super().__init__(model, tokenizer)
124
+ self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
125
+
126
+ def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
127
+ result = self.reverse_logit_lens(hidden_states, layer_idx)
128
+ token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
129
+ return token
130
+
131
+
132
+ class AnalysisWordRetriever:
133
+ def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate=1, add_context=True,
134
+ model_name='LLaMa-2B', device='cuda', dataset=None):
135
+ self.model = model.to(device)
136
+ self.tokenizer = tokenizer
137
+ self.multi_token_kind = multi_token_kind
138
+ self.num_tokens_to_generate = num_tokens_to_generate
139
+ self.add_context = add_context
140
+ self.model_name = model_name
141
+ self.device = device
142
+ self.dataset = dataset
143
+ self.retriever = self._initialize_retriever()
144
+ self.RetrievalTechniques = (RetrievalTechniques.Patchscopes if self.multi_token_kind == MultiTokenKind.Natural
145
+ else RetrievalTechniques.ReverseLogitLens)
146
+ self.whitespace_token = 'Ġ' if model_name in ['gemma-2-9b', 'pythia-6.9b', 'LLaMA3-8B', 'Yi-6B'] else '▁'
147
+ self.processor = RetrievalProcessor(self.model, self.tokenizer, self.multi_token_kind,
148
+ self.num_tokens_to_generate, self.add_context, self.model_name,
149
+ self.whitespace_token)
150
+
151
+ def _initialize_retriever(self):
152
+ if self.multi_token_kind == MultiTokenKind.Natural:
153
+ return PatchscopesRetriever(self.model, self.tokenizer)
154
+ else:
155
+ return ReverseLogitLensRetriever(self.model, self.tokenizer)
156
+
157
+ def retrieve_words_in_dataset(self, number_of_examples_to_retrieve=2, max_length=1000):
158
+ self.model.eval()
159
+ results = []
160
+
161
+ for text in tqdm(self.dataset['train']['text'][:number_of_examples_to_retrieve], self.model_name):
162
+ tokenized_input = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length).to(
163
+ self.device)
164
+ tokens = tokenized_input.input_ids[0]
165
+ print(f'Processing text: {text}')
166
+ i = 5
167
+ while i < len(tokens):
168
+ if self.multi_token_kind == MultiTokenKind.Natural:
169
+ j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_word(
170
+ tokens, i, device=self.device)
171
+ elif self.multi_token_kind == MultiTokenKind.Typo:
172
+ j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_typo(
173
+ tokens, i, device=self.device)
174
+ else:
175
+ j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_separated(
176
+ tokens, i, device=self.device)
177
+
178
+ if len(word_tokens) > 1:
179
+ with torch.no_grad():
180
+ outputs = self.model(**tokenized_combined_text, output_hidden_states=True)
181
+
182
+ hidden_states = outputs.hidden_states
183
+ for layer_idx, hidden_state in enumerate(hidden_states):
184
+ postfix_hidden_state = hidden_states[layer_idx][0, -1, :].unsqueeze(0)
185
+ retrieved_word_str = self.retriever.retrieve_word(postfix_hidden_state, layer_idx=layer_idx,
186
+ num_tokens_to_generate=len(word_tokens))
187
+ results.append({
188
+ 'text': combined_text,
189
+ 'original_word': original_word,
190
+ 'word': word,
191
+ 'word_tokens': self.tokenizer.convert_ids_to_tokens(word_tokens),
192
+ 'num_tokens': len(word_tokens),
193
+ 'layer': layer_idx,
194
+ 'retrieved_word_str': retrieved_word_str,
195
+ 'context': "With Context" if self.add_context else "Without Context"
196
+ })
197
+ else:
198
+ i = j
199
+ return results
200
+
201
 
202
  DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere
203
  DEVICE = (