s-a-malik commited on
Commit
fa78257
·
1 Parent(s): b874271
Files changed (2) hide show
  1. app.py +396 -211
  2. app_sep.py +36 -42
app.py CHANGED
@@ -1,228 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
- from typing import List, Optional, Tuple
 
3
 
4
  import spaces
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
- from sudachipy import dictionary
9
- from sudachipy import tokenizer as sudachi_tokenizer
10
- from transformers import AutoModelForCausalLM, PreTrainedTokenizer, T5Tokenizer
11
-
12
-
13
- model_dir = Path(__file__).parents[0] / "model"
14
- device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
15
- tokenizer = T5Tokenizer.from_pretrained(model_dir)
16
- tokenizer.do_lower_case = True
17
- trained_model = AutoModelForCausalLM.from_pretrained(model_dir)
18
- trained_model.to(device)
19
-
20
- # baseline model
21
- baseline_model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
22
- baseline_model.to(device)
23
-
24
- sudachi_tokenizer_obj = dictionary.Dictionary().create()
25
- mode = sudachi_tokenizer.Tokenizer.SplitMode.C
26
-
27
-
28
- def sudachi_tokenize(input_text: str) -> List[str]:
29
- morphemes = sudachi_tokenizer_obj.tokenize(input_text, mode)
30
- return [morpheme.surface() for morpheme in morphemes]
31
-
32
-
33
- def calc_offsets(tokens: List[str]) -> List[int]:
34
- offsets = [0]
35
- for token in tokens:
36
- offsets.append(offsets[-1] + len(token))
37
- return offsets
38
-
39
-
40
- def distribute_surprisals_to_characters(
41
- tokens2surprisal: List[Tuple[str, float]]
42
- ) -> List[Tuple[str, float]]:
43
- tokens2surprisal_by_character: List[Tuple[str, float]] = []
44
- for token, surprisal in tokens2surprisal:
45
- token_len = len(token)
46
- for character in token:
47
- tokens2surprisal_by_character.append((character, surprisal / token_len))
48
- return tokens2surprisal_by_character
49
-
50
-
51
- def calculate_surprisals_by_character(
52
- input_text: str, model: AutoModelForCausalLM, tokenizer: PreTrainedTokenizer
53
- ) -> Tuple[float, List[Tuple[str, float]]]:
54
- input_tokens = [
55
- token.replace("", "")
56
- for token in tokenizer.tokenize(input_text)
57
- if token != ""
58
- ]
59
- input_ids = tokenizer.encode(
60
- "<s>" + input_text, add_special_tokens=False, return_tensors="pt"
61
- ).to(device)
62
-
63
- logits = model(input_ids)["logits"].squeeze(0)
64
-
65
- surprisals = []
66
- for i in range(logits.shape[0] - 1):
67
- if input_ids[0][i + 1] == 9:
68
- continue
69
- logit = logits[i]
70
- prob = torch.softmax(logit, dim=0)
71
- neg_logprob = -torch.log(prob)
72
- surprisals.append(neg_logprob[input_ids[0][i + 1]].item())
73
- mean_surprisal = np.mean(surprisals)
74
-
75
- tokens2surprisal: List[Tuple[str, float]] = []
76
- for token, surprisal in zip(input_tokens, surprisals):
77
- tokens2surprisal.append((token, surprisal))
78
-
79
- char2surprisal = distribute_surprisals_to_characters(tokens2surprisal)
80
-
81
- return mean_surprisal, char2surprisal
82
-
83
-
84
- def aggregate_surprisals_by_offset(
85
- char2surprisal: List[Tuple[str, float]], offsets: List[int]
86
- ) -> List[Tuple[str, float]]:
87
- tokens2surprisal = []
88
- for i in range(len(offsets) - 1):
89
- start = offsets[i]
90
- end = offsets[i + 1]
91
- surprisal = sum([surprisal for _, surprisal in char2surprisal[start:end]])
92
- token = "".join([char for char, _ in char2surprisal[start:end]])
93
- tokens2surprisal.append((token, surprisal))
94
-
95
- return tokens2surprisal
96
-
97
-
98
- def highlight_token(token: str, score: float):
99
- if score > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  html_color = "#%02X%02X%02X" % (
101
  255,
102
- int(255 * (1 - score)),
103
- int(255 * (1 - score)),
104
  )
105
  else:
106
  html_color = "#%02X%02X%02X" % (
107
- int(255 * (1 + score)),
108
  255,
109
- int(255 * (1 + score)),
110
  )
111
  return '<span style="background-color: {}; color: black">{}</span>'.format(
112
- html_color, token
113
- )
114
-
115
-
116
- def create_highlighted_text(
117
- label: str,
118
- tokens2scores: List[Tuple[str, float]],
119
- mean_surprisal: Optional[float] = None,
120
- ):
121
- if mean_surprisal is None:
122
- highlighted_text = "<h2><b>" + label + "</b></h2>"
123
- else:
124
- highlighted_text = (
125
- "<h2><b>" + label + f"</b>(サプライザル平均値: {mean_surprisal:.3f})</h2>"
126
- )
127
- for token, score in tokens2scores:
128
- highlighted_text += highlight_token(token, score)
129
- return highlighted_text
130
-
131
-
132
- def normalize_surprisals(
133
- tokens2surprisal: List[Tuple[str, float]], log_scale: bool = False
134
- ) -> List[Tuple[str, float]]:
135
- if log_scale:
136
- surprisals = [np.log(surprisal) for _, surprisal in tokens2surprisal]
137
- else:
138
- surprisals = [surprisal for _, surprisal in tokens2surprisal]
139
- min_surprisal = np.min(surprisals)
140
- max_surprisal = np.max(surprisals)
141
- surprisals = [
142
- (surprisal - min_surprisal) / (max_surprisal - min_surprisal)
143
- for surprisal in surprisals
144
- ]
145
- assert min(surprisals) >= 0
146
- assert max(surprisals) <= 1
147
- return [
148
- (token, surprisal)
149
- for (token, _), surprisal in zip(tokens2surprisal, surprisals)
150
- ]
151
-
152
-
153
- def calculate_surprisal_diff(
154
- tokens2surprisal: List[Tuple[str, float]],
155
- baseline_tokens2surprisal: List[Tuple[str, float]],
156
- scale: float = 100.0,
157
- ):
158
- diff_tokens2surprisal = [
159
- (token, (surprisal - baseline_surprisal) * 100)
160
- for (token, surprisal), (_, baseline_surprisal) in zip(
161
- tokens2surprisal, baseline_tokens2surprisal
162
- )
163
- ]
164
- return diff_tokens2surprisal
165
-
166
- @spaces.GPU
167
- def main(input_text: str) -> Tuple[str, str, str]:
168
- mean_surprisal, char2surprisal = calculate_surprisals_by_character(
169
- input_text, trained_model, tokenizer
170
- )
171
- offsets = calc_offsets(sudachi_tokenize(input_text))
172
- tokens2surprisal = aggregate_surprisals_by_offset(char2surprisal, offsets)
173
- tokens2surprisal = normalize_surprisals(tokens2surprisal)
174
-
175
- highlighted_text = create_highlighted_text(
176
- "学習後モデル", tokens2surprisal, mean_surprisal
177
  )
178
-
179
- (
180
- baseline_mean_surprisal,
181
- baseline_char2surprisal,
182
- ) = calculate_surprisals_by_character(input_text, baseline_model, tokenizer)
183
- baseline_tokens2surprisal = aggregate_surprisals_by_offset(
184
- baseline_char2surprisal, offsets
185
- )
186
- baseline_tokens2surprisal = normalize_surprisals(baseline_tokens2surprisal)
187
- baseline_highlighted_text = create_highlighted_text(
188
- "学習前モデル", baseline_tokens2surprisal, baseline_mean_surprisal
189
- )
190
-
191
- diff_tokens2surprisal = calculate_surprisal_diff(
192
- tokens2surprisal, baseline_tokens2surprisal, 100.0
193
- )
194
- diff_highlighted_text = create_highlighted_text(
195
- "学習前後の差分", diff_tokens2surprisal, None
196
- )
197
- return (
198
- baseline_highlighted_text,
199
- highlighted_text,
200
- diff_highlighted_text,
201
- )
202
-
203
-
204
- if __name__ == "__main__":
205
- demo = gr.Interface(
206
- fn=main,
207
- title="文章の読みやすさを自動評価するAI",
208
- description="文章を入力すると、読みづらい表現は赤く、読みやすい表現は青くハイライトされて出力されます。",
209
- # show_label=True,
210
- inputs=gr.Textbox(
211
- lines=5,
212
- label="文章",
213
- placeholder="ここに文章を入力してください。",
214
  ),
215
- outputs=[
216
- gr.HTML(label="学習前モデル", show_label=True),
217
- gr.HTML(label="学習後モデル", show_label=True),
218
- gr.HTML(label="学習前後の差分", show_label=True),
219
- ],
220
- examples=[
221
- "太郎が二郎を殴った。",
222
- "太郎が二郎に殴った。",
223
- "サイエンスインパクトラボは、国立研究開発法人科学技術振興機構(JST)の「科学と社会」推進部が行う共創プログラムです。「先端の研究開発を行う研究者」と「社会課題解決に取り組むプレイヤー」が約3ヶ月に渡って共創活動を行います。",
224
- "近年、ニューラル言語モデルが自然言語の統語知識をどれほど有しているかを、容認性判断課題を通して検証する研究が行われてきている。しかし、このような言語モデルの統語的評価を行うためのデータセットは、主に英語を中心とした欧米の諸言語を対象に構築されてきた。本研究では、既存のデータセットの問題点を克服しつつ、このようなデータセットが構築されてこなかった日本語を対象とした初めてのデータセットである JCoLA (JapaneseCorpus of Linguistic Acceptability) を構築した上で、それを用いた言語モデルの統語的評価を行った。",
225
- ],
226
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- demo.launch()
 
 
1
+ # from pathlib import Path
2
+ # from typing import List, Optional, Tuple
3
+
4
+ # import spaces
5
+ # import gradio as gr
6
+ # import numpy as np
7
+ # import torch
8
+ # from sudachipy import dictionary
9
+ # from sudachipy import tokenizer as sudachi_tokenizer
10
+ # from transformers import AutoModelForCausalLM, PreTrainedTokenizer, T5Tokenizer
11
+
12
+
13
+ # model_dir = Path(__file__).parents[0] / "model"
14
+ # device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
15
+ # tokenizer = T5Tokenizer.from_pretrained(model_dir)
16
+ # tokenizer.do_lower_case = True
17
+ # trained_model = AutoModelForCausalLM.from_pretrained(model_dir)
18
+ # trained_model.to(device)
19
+
20
+ # # baseline model
21
+ # baseline_model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
22
+ # baseline_model.to(device)
23
+
24
+ # sudachi_tokenizer_obj = dictionary.Dictionary().create()
25
+ # mode = sudachi_tokenizer.Tokenizer.SplitMode.C
26
+
27
+
28
+ # def sudachi_tokenize(input_text: str) -> List[str]:
29
+ # morphemes = sudachi_tokenizer_obj.tokenize(input_text, mode)
30
+ # return [morpheme.surface() for morpheme in morphemes]
31
+
32
+
33
+ # def calc_offsets(tokens: List[str]) -> List[int]:
34
+ # offsets = [0]
35
+ # for token in tokens:
36
+ # offsets.append(offsets[-1] + len(token))
37
+ # return offsets
38
+
39
+
40
+ # def distribute_surprisals_to_characters(
41
+ # tokens2surprisal: List[Tuple[str, float]]
42
+ # ) -> List[Tuple[str, float]]:
43
+ # tokens2surprisal_by_character: List[Tuple[str, float]] = []
44
+ # for token, surprisal in tokens2surprisal:
45
+ # token_len = len(token)
46
+ # for character in token:
47
+ # tokens2surprisal_by_character.append((character, surprisal / token_len))
48
+ # return tokens2surprisal_by_character
49
+
50
+
51
+ # def calculate_surprisals_by_character(
52
+ # input_text: str, model: AutoModelForCausalLM, tokenizer: PreTrainedTokenizer
53
+ # ) -> Tuple[float, List[Tuple[str, float]]]:
54
+ # input_tokens = [
55
+ # token.replace("▁", "")
56
+ # for token in tokenizer.tokenize(input_text)
57
+ # if token != "▁"
58
+ # ]
59
+ # input_ids = tokenizer.encode(
60
+ # "<s>" + input_text, add_special_tokens=False, return_tensors="pt"
61
+ # ).to(device)
62
+
63
+ # logits = model(input_ids)["logits"].squeeze(0)
64
+
65
+ # surprisals = []
66
+ # for i in range(logits.shape[0] - 1):
67
+ # if input_ids[0][i + 1] == 9:
68
+ # continue
69
+ # logit = logits[i]
70
+ # prob = torch.softmax(logit, dim=0)
71
+ # neg_logprob = -torch.log(prob)
72
+ # surprisals.append(neg_logprob[input_ids[0][i + 1]].item())
73
+ # mean_surprisal = np.mean(surprisals)
74
+
75
+ # tokens2surprisal: List[Tuple[str, float]] = []
76
+ # for token, surprisal in zip(input_tokens, surprisals):
77
+ # tokens2surprisal.append((token, surprisal))
78
+
79
+ # char2surprisal = distribute_surprisals_to_characters(tokens2surprisal)
80
+
81
+ # return mean_surprisal, char2surprisal
82
+
83
+
84
+ # def aggregate_surprisals_by_offset(
85
+ # char2surprisal: List[Tuple[str, float]], offsets: List[int]
86
+ # ) -> List[Tuple[str, float]]:
87
+ # tokens2surprisal = []
88
+ # for i in range(len(offsets) - 1):
89
+ # start = offsets[i]
90
+ # end = offsets[i + 1]
91
+ # surprisal = sum([surprisal for _, surprisal in char2surprisal[start:end]])
92
+ # token = "".join([char for char, _ in char2surprisal[start:end]])
93
+ # tokens2surprisal.append((token, surprisal))
94
+
95
+ # return tokens2surprisal
96
+
97
+
98
+ # def highlight_token(token: str, score: float):
99
+ # if score > 0:
100
+ # html_color = "#%02X%02X%02X" % (
101
+ # 255,
102
+ # int(255 * (1 - score)),
103
+ # int(255 * (1 - score)),
104
+ # )
105
+ # else:
106
+ # html_color = "#%02X%02X%02X" % (
107
+ # int(255 * (1 + score)),
108
+ # 255,
109
+ # int(255 * (1 + score)),
110
+ # )
111
+ # return '<span style="background-color: {}; color: black">{}</span>'.format(
112
+ # html_color, token
113
+ # )
114
+
115
+
116
+ # def create_highlighted_text(
117
+ # label: str,
118
+ # tokens2scores: List[Tuple[str, float]],
119
+ # mean_surprisal: Optional[float] = None,
120
+ # ):
121
+ # if mean_surprisal is None:
122
+ # highlighted_text = "<h2><b>" + label + "</b></h2>"
123
+ # else:
124
+ # highlighted_text = (
125
+ # "<h2><b>" + label + f"</b>(サプライザル平均値: {mean_surprisal:.3f})</h2>"
126
+ # )
127
+ # for token, score in tokens2scores:
128
+ # highlighted_text += highlight_token(token, score)
129
+ # return highlighted_text
130
+
131
+
132
+ # def normalize_surprisals(
133
+ # tokens2surprisal: List[Tuple[str, float]], log_scale: bool = False
134
+ # ) -> List[Tuple[str, float]]:
135
+ # if log_scale:
136
+ # surprisals = [np.log(surprisal) for _, surprisal in tokens2surprisal]
137
+ # else:
138
+ # surprisals = [surprisal for _, surprisal in tokens2surprisal]
139
+ # min_surprisal = np.min(surprisals)
140
+ # max_surprisal = np.max(surprisals)
141
+ # surprisals = [
142
+ # (surprisal - min_surprisal) / (max_surprisal - min_surprisal)
143
+ # for surprisal in surprisals
144
+ # ]
145
+ # assert min(surprisals) >= 0
146
+ # assert max(surprisals) <= 1
147
+ # return [
148
+ # (token, surprisal)
149
+ # for (token, _), surprisal in zip(tokens2surprisal, surprisals)
150
+ # ]
151
+
152
+
153
+ # def calculate_surprisal_diff(
154
+ # tokens2surprisal: List[Tuple[str, float]],
155
+ # baseline_tokens2surprisal: List[Tuple[str, float]],
156
+ # scale: float = 100.0,
157
+ # ):
158
+ # diff_tokens2surprisal = [
159
+ # (token, (surprisal - baseline_surprisal) * 100)
160
+ # for (token, surprisal), (_, baseline_surprisal) in zip(
161
+ # tokens2surprisal, baseline_tokens2surprisal
162
+ # )
163
+ # ]
164
+ # return diff_tokens2surprisal
165
+
166
+ # @spaces.GPU
167
+ # def main(input_text: str) -> Tuple[str, str, str]:
168
+ # mean_surprisal, char2surprisal = calculate_surprisals_by_character(
169
+ # input_text, trained_model, tokenizer
170
+ # )
171
+ # offsets = calc_offsets(sudachi_tokenize(input_text))
172
+ # tokens2surprisal = aggregate_surprisals_by_offset(char2surprisal, offsets)
173
+ # tokens2surprisal = normalize_surprisals(tokens2surprisal)
174
+
175
+ # highlighted_text = create_highlighted_text(
176
+ # "学習後モデル", tokens2surprisal, mean_surprisal
177
+ # )
178
+
179
+ # (
180
+ # baseline_mean_surprisal,
181
+ # baseline_char2surprisal,
182
+ # ) = calculate_surprisals_by_character(input_text, baseline_model, tokenizer)
183
+ # baseline_tokens2surprisal = aggregate_surprisals_by_offset(
184
+ # baseline_char2surprisal, offsets
185
+ # )
186
+ # baseline_tokens2surprisal = normalize_surprisals(baseline_tokens2surprisal)
187
+ # baseline_highlighted_text = create_highlighted_text(
188
+ # "学習前モデル", baseline_tokens2surprisal, baseline_mean_surprisal
189
+ # )
190
+
191
+ # diff_tokens2surprisal = calculate_surprisal_diff(
192
+ # tokens2surprisal, baseline_tokens2surprisal, 100.0
193
+ # )
194
+ # diff_highlighted_text = create_highlighted_text(
195
+ # "学習前後の差分", diff_tokens2surprisal, None
196
+ # )
197
+ # return (
198
+ # baseline_highlighted_text,
199
+ # highlighted_text,
200
+ # diff_highlighted_text,
201
+ # )
202
+
203
+
204
+ # if __name__ == "__main__":
205
+ # demo = gr.Interface(
206
+ # fn=main,
207
+ # title="文章の読みやすさを自動評価するAI",
208
+ # description="文章を入力すると、読みづらい表現は赤く、読みやすい表現は青くハイライトされて出力されます。",
209
+ # # show_label=True,
210
+ # inputs=gr.Textbox(
211
+ # lines=5,
212
+ # label="文章",
213
+ # placeholder="ここに文章を入力してください。",
214
+ # ),
215
+ # outputs=[
216
+ # gr.HTML(label="学習前モデル", show_label=True),
217
+ # gr.HTML(label="学習後モデル", show_label=True),
218
+ # gr.HTML(label="学習前後の差分", show_label=True),
219
+ # ],
220
+ # examples=[
221
+ # "太郎が二郎を殴った。",
222
+ # "太郎が二郎に殴った。",
223
+ # "サイエンスインパクトラボは、国立研究開発法人科学技術振興機構(JST)の「科学と社会」推進部が行う共創プログラムです。「先端の研究開発を行う研究者」と「社会課題解決に取り組むプレイヤー」が約3ヶ月に渡って共創活動を行います。",
224
+ # "近年、ニューラル言語モデルが自然言語の統語知識をどれほど有しているかを、容認性判断課題を通して検証する研究が行われてきている。しかし、このような言語モデルの統語的評価を行うためのデータセットは、主に英語を中心とした欧米の諸言語を対象に構築されてきた。本研究では、既存のデータセットの問題点を克服しつつ、このようなデータセットが構築されてこなかった日本語を対象とした初めてのデータセットである JCoLA (JapaneseCorpus of Linguistic Acceptability) を構築した上で、それを用いた言語モデルの統語的評価を行った。",
225
+ # ],
226
+ # )
227
+
228
+ # demo.launch()
229
+
230
+
231
+ import os
232
+ import pickle as pkl
233
  from pathlib import Path
234
+ from threading import Thread
235
+ from typing import List, Optional, Tuple, Iterator
236
 
237
  import spaces
238
  import gradio as gr
239
  import numpy as np
240
  import torch
241
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
242
+
243
+
244
+ MAX_MAX_NEW_TOKENS = 2048
245
+ DEFAULT_MAX_NEW_TOKENS = 1024
246
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
247
+
248
+ DESCRIPTION = """\
249
+ # Llama-2 7B Chat with Streamable Semantic Uncertainty Probe
250
+ This Space demonstrates the Llama-2-7b-chat model with an added semantic uncertainty probe.
251
+ The highlighted text shows the model's uncertainty in real-time, with more intense yellow indicating higher uncertainty.
252
+ """
253
+
254
+ if torch.cuda.is_available():
255
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
256
+ # TODO load the full model?
257
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
258
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
259
+ tokenizer.use_default_system_prompt = False
260
+
261
+ # load the probe data
262
+ # TODO load accuracy and SE probe and compare in different tabs
263
+ with open("./model/20240625-131035_demo.pkl", "rb") as f:
264
+ probe_data = pkl.load(f)
265
+ # take the NQ open one
266
+ probe_data = probe_data[-2]
267
+ probe = probe_data['t_bmodel']
268
+ layer_range = probe_data['sep_layer_range']
269
+ acc_probe = probe_data['t_amodel']
270
+ acc_layer_range = probe_data['ap_layer_range']
271
+
272
+ @spaces.GPU
273
+ def generate(
274
+ message: str,
275
+ chat_history: List[Tuple[str, str]],
276
+ system_prompt: str,
277
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
278
+ temperature: float = 0.6,
279
+ top_p: float = 0.9,
280
+ top_k: int = 50,
281
+ repetition_penalty: float = 1.2,
282
+ ) -> Iterator[str]:
283
+ conversation = []
284
+ if system_prompt:
285
+ conversation.append({"role": "system", "content": system_prompt})
286
+ for user, assistant in chat_history:
287
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
288
+ conversation.append({"role": "user", "content": message})
289
+
290
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
291
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
292
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
293
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
294
+ input_ids = input_ids.to(model.device)
295
+
296
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
297
+ generation_kwargs = dict(
298
+ input_ids=input_ids,
299
+ max_new_tokens=max_new_tokens,
300
+ do_sample=True,
301
+ top_p=top_p,
302
+ top_k=top_k,
303
+ temperature=temperature,
304
+ repetition_penalty=repetition_penalty,
305
+ streamer=streamer,
306
+ output_hidden_states=True,
307
+ return_dict_in_generate=True,
308
+ )
309
+
310
+ # Generate without threading
311
+ with torch.no_grad():
312
+ outputs = model.generate(**generation_kwargs)
313
+ print(outputs.sequences.shape, input_ids.shape)
314
+ generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
315
+ print("Generated tokens:", generated_tokens, generated_tokens.shape)
316
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
317
+ print("Generated text:", generated_text)
318
+ # hidden states
319
+ hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
320
+ print(len(hidden))
321
+ print(len(hidden[1])) # layers
322
+ print(hidden[1][0].shape) # (sequence length, hidden size)
323
+ # stack token embeddings
324
+
325
+ # TODO do this loop on the fly instead of waiting for the whole generation
326
+ highlighted_text = ""
327
+ for i in range(1, len(hidden)):
328
+ token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]) # (num_layers, hidden_size)
329
+ # print(token_embeddings.shape)
330
+ # probe the model
331
+ # print(token_embeddings.numpy()[layer_range].shape)
332
+ concat_layers = token_embeddings.numpy()[layer_range[0]:layer_range[1]].reshape(-1) # (num_layers * hidden_size)
333
+ # print(concat_layers.shape)
334
+ # or prob?
335
+ probe_pred = probe.predict_log_proba(concat_layers.reshape(1, -1))[0][1] # prob of high SE
336
+ # print(probe_pred.shape, probe_pred)
337
+ # decode one token at a time
338
+ output_id = outputs.sequences[0, input_ids.shape[1]+i]
339
+ print(output_id, output_word, probe_pred)
340
+ output_word = tokenizer.decode(output_id)
341
+ new_highlighted_text = highlight_text(output_word, probe_pred)
342
+ highlighted_text += new_highlighted_text
343
+
344
+ yield highlighted_text
345
+
346
+ def highlight_text(text: str, uncertainty_score: float) -> str:
347
+ if uncertainty_score > 0:
348
  html_color = "#%02X%02X%02X" % (
349
  255,
350
+ int(255 * (1 - uncertainty_score)),
351
+ int(255 * (1 - uncertainty_score)),
352
  )
353
  else:
354
  html_color = "#%02X%02X%02X" % (
355
+ int(255 * (1 + uncertainty_score)),
356
  255,
357
+ int(255 * (1 + uncertainty_score)),
358
  )
359
  return '<span style="background-color: {}; color: black">{}</span>'.format(
360
+ html_color, text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  )
362
+ chat_interface = gr.ChatInterface(
363
+ fn=generate,
364
+ additional_inputs=[
365
+ gr.Textbox(label="System prompt", lines=6),
366
+ gr.Slider(
367
+ label="Max new tokens",
368
+ minimum=1,
369
+ maximum=MAX_MAX_NEW_TOKENS,
370
+ step=1,
371
+ value=DEFAULT_MAX_NEW_TOKENS,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  ),
373
+ gr.Slider(
374
+ label="Temperature",
375
+ minimum=0.1,
376
+ maximum=4.0,
377
+ step=0.1,
378
+ value=0.6,
379
+ ),
380
+ gr.Slider(
381
+ label="Top-p (nucleus sampling)",
382
+ minimum=0.05,
383
+ maximum=1.0,
384
+ step=0.05,
385
+ value=0.9,
386
+ ),
387
+ gr.Slider(
388
+ label="Top-k",
389
+ minimum=1,
390
+ maximum=1000,
391
+ step=1,
392
+ value=50,
393
+ ),
394
+ gr.Slider(
395
+ label="Repetition penalty",
396
+ minimum=1.0,
397
+ maximum=2.0,
398
+ step=0.05,
399
+ value=1.2,
400
+ ),
401
+ ],
402
+ stop_btn=None,
403
+ examples=[
404
+ ["What is the capital of France?"],
405
+ ["Explain the theory of relativity in simple terms."],
406
+ ["Write a short poem about artificial intelligence."]
407
+ ],
408
+ title="Llama-2 7B Chat with Streamable Semantic Uncertainty Probe",
409
+ description=DESCRIPTION,
410
+ )
411
 
412
+ if __name__ == "__main__":
413
+ chat_interface.launch()
app_sep.py CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
4
  from threading import Thread
5
  from typing import List, Optional, Tuple, Iterator
6
 
 
7
  import gradio as gr
8
  import numpy as np
9
  import torch
@@ -33,11 +34,12 @@ if torch.cuda.is_available():
33
  probe_data = pkl.load(f)
34
  # take the NQ open one
35
  probe_data = probe_data[-2]
36
- model = probe_data['t_bmodel']
37
  layer_range = probe_data['sep_layer_range']
38
- acc_model = probe_data['t_amodel']
39
  acc_layer_range = probe_data['ap_layer_range']
40
 
 
41
  def generate(
42
  message: str,
43
  chat_history: List[Tuple[str, str]],
@@ -75,50 +77,42 @@ def generate(
75
  return_dict_in_generate=True,
76
  )
77
 
78
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
79
- thread.start()
80
-
81
- generated_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
82
  highlighted_text = ""
83
- for output in streamer:
84
- print(output)
85
- generated_text += output
86
-
87
- yield generated_text
88
-
89
- # this is doing it twice... just do autoregressive generation instead
90
- for new_text in streamer:
91
- generated_text += new_text
92
- current_input_ids = tokenizer.encode(generated_text, return_tensors="pt").to(model.device)
93
-
94
- with torch.no_grad():
95
- outputs = model(current_input_ids, output_hidden_states=True)
96
- hidden = outputs.hidden_states
97
- # Stack second last token embeddings from all layers
98
- # if len(hidden) == 1: # FIX: runtime error for mistral-7b on bioasq
99
- # sec_last_input = hidden[0]
100
- # elif ((n_generated - 2) >= len(hidden)):
101
- # sec_last_input = hidden[-2]
102
- # else:
103
- # sec_last_input = hidden[n_generated - 2]
104
- last_hidden_state = torch.stack([layer[:, -1, :].cpu() for layer in hidden[-1]]).cpu().numpy()
105
- # print(sec_last_token_embedding.shape)
106
- # last_hidden_state = outputs.hidden_states[-1][:, -1, :].cpu().numpy()
107
- print(last_hidden_state.shape)
108
- # TODO potentially need to only compute uncertainty for the last token in sentence?
109
-
110
- # concatenate the hidden states from the specified layers
111
- probe_input = np.concatenate(last_hidden_state[layer_range], axis=1)
112
- print(probe_input.shape)
113
- uncertainty_score = model.predict(probe_input)
114
- print(uncertainty_score)
115
- new_highlighted_text = highlight_text(new_text, uncertainty_score[0])
116
- print(new_highlighted_text)
117
  highlighted_text += new_highlighted_text
118
-
119
  yield highlighted_text
120
 
121
-
122
  def highlight_text(text: str, uncertainty_score: float) -> str:
123
  if uncertainty_score > 0:
124
  html_color = "#%02X%02X%02X" % (
 
4
  from threading import Thread
5
  from typing import List, Optional, Tuple, Iterator
6
 
7
+ import spaces
8
  import gradio as gr
9
  import numpy as np
10
  import torch
 
34
  probe_data = pkl.load(f)
35
  # take the NQ open one
36
  probe_data = probe_data[-2]
37
+ probe = probe_data['t_bmodel']
38
  layer_range = probe_data['sep_layer_range']
39
+ acc_probe = probe_data['t_amodel']
40
  acc_layer_range = probe_data['ap_layer_range']
41
 
42
+ @spaces.GPU
43
  def generate(
44
  message: str,
45
  chat_history: List[Tuple[str, str]],
 
77
  return_dict_in_generate=True,
78
  )
79
 
80
+ # Generate without threading
81
+ with torch.no_grad():
82
+ outputs = model.generate(**generation_kwargs)
83
+ print(outputs.sequences.shape, input_ids.shape)
84
+ generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
85
+ print("Generated tokens:", generated_tokens, generated_tokens.shape)
86
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
87
+ print("Generated text:", generated_text)
88
+ # hidden states
89
+ hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
90
+ print(len(hidden))
91
+ print(len(hidden[1])) # layers
92
+ print(hidden[1][0].shape) # (sequence length, hidden size)
93
+ # stack token embeddings
94
+
95
+ # TODO do this loop on the fly instead of waiting for the whole generation
96
  highlighted_text = ""
97
+ for i in range(1, len(hidden)):
98
+ token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]) # (num_layers, hidden_size)
99
+ # print(token_embeddings.shape)
100
+ # probe the model
101
+ # print(token_embeddings.numpy()[layer_range].shape)
102
+ concat_layers = token_embeddings.numpy()[layer_range[0]:layer_range[1]].reshape(-1) # (num_layers * hidden_size)
103
+ # print(concat_layers.shape)
104
+ # or prob?
105
+ probe_pred = probe.predict_log_proba(concat_layers.reshape(1, -1))[0][1] # prob of high SE
106
+ # print(probe_pred.shape, probe_pred)
107
+ # decode one token at a time
108
+ output_id = outputs.sequences[0, input_ids.shape[1]+i]
109
+ print(output_id, output_word, probe_pred)
110
+ output_word = tokenizer.decode(output_id)
111
+ new_highlighted_text = highlight_text(output_word, probe_pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  highlighted_text += new_highlighted_text
113
+
114
  yield highlighted_text
115
 
 
116
  def highlight_text(text: str, uncertainty_score: float) -> str:
117
  if uncertainty_score > 0:
118
  html_color = "#%02X%02X%02X" % (