ACMCMC commited on
Commit
0331d85
·
1 Parent(s): 3b410d3

Update code

Browse files
Files changed (1) hide show
  1. main.ipynb +879 -17
main.ipynb CHANGED
@@ -2,10 +2,25 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "d1daab37",
7
  "metadata": {},
8
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "source": [
10
  "import ipywidgets\n",
11
  "\n",
@@ -20,14 +35,688 @@
20
  },
21
  {
22
  "cell_type": "code",
23
- "execution_count": null,
24
  "id": "09b3c097",
25
  "metadata": {},
26
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  "source": [
28
  "import transformers\n",
29
  "import torch\n",
30
  "import tqdm\n",
 
31
  "\n",
32
  "print(f\"Optimizing: {text_widget.value}\")\n",
33
  "\n",
@@ -79,7 +768,7 @@
79
  "\n",
80
  "# Define the number of optimization steps\n",
81
  "n_steps = 5000\n",
82
- "patience = 50\n",
83
  "patience_counter = 0\n",
84
  "epsilon = 1e-1\n",
85
  "\n",
@@ -112,8 +801,8 @@
112
  "\n",
113
  " # Compute the ranks of the input IDs, i.e. how many tokens would have been more likely than the correct one (the label, the input IDs)\n",
114
  " \n",
115
- " count_higher_logits = (logits[:, ss_prompt.size(-2) - 1 :-1] > logits[:, -1:]).float()\n",
116
- " ranks = count_higher_logits.sum(dim=-1)\n",
117
  "\n",
118
  " # Compute the loss\n",
119
  " loss = loss_fn(\n",
@@ -134,7 +823,7 @@
134
  " for optimized, original in zip(ss_prompt, original_ss_prompt)\n",
135
  " )\n",
136
  " print(\n",
137
- " f\"Step {step}, Loss: {loss.item()}, L2 norm: {l2_norm.item()}, avg rank: {ranks.mean().item()}\"\n",
138
  " )\n",
139
  "\n",
140
  " # Early stopping with patience\n",
@@ -147,12 +836,17 @@
147
  "\n",
148
  " if patience_counter >= patience:\n",
149
  " print(f\"Early stopping at step {step} with best loss {best_loss}\")\n",
 
 
 
 
 
150
  " break"
151
  ]
152
  },
153
  {
154
  "cell_type": "code",
155
- "execution_count": null,
156
  "id": "cc9a6a2f",
157
  "metadata": {},
158
  "outputs": [],
@@ -163,22 +857,190 @@
163
  },
164
  {
165
  "cell_type": "code",
166
- "execution_count": null,
167
  "id": "3186747d",
168
  "metadata": {},
169
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  "source": [
171
  "import os\n",
172
  "import transformers\n",
173
  "import torch\n",
174
  "import tqdm\n",
175
  "\n",
176
- "# tokenizer: transformers.PreTrainedTokenizer = (\n",
177
- "# transformers.AutoTokenizer.from_pretrained(\"openai-community/gpt2\")\n",
178
- "# )\n",
179
- "# model: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(\n",
180
- "# \"openai-community/gpt2\"\n",
181
- "# )\n",
 
 
182
  "\n",
183
  "# Load the best soft prompt from file\n",
184
  "best_ss_prompt = torch.load(\"best_ss_prompt.pt\")\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "d1daab37",
7
  "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "application/vnd.jupyter.widget-view+json": {
12
+ "model_id": "6be175a38e1843ab83831cb3a9a06296",
13
+ "version_major": 2,
14
+ "version_minor": 0
15
+ },
16
+ "text/plain": [
17
+ "Textarea(value=\"Solomonoff's theory of inductive inference proposes that all problems of logical induction can…"
18
+ ]
19
+ },
20
+ "metadata": {},
21
+ "output_type": "display_data"
22
+ }
23
+ ],
24
  "source": [
25
  "import ipywidgets\n",
26
  "\n",
 
35
  },
36
  {
37
  "cell_type": "code",
38
+ "execution_count": 6,
39
  "id": "09b3c097",
40
  "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "Optimizing: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps!\n"
47
+ ]
48
+ },
49
+ {
50
+ "name": "stderr",
51
+ "output_type": "stream",
52
+ "text": [
53
+ " 0%| | 2/5000 [00:00<13:45, 6.05it/s]"
54
+ ]
55
+ },
56
+ {
57
+ "name": "stdout",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "Step 0, Loss: 5.8886637687683105, L2 norm: 0.5542519688606262, avg rank: 655.2222290039062\n"
61
+ ]
62
+ },
63
+ {
64
+ "name": "stderr",
65
+ "output_type": "stream",
66
+ "text": [
67
+ " 0%| | 12/5000 [00:01<10:57, 7.58it/s]"
68
+ ]
69
+ },
70
+ {
71
+ "name": "stdout",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Step 10, Loss: 5.159086227416992, L2 norm: 3.46441388130188, avg rank: 439.6666564941406\n"
75
+ ]
76
+ },
77
+ {
78
+ "name": "stderr",
79
+ "output_type": "stream",
80
+ "text": [
81
+ " 0%| | 22/5000 [00:03<12:31, 6.62it/s]"
82
+ ]
83
+ },
84
+ {
85
+ "name": "stdout",
86
+ "output_type": "stream",
87
+ "text": [
88
+ "Step 20, Loss: 4.893388748168945, L2 norm: 5.287822246551514, avg rank: 380.629638671875\n"
89
+ ]
90
+ },
91
+ {
92
+ "name": "stderr",
93
+ "output_type": "stream",
94
+ "text": [
95
+ " 1%| | 32/5000 [00:04<10:37, 7.79it/s]"
96
+ ]
97
+ },
98
+ {
99
+ "name": "stdout",
100
+ "output_type": "stream",
101
+ "text": [
102
+ "Step 30, Loss: 4.576204299926758, L2 norm: 7.122143745422363, avg rank: 313.0740661621094\n"
103
+ ]
104
+ },
105
+ {
106
+ "name": "stderr",
107
+ "output_type": "stream",
108
+ "text": [
109
+ " 1%| | 42/5000 [00:05<10:42, 7.71it/s]"
110
+ ]
111
+ },
112
+ {
113
+ "name": "stdout",
114
+ "output_type": "stream",
115
+ "text": [
116
+ "Step 40, Loss: 4.330563545227051, L2 norm: 8.640031814575195, avg rank: 189.8518524169922\n"
117
+ ]
118
+ },
119
+ {
120
+ "name": "stderr",
121
+ "output_type": "stream",
122
+ "text": [
123
+ " 1%| | 52/5000 [00:07<13:39, 6.04it/s]"
124
+ ]
125
+ },
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "Step 50, Loss: 4.107393264770508, L2 norm: 10.037625312805176, avg rank: 130.6666717529297\n"
131
+ ]
132
+ },
133
+ {
134
+ "name": "stderr",
135
+ "output_type": "stream",
136
+ "text": [
137
+ " 1%| | 62/5000 [00:08<11:33, 7.12it/s]"
138
+ ]
139
+ },
140
+ {
141
+ "name": "stdout",
142
+ "output_type": "stream",
143
+ "text": [
144
+ "Step 60, Loss: 3.8990049362182617, L2 norm: 11.52373218536377, avg rank: 105.9259262084961\n"
145
+ ]
146
+ },
147
+ {
148
+ "name": "stderr",
149
+ "output_type": "stream",
150
+ "text": [
151
+ " 1%|▏ | 72/5000 [00:10<19:46, 4.15it/s]"
152
+ ]
153
+ },
154
+ {
155
+ "name": "stdout",
156
+ "output_type": "stream",
157
+ "text": [
158
+ "Step 70, Loss: 3.689441442489624, L2 norm: 12.885846138000488, avg rank: 96.77777862548828\n"
159
+ ]
160
+ },
161
+ {
162
+ "name": "stderr",
163
+ "output_type": "stream",
164
+ "text": [
165
+ " 2%|▏ | 82/5000 [00:13<14:27, 5.67it/s]"
166
+ ]
167
+ },
168
+ {
169
+ "name": "stdout",
170
+ "output_type": "stream",
171
+ "text": [
172
+ "Step 80, Loss: 3.664461374282837, L2 norm: 14.095556259155273, avg rank: 82.48148345947266\n"
173
+ ]
174
+ },
175
+ {
176
+ "name": "stderr",
177
+ "output_type": "stream",
178
+ "text": [
179
+ " 2%|▏ | 92/5000 [00:14<12:06, 6.76it/s]"
180
+ ]
181
+ },
182
+ {
183
+ "name": "stdout",
184
+ "output_type": "stream",
185
+ "text": [
186
+ "Step 90, Loss: 3.359441041946411, L2 norm: 15.264603614807129, avg rank: 76.62963104248047\n"
187
+ ]
188
+ },
189
+ {
190
+ "name": "stderr",
191
+ "output_type": "stream",
192
+ "text": [
193
+ " 2%|▏ | 102/5000 [00:15<11:25, 7.14it/s]"
194
+ ]
195
+ },
196
+ {
197
+ "name": "stdout",
198
+ "output_type": "stream",
199
+ "text": [
200
+ "Step 100, Loss: 3.2641189098358154, L2 norm: 16.456912994384766, avg rank: 59.55555725097656\n"
201
+ ]
202
+ },
203
+ {
204
+ "name": "stderr",
205
+ "output_type": "stream",
206
+ "text": [
207
+ " 2%|▏ | 111/5000 [00:17<10:21, 7.87it/s]"
208
+ ]
209
+ },
210
+ {
211
+ "name": "stdout",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "Step 110, Loss: 3.4184069633483887, L2 norm: 17.494909286499023, avg rank: 45.22222137451172\n"
215
+ ]
216
+ },
217
+ {
218
+ "name": "stderr",
219
+ "output_type": "stream",
220
+ "text": [
221
+ " 2%|▏ | 122/5000 [00:19<14:21, 5.66it/s]"
222
+ ]
223
+ },
224
+ {
225
+ "name": "stdout",
226
+ "output_type": "stream",
227
+ "text": [
228
+ "Step 120, Loss: 3.057454824447632, L2 norm: 18.361045837402344, avg rank: 47.592594146728516\n"
229
+ ]
230
+ },
231
+ {
232
+ "name": "stderr",
233
+ "output_type": "stream",
234
+ "text": [
235
+ " 3%|▎ | 132/5000 [00:20<11:23, 7.12it/s]"
236
+ ]
237
+ },
238
+ {
239
+ "name": "stdout",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "Step 130, Loss: 2.752321243286133, L2 norm: 19.076860427856445, avg rank: 37.185184478759766\n"
243
+ ]
244
+ },
245
+ {
246
+ "name": "stderr",
247
+ "output_type": "stream",
248
+ "text": [
249
+ " 3%|▎ | 142/5000 [00:21<10:12, 7.93it/s]"
250
+ ]
251
+ },
252
+ {
253
+ "name": "stdout",
254
+ "output_type": "stream",
255
+ "text": [
256
+ "Step 140, Loss: 2.492600679397583, L2 norm: 19.873262405395508, avg rank: 29.0\n"
257
+ ]
258
+ },
259
+ {
260
+ "name": "stderr",
261
+ "output_type": "stream",
262
+ "text": [
263
+ " 3%|▎ | 152/5000 [00:23<12:10, 6.64it/s]"
264
+ ]
265
+ },
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "Step 150, Loss: 2.2693941593170166, L2 norm: 20.689470291137695, avg rank: 25.22222137451172\n"
271
+ ]
272
+ },
273
+ {
274
+ "name": "stderr",
275
+ "output_type": "stream",
276
+ "text": [
277
+ " 3%|▎ | 162/5000 [00:24<11:26, 7.05it/s]"
278
+ ]
279
+ },
280
+ {
281
+ "name": "stdout",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "Step 160, Loss: 2.0094830989837646, L2 norm: 21.427452087402344, avg rank: 20.185184478759766\n"
285
+ ]
286
+ },
287
+ {
288
+ "name": "stderr",
289
+ "output_type": "stream",
290
+ "text": [
291
+ " 3%|▎ | 172/5000 [00:26<11:22, 7.08it/s]"
292
+ ]
293
+ },
294
+ {
295
+ "name": "stdout",
296
+ "output_type": "stream",
297
+ "text": [
298
+ "Step 170, Loss: 1.8882372379302979, L2 norm: 22.13208770751953, avg rank: 23.074073791503906\n"
299
+ ]
300
+ },
301
+ {
302
+ "name": "stderr",
303
+ "output_type": "stream",
304
+ "text": [
305
+ " 4%|▎ | 182/5000 [00:27<09:56, 8.08it/s]"
306
+ ]
307
+ },
308
+ {
309
+ "name": "stdout",
310
+ "output_type": "stream",
311
+ "text": [
312
+ "Step 180, Loss: 1.8304041624069214, L2 norm: 22.583101272583008, avg rank: 28.33333396911621\n"
313
+ ]
314
+ },
315
+ {
316
+ "name": "stderr",
317
+ "output_type": "stream",
318
+ "text": [
319
+ " 4%|▍ | 192/5000 [00:28<10:58, 7.30it/s]"
320
+ ]
321
+ },
322
+ {
323
+ "name": "stdout",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "Step 190, Loss: 1.5684080123901367, L2 norm: 23.05557632446289, avg rank: 22.037036895751953\n"
327
+ ]
328
+ },
329
+ {
330
+ "name": "stderr",
331
+ "output_type": "stream",
332
+ "text": [
333
+ " 4%|▍ | 202/5000 [00:30<09:58, 8.02it/s]"
334
+ ]
335
+ },
336
+ {
337
+ "name": "stdout",
338
+ "output_type": "stream",
339
+ "text": [
340
+ "Step 200, Loss: 1.3705590963363647, L2 norm: 23.487506866455078, avg rank: 18.925926208496094\n"
341
+ ]
342
+ },
343
+ {
344
+ "name": "stderr",
345
+ "output_type": "stream",
346
+ "text": [
347
+ " 4%|▍ | 212/5000 [00:31<09:58, 8.00it/s]"
348
+ ]
349
+ },
350
+ {
351
+ "name": "stdout",
352
+ "output_type": "stream",
353
+ "text": [
354
+ "Step 210, Loss: 1.2061578035354614, L2 norm: 23.888507843017578, avg rank: 15.481481552124023\n"
355
+ ]
356
+ },
357
+ {
358
+ "name": "stderr",
359
+ "output_type": "stream",
360
+ "text": [
361
+ " 4%|▍ | 222/5000 [00:32<10:36, 7.50it/s]"
362
+ ]
363
+ },
364
+ {
365
+ "name": "stdout",
366
+ "output_type": "stream",
367
+ "text": [
368
+ "Step 220, Loss: 1.0748673677444458, L2 norm: 24.25291633605957, avg rank: 13.333333015441895\n"
369
+ ]
370
+ },
371
+ {
372
+ "name": "stderr",
373
+ "output_type": "stream",
374
+ "text": [
375
+ " 5%|▍ | 232/5000 [00:34<11:45, 6.76it/s]"
376
+ ]
377
+ },
378
+ {
379
+ "name": "stdout",
380
+ "output_type": "stream",
381
+ "text": [
382
+ "Step 230, Loss: 0.9695132374763489, L2 norm: 24.588470458984375, avg rank: 12.592592239379883\n"
383
+ ]
384
+ },
385
+ {
386
+ "name": "stderr",
387
+ "output_type": "stream",
388
+ "text": [
389
+ " 5%|▍ | 242/5000 [00:35<11:47, 6.73it/s]"
390
+ ]
391
+ },
392
+ {
393
+ "name": "stdout",
394
+ "output_type": "stream",
395
+ "text": [
396
+ "Step 240, Loss: 0.8693955540657043, L2 norm: 24.909320831298828, avg rank: 11.518518447875977\n"
397
+ ]
398
+ },
399
+ {
400
+ "name": "stderr",
401
+ "output_type": "stream",
402
+ "text": [
403
+ " 5%|▌ | 252/5000 [00:36<09:53, 7.99it/s]"
404
+ ]
405
+ },
406
+ {
407
+ "name": "stdout",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "Step 250, Loss: 0.7695979475975037, L2 norm: 25.205760955810547, avg rank: 10.037036895751953\n"
411
+ ]
412
+ },
413
+ {
414
+ "name": "stderr",
415
+ "output_type": "stream",
416
+ "text": [
417
+ " 5%|▌ | 262/5000 [00:38<10:06, 7.81it/s]"
418
+ ]
419
+ },
420
+ {
421
+ "name": "stdout",
422
+ "output_type": "stream",
423
+ "text": [
424
+ "Step 260, Loss: 0.6803638935089111, L2 norm: 25.493518829345703, avg rank: 8.666666984558105\n"
425
+ ]
426
+ },
427
+ {
428
+ "name": "stderr",
429
+ "output_type": "stream",
430
+ "text": [
431
+ " 5%|▌ | 272/5000 [00:39<11:42, 6.73it/s]"
432
+ ]
433
+ },
434
+ {
435
+ "name": "stdout",
436
+ "output_type": "stream",
437
+ "text": [
438
+ "Step 270, Loss: 0.6074316501617432, L2 norm: 25.7545108795166, avg rank: 7.407407283782959\n"
439
+ ]
440
+ },
441
+ {
442
+ "name": "stderr",
443
+ "output_type": "stream",
444
+ "text": [
445
+ " 6%|▌ | 282/5000 [00:40<10:20, 7.60it/s]"
446
+ ]
447
+ },
448
+ {
449
+ "name": "stdout",
450
+ "output_type": "stream",
451
+ "text": [
452
+ "Step 280, Loss: 0.5503782629966736, L2 norm: 25.987356185913086, avg rank: 6.666666507720947\n"
453
+ ]
454
+ },
455
+ {
456
+ "name": "stderr",
457
+ "output_type": "stream",
458
+ "text": [
459
+ " 6%|▌ | 292/5000 [00:42<10:42, 7.33it/s]"
460
+ ]
461
+ },
462
+ {
463
+ "name": "stdout",
464
+ "output_type": "stream",
465
+ "text": [
466
+ "Step 290, Loss: 0.5049920678138733, L2 norm: 26.197214126586914, avg rank: 6.111111164093018\n"
467
+ ]
468
+ },
469
+ {
470
+ "name": "stderr",
471
+ "output_type": "stream",
472
+ "text": [
473
+ " 6%|▌ | 302/5000 [00:43<11:42, 6.68it/s]"
474
+ ]
475
+ },
476
+ {
477
+ "name": "stdout",
478
+ "output_type": "stream",
479
+ "text": [
480
+ "Step 300, Loss: 0.46782854199409485, L2 norm: 26.38979721069336, avg rank: 5.407407283782959\n"
481
+ ]
482
+ },
483
+ {
484
+ "name": "stderr",
485
+ "output_type": "stream",
486
+ "text": [
487
+ " 6%|▌ | 312/5000 [00:45<10:34, 7.39it/s]"
488
+ ]
489
+ },
490
+ {
491
+ "name": "stdout",
492
+ "output_type": "stream",
493
+ "text": [
494
+ "Step 310, Loss: 0.4365224838256836, L2 norm: 26.568010330200195, avg rank: 4.666666507720947\n"
495
+ ]
496
+ },
497
+ {
498
+ "name": "stderr",
499
+ "output_type": "stream",
500
+ "text": [
501
+ " 6%|▋ | 322/5000 [00:46<10:26, 7.46it/s]"
502
+ ]
503
+ },
504
+ {
505
+ "name": "stdout",
506
+ "output_type": "stream",
507
+ "text": [
508
+ "Step 320, Loss: 0.4172615706920624, L2 norm: 26.73436737060547, avg rank: 4.296296119689941\n"
509
+ ]
510
+ },
511
+ {
512
+ "name": "stderr",
513
+ "output_type": "stream",
514
+ "text": [
515
+ " 7%|▋ | 332/5000 [00:48<11:34, 6.73it/s]"
516
+ ]
517
+ },
518
+ {
519
+ "name": "stdout",
520
+ "output_type": "stream",
521
+ "text": [
522
+ "Step 330, Loss: 0.44890207052230835, L2 norm: 26.890453338623047, avg rank: 4.777777671813965\n"
523
+ ]
524
+ },
525
+ {
526
+ "name": "stderr",
527
+ "output_type": "stream",
528
+ "text": [
529
+ " 7%|▋ | 342/5000 [00:49<09:44, 7.97it/s]"
530
+ ]
531
+ },
532
+ {
533
+ "name": "stdout",
534
+ "output_type": "stream",
535
+ "text": [
536
+ "Step 340, Loss: 1.6345257759094238, L2 norm: 27.1015625, avg rank: 10.0\n"
537
+ ]
538
+ },
539
+ {
540
+ "name": "stderr",
541
+ "output_type": "stream",
542
+ "text": [
543
+ " 7%|▋ | 352/5000 [00:50<10:12, 7.59it/s]"
544
+ ]
545
+ },
546
+ {
547
+ "name": "stdout",
548
+ "output_type": "stream",
549
+ "text": [
550
+ "Step 350, Loss: 1.079402208328247, L2 norm: 27.411144256591797, avg rank: 6.592592716217041\n"
551
+ ]
552
+ },
553
+ {
554
+ "name": "stderr",
555
+ "output_type": "stream",
556
+ "text": [
557
+ " 7%|▋ | 362/5000 [00:52<12:10, 6.35it/s]"
558
+ ]
559
+ },
560
+ {
561
+ "name": "stdout",
562
+ "output_type": "stream",
563
+ "text": [
564
+ "Step 360, Loss: 0.8217418193817139, L2 norm: 27.712562561035156, avg rank: 4.703703880310059\n"
565
+ ]
566
+ },
567
+ {
568
+ "name": "stderr",
569
+ "output_type": "stream",
570
+ "text": [
571
+ " 7%|▋ | 372/5000 [00:53<11:04, 6.97it/s]"
572
+ ]
573
+ },
574
+ {
575
+ "name": "stdout",
576
+ "output_type": "stream",
577
+ "text": [
578
+ "Step 370, Loss: 0.6470327973365784, L2 norm: 28.024415969848633, avg rank: 1.8888888359069824\n"
579
+ ]
580
+ },
581
+ {
582
+ "name": "stderr",
583
+ "output_type": "stream",
584
+ "text": [
585
+ " 8%|▊ | 382/5000 [00:55<10:21, 7.43it/s]"
586
+ ]
587
+ },
588
+ {
589
+ "name": "stdout",
590
+ "output_type": "stream",
591
+ "text": [
592
+ "Step 380, Loss: 1.0997235774993896, L2 norm: 28.2260684967041, avg rank: 7.185184955596924\n"
593
+ ]
594
+ },
595
+ {
596
+ "name": "stderr",
597
+ "output_type": "stream",
598
+ "text": [
599
+ " 8%|▊ | 392/5000 [00:56<10:58, 7.00it/s]"
600
+ ]
601
+ },
602
+ {
603
+ "name": "stdout",
604
+ "output_type": "stream",
605
+ "text": [
606
+ "Step 390, Loss: 0.64601731300354, L2 norm: 28.461565017700195, avg rank: 2.9259259700775146\n"
607
+ ]
608
+ },
609
+ {
610
+ "name": "stderr",
611
+ "output_type": "stream",
612
+ "text": [
613
+ " 8%|▊ | 402/5000 [00:58<09:58, 7.68it/s]"
614
+ ]
615
+ },
616
+ {
617
+ "name": "stdout",
618
+ "output_type": "stream",
619
+ "text": [
620
+ "Step 400, Loss: 0.482523649930954, L2 norm: 28.71732521057129, avg rank: 1.5185185670852661\n"
621
+ ]
622
+ },
623
+ {
624
+ "name": "stderr",
625
+ "output_type": "stream",
626
+ "text": [
627
+ " 8%|▊ | 412/5000 [00:59<09:54, 7.71it/s]"
628
+ ]
629
+ },
630
+ {
631
+ "name": "stdout",
632
+ "output_type": "stream",
633
+ "text": [
634
+ "Step 410, Loss: 0.38174617290496826, L2 norm: 28.943510055541992, avg rank: 1.2592592239379883\n"
635
+ ]
636
+ },
637
+ {
638
+ "name": "stderr",
639
+ "output_type": "stream",
640
+ "text": [
641
+ " 8%|▊ | 422/5000 [01:01<12:41, 6.01it/s]"
642
+ ]
643
+ },
644
+ {
645
+ "name": "stdout",
646
+ "output_type": "stream",
647
+ "text": [
648
+ "Step 420, Loss: 0.29940587282180786, L2 norm: 29.132537841796875, avg rank: 1.1111111640930176\n"
649
+ ]
650
+ },
651
+ {
652
+ "name": "stderr",
653
+ "output_type": "stream",
654
+ "text": [
655
+ " 9%|▊ | 432/5000 [01:02<10:03, 7.57it/s]"
656
+ ]
657
+ },
658
+ {
659
+ "name": "stdout",
660
+ "output_type": "stream",
661
+ "text": [
662
+ "Step 430, Loss: 0.23488281667232513, L2 norm: 29.293710708618164, avg rank: 1.0740740299224854\n"
663
+ ]
664
+ },
665
+ {
666
+ "name": "stderr",
667
+ "output_type": "stream",
668
+ "text": [
669
+ " 9%|▉ | 442/5000 [01:03<09:35, 7.91it/s]"
670
+ ]
671
+ },
672
+ {
673
+ "name": "stdout",
674
+ "output_type": "stream",
675
+ "text": [
676
+ "Step 440, Loss: 0.19830304384231567, L2 norm: 29.436723709106445, avg rank: 1.0370370149612427\n"
677
+ ]
678
+ },
679
+ {
680
+ "name": "stderr",
681
+ "output_type": "stream",
682
+ "text": [
683
+ " 9%|▉ | 452/5000 [01:05<10:41, 7.09it/s]"
684
+ ]
685
+ },
686
+ {
687
+ "name": "stdout",
688
+ "output_type": "stream",
689
+ "text": [
690
+ "Step 450, Loss: 0.17161428928375244, L2 norm: 29.54309844970703, avg rank: 1.0370370149612427\n"
691
+ ]
692
+ },
693
+ {
694
+ "name": "stderr",
695
+ "output_type": "stream",
696
+ "text": [
697
+ " 9%|▉ | 454/5000 [01:05<10:54, 6.94it/s]"
698
+ ]
699
+ },
700
+ {
701
+ "name": "stdout",
702
+ "output_type": "stream",
703
+ "text": [
704
+ "Perfect ranks achieved at step 454, stopping optimization.\n"
705
+ ]
706
+ },
707
+ {
708
+ "name": "stderr",
709
+ "output_type": "stream",
710
+ "text": [
711
+ "\n"
712
+ ]
713
+ }
714
+ ],
715
  "source": [
716
  "import transformers\n",
717
  "import torch\n",
718
  "import tqdm\n",
719
+ "torch.manual_seed(42) # Set a fixed seed for reproducibility\n",
720
  "\n",
721
  "print(f\"Optimizing: {text_widget.value}\")\n",
722
  "\n",
 
768
  "\n",
769
  "# Define the number of optimization steps\n",
770
  "n_steps = 5000\n",
771
+ "patience = 100\n",
772
  "patience_counter = 0\n",
773
  "epsilon = 1e-1\n",
774
  "\n",
 
801
  "\n",
802
  " # Compute the ranks of the input IDs, i.e. how many tokens would have been more likely than the correct one (the label, the input IDs)\n",
803
  " \n",
804
+ " # Calculate the ranks by summing the probabilities of tokens with higher logits than the correct token\n",
805
+ " ranks = torch.sum(probs > probs.gather(2, tokens[\"input_ids\"].unsqueeze(-1)), dim=-1) + 1\n",
806
  "\n",
807
  " # Compute the loss\n",
808
  " loss = loss_fn(\n",
 
823
  " for optimized, original in zip(ss_prompt, original_ss_prompt)\n",
824
  " )\n",
825
  " print(\n",
826
+ " f\"Step {step}, Loss: {loss.item()}, L2 norm: {l2_norm.item()}, avg rank: {ranks.float().mean().item()}\"\n",
827
  " )\n",
828
  "\n",
829
  " # Early stopping with patience\n",
 
836
  "\n",
837
  " if patience_counter >= patience:\n",
838
  " print(f\"Early stopping at step {step} with best loss {best_loss}\")\n",
839
+ " break\n",
840
+ "\n",
841
+ " # If the ranks are perfect (all 1), stop\n",
842
+ " if torch.all(ranks == 1):\n",
843
+ " print(f\"Perfect ranks achieved at step {step}, stopping optimization.\")\n",
844
  " break"
845
  ]
846
  },
847
  {
848
  "cell_type": "code",
849
+ "execution_count": 7,
850
  "id": "cc9a6a2f",
851
  "metadata": {},
852
  "outputs": [],
 
857
  },
858
  {
859
  "cell_type": "code",
860
+ "execution_count": 8,
861
  "id": "3186747d",
862
  "metadata": {},
863
+ "outputs": [
864
+ {
865
+ "name": "stderr",
866
+ "output_type": "stream",
867
+ "text": [
868
+ "100%|██████████| 150/150 [00:10<00:00, 14.83it/s]\n"
869
+ ]
870
+ },
871
+ {
872
+ "name": "stdout",
873
+ "output_type": "stream",
874
+ "text": [
875
+ "Reference: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps!\n",
876
+ "Generated: Hey this is a test text of me, Aldan, writing some random text in this text box. Will this work? Perhaps! I'll have to test this out. Maybe. I'll have to test this out some more. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe. Maybe\n",
877
+ "'Hey':\tRank 0.00, probability: 71.02%, Reference: 'Hey'\n",
878
+ "' this':\tRank 0.00, probability: 87.04%, Reference: ' this'\n",
879
+ "' is':\tRank 0.00, probability: 99.65%, Reference: ' is'\n",
880
+ "' a':\tRank 0.00, probability: 96.23%, Reference: ' a'\n",
881
+ "' test':\tRank 0.00, probability: 90.31%, Reference: ' test'\n",
882
+ "' text':\tRank 0.00, probability: 95.40%, Reference: ' text'\n",
883
+ "' of':\tRank 0.00, probability: 62.63%, Reference: ' of'\n",
884
+ "' me':\tRank 0.00, probability: 93.07%, Reference: ' me'\n",
885
+ "',':\tRank 0.00, probability: 95.56%, Reference: ','\n",
886
+ "' Ald':\tRank 0.00, probability: 71.15%, Reference: ' Ald'\n",
887
+ "'an':\tRank 0.00, probability: 84.45%, Reference: 'an'\n",
888
+ "',':\tRank 0.00, probability: 97.80%, Reference: ','\n",
889
+ "' writing':\tRank 0.00, probability: 95.00%, Reference: ' writing'\n",
890
+ "' some':\tRank 0.00, probability: 88.69%, Reference: ' some'\n",
891
+ "' random':\tRank 0.00, probability: 97.07%, Reference: ' random'\n",
892
+ "' text':\tRank 0.00, probability: 89.14%, Reference: ' text'\n",
893
+ "' in':\tRank 0.00, probability: 90.19%, Reference: ' in'\n",
894
+ "' this':\tRank 0.00, probability: 98.12%, Reference: ' this'\n",
895
+ "' text':\tRank 0.00, probability: 96.47%, Reference: ' text'\n",
896
+ "' box':\tRank 0.00, probability: 95.55%, Reference: ' box'\n",
897
+ "'.':\tRank 0.00, probability: 97.38%, Reference: '.'\n",
898
+ "' Will':\tRank 0.00, probability: 90.65%, Reference: ' Will'\n",
899
+ "' this':\tRank 0.00, probability: 93.50%, Reference: ' this'\n",
900
+ "' work':\tRank 0.00, probability: 49.78%, Reference: ' work'\n",
901
+ "'?':\tRank 0.00, probability: 96.33%, Reference: '?'\n",
902
+ "' Perhaps':\tRank 0.00, probability: 76.97%, Reference: ' Perhaps'\n",
903
+ "'!':\tRank 0.00, probability: 42.45%, Reference: '!'\n",
904
+ "' I':\tRank 0.00, probability: 11.67%, Reference: 'N/A'\n",
905
+ "''ll':\tRank 0.00, probability: 10.19%, Reference: 'N/A'\n",
906
+ "' have':\tRank 0.00, probability: 53.08%, Reference: 'N/A'\n",
907
+ "' to':\tRank 0.00, probability: 77.45%, Reference: 'N/A'\n",
908
+ "' test':\tRank 0.00, probability: 8.06%, Reference: 'N/A'\n",
909
+ "' this':\tRank 0.00, probability: 40.56%, Reference: 'N/A'\n",
910
+ "' out':\tRank 0.00, probability: 51.98%, Reference: 'N/A'\n",
911
+ "'.':\tRank 0.00, probability: 26.31%, Reference: 'N/A'\n",
912
+ "' Maybe':\tRank 0.00, probability: 27.11%, Reference: 'N/A'\n",
913
+ "'.':\tRank 0.00, probability: 22.16%, Reference: 'N/A'\n",
914
+ "' I':\tRank 0.00, probability: 20.00%, Reference: 'N/A'\n",
915
+ "''ll':\tRank 0.00, probability: 21.35%, Reference: 'N/A'\n",
916
+ "' have':\tRank 0.00, probability: 59.45%, Reference: 'N/A'\n",
917
+ "' to':\tRank 0.00, probability: 74.45%, Reference: 'N/A'\n",
918
+ "' test':\tRank 0.00, probability: 17.91%, Reference: 'N/A'\n",
919
+ "' this':\tRank 0.00, probability: 70.72%, Reference: 'N/A'\n",
920
+ "' out':\tRank 0.00, probability: 87.33%, Reference: 'N/A'\n",
921
+ "' some':\tRank 0.00, probability: 28.23%, Reference: 'N/A'\n",
922
+ "' more':\tRank 0.00, probability: 26.20%, Reference: 'N/A'\n",
923
+ "'.':\tRank 0.00, probability: 22.73%, Reference: 'N/A'\n",
924
+ "' Maybe':\tRank 0.00, probability: 36.98%, Reference: 'N/A'\n",
925
+ "'.':\tRank 0.00, probability: 54.91%, Reference: 'N/A'\n",
926
+ "' Maybe':\tRank 0.00, probability: 64.39%, Reference: 'N/A'\n",
927
+ "'.':\tRank 0.00, probability: 65.85%, Reference: 'N/A'\n",
928
+ "' Maybe':\tRank 0.00, probability: 65.80%, Reference: 'N/A'\n",
929
+ "'.':\tRank 0.00, probability: 51.26%, Reference: 'N/A'\n",
930
+ "' Maybe':\tRank 0.00, probability: 60.87%, Reference: 'N/A'\n",
931
+ "'.':\tRank 0.00, probability: 35.56%, Reference: 'N/A'\n",
932
+ "' Maybe':\tRank 0.00, probability: 48.26%, Reference: 'N/A'\n",
933
+ "'.':\tRank 0.00, probability: 28.23%, Reference: 'N/A'\n",
934
+ "' Maybe':\tRank 0.00, probability: 43.46%, Reference: 'N/A'\n",
935
+ "'.':\tRank 0.00, probability: 23.53%, Reference: 'N/A'\n",
936
+ "' Maybe':\tRank 0.00, probability: 42.47%, Reference: 'N/A'\n",
937
+ "'.':\tRank 0.00, probability: 21.56%, Reference: 'N/A'\n",
938
+ "' Maybe':\tRank 0.00, probability: 44.99%, Reference: 'N/A'\n",
939
+ "'.':\tRank 0.00, probability: 22.61%, Reference: 'N/A'\n",
940
+ "' Maybe':\tRank 0.00, probability: 48.24%, Reference: 'N/A'\n",
941
+ "'.':\tRank 0.00, probability: 19.47%, Reference: 'N/A'\n",
942
+ "' Maybe':\tRank 0.00, probability: 51.23%, Reference: 'N/A'\n",
943
+ "'.':\tRank 0.00, probability: 19.13%, Reference: 'N/A'\n",
944
+ "' Maybe':\tRank 0.00, probability: 54.27%, Reference: 'N/A'\n",
945
+ "'.':\tRank 0.00, probability: 20.28%, Reference: 'N/A'\n",
946
+ "' Maybe':\tRank 0.00, probability: 56.09%, Reference: 'N/A'\n",
947
+ "'.':\tRank 0.00, probability: 23.08%, Reference: 'N/A'\n",
948
+ "' Maybe':\tRank 0.00, probability: 60.23%, Reference: 'N/A'\n",
949
+ "'.':\tRank 0.00, probability: 25.86%, Reference: 'N/A'\n",
950
+ "' Maybe':\tRank 0.00, probability: 63.49%, Reference: 'N/A'\n",
951
+ "'.':\tRank 0.00, probability: 28.66%, Reference: 'N/A'\n",
952
+ "' Maybe':\tRank 0.00, probability: 65.52%, Reference: 'N/A'\n",
953
+ "'.':\tRank 0.00, probability: 29.81%, Reference: 'N/A'\n",
954
+ "' Maybe':\tRank 0.00, probability: 67.58%, Reference: 'N/A'\n",
955
+ "'.':\tRank 0.00, probability: 32.28%, Reference: 'N/A'\n",
956
+ "' Maybe':\tRank 0.00, probability: 69.40%, Reference: 'N/A'\n",
957
+ "'.':\tRank 0.00, probability: 33.50%, Reference: 'N/A'\n",
958
+ "' Maybe':\tRank 0.00, probability: 70.15%, Reference: 'N/A'\n",
959
+ "'.':\tRank 0.00, probability: 34.01%, Reference: 'N/A'\n",
960
+ "' Maybe':\tRank 0.00, probability: 71.20%, Reference: 'N/A'\n",
961
+ "'.':\tRank 0.00, probability: 36.83%, Reference: 'N/A'\n",
962
+ "' Maybe':\tRank 0.00, probability: 71.29%, Reference: 'N/A'\n",
963
+ "'.':\tRank 0.00, probability: 38.54%, Reference: 'N/A'\n",
964
+ "' Maybe':\tRank 0.00, probability: 72.09%, Reference: 'N/A'\n",
965
+ "'.':\tRank 0.00, probability: 40.10%, Reference: 'N/A'\n",
966
+ "' Maybe':\tRank 0.00, probability: 71.66%, Reference: 'N/A'\n",
967
+ "'.':\tRank 0.00, probability: 41.78%, Reference: 'N/A'\n",
968
+ "' Maybe':\tRank 0.00, probability: 72.04%, Reference: 'N/A'\n",
969
+ "'.':\tRank 0.00, probability: 42.55%, Reference: 'N/A'\n",
970
+ "' Maybe':\tRank 0.00, probability: 71.18%, Reference: 'N/A'\n",
971
+ "'.':\tRank 0.00, probability: 44.33%, Reference: 'N/A'\n",
972
+ "' Maybe':\tRank 0.00, probability: 70.62%, Reference: 'N/A'\n",
973
+ "'.':\tRank 0.00, probability: 44.84%, Reference: 'N/A'\n",
974
+ "' Maybe':\tRank 0.00, probability: 70.70%, Reference: 'N/A'\n",
975
+ "'.':\tRank 0.00, probability: 44.89%, Reference: 'N/A'\n",
976
+ "' Maybe':\tRank 0.00, probability: 71.80%, Reference: 'N/A'\n",
977
+ "'.':\tRank 0.00, probability: 46.26%, Reference: 'N/A'\n",
978
+ "' Maybe':\tRank 0.00, probability: 71.88%, Reference: 'N/A'\n",
979
+ "'.':\tRank 0.00, probability: 46.11%, Reference: 'N/A'\n",
980
+ "' Maybe':\tRank 0.00, probability: 72.24%, Reference: 'N/A'\n",
981
+ "'.':\tRank 0.00, probability: 44.98%, Reference: 'N/A'\n",
982
+ "' Maybe':\tRank 0.00, probability: 73.18%, Reference: 'N/A'\n",
983
+ "'.':\tRank 0.00, probability: 45.08%, Reference: 'N/A'\n",
984
+ "' Maybe':\tRank 0.00, probability: 75.74%, Reference: 'N/A'\n",
985
+ "'.':\tRank 0.00, probability: 45.33%, Reference: 'N/A'\n",
986
+ "' Maybe':\tRank 0.00, probability: 76.68%, Reference: 'N/A'\n",
987
+ "'.':\tRank 0.00, probability: 46.90%, Reference: 'N/A'\n",
988
+ "' Maybe':\tRank 0.00, probability: 78.19%, Reference: 'N/A'\n",
989
+ "'.':\tRank 0.00, probability: 46.31%, Reference: 'N/A'\n",
990
+ "' Maybe':\tRank 0.00, probability: 78.19%, Reference: 'N/A'\n",
991
+ "'.':\tRank 0.00, probability: 48.19%, Reference: 'N/A'\n",
992
+ "' Maybe':\tRank 0.00, probability: 79.38%, Reference: 'N/A'\n",
993
+ "'.':\tRank 0.00, probability: 48.56%, Reference: 'N/A'\n",
994
+ "' Maybe':\tRank 0.00, probability: 80.29%, Reference: 'N/A'\n",
995
+ "'.':\tRank 0.00, probability: 49.49%, Reference: 'N/A'\n",
996
+ "' Maybe':\tRank 0.00, probability: 79.51%, Reference: 'N/A'\n",
997
+ "'.':\tRank 0.00, probability: 51.13%, Reference: 'N/A'\n",
998
+ "' Maybe':\tRank 0.00, probability: 81.70%, Reference: 'N/A'\n",
999
+ "'.':\tRank 0.00, probability: 52.29%, Reference: 'N/A'\n",
1000
+ "' Maybe':\tRank 0.00, probability: 80.87%, Reference: 'N/A'\n",
1001
+ "'.':\tRank 0.00, probability: 52.84%, Reference: 'N/A'\n",
1002
+ "' Maybe':\tRank 0.00, probability: 81.61%, Reference: 'N/A'\n",
1003
+ "'.':\tRank 0.00, probability: 54.15%, Reference: 'N/A'\n",
1004
+ "' Maybe':\tRank 0.00, probability: 81.67%, Reference: 'N/A'\n",
1005
+ "'.':\tRank 0.00, probability: 55.17%, Reference: 'N/A'\n",
1006
+ "' Maybe':\tRank 0.00, probability: 82.23%, Reference: 'N/A'\n",
1007
+ "'.':\tRank 0.00, probability: 55.47%, Reference: 'N/A'\n",
1008
+ "' Maybe':\tRank 0.00, probability: 82.99%, Reference: 'N/A'\n",
1009
+ "'.':\tRank 0.00, probability: 57.04%, Reference: 'N/A'\n",
1010
+ "' Maybe':\tRank 0.00, probability: 83.32%, Reference: 'N/A'\n",
1011
+ "'.':\tRank 0.00, probability: 58.06%, Reference: 'N/A'\n",
1012
+ "' Maybe':\tRank 0.00, probability: 84.47%, Reference: 'N/A'\n",
1013
+ "'.':\tRank 0.00, probability: 59.85%, Reference: 'N/A'\n",
1014
+ "' Maybe':\tRank 0.00, probability: 83.82%, Reference: 'N/A'\n",
1015
+ "'.':\tRank 0.00, probability: 58.62%, Reference: 'N/A'\n",
1016
+ "' Maybe':\tRank 0.00, probability: 85.14%, Reference: 'N/A'\n",
1017
+ "'.':\tRank 0.00, probability: 60.26%, Reference: 'N/A'\n",
1018
+ "' Maybe':\tRank 0.00, probability: 85.67%, Reference: 'N/A'\n",
1019
+ "'.':\tRank 0.00, probability: 62.56%, Reference: 'N/A'\n",
1020
+ "' Maybe':\tRank 0.00, probability: 86.05%, Reference: 'N/A'\n",
1021
+ "'.':\tRank 0.00, probability: 64.28%, Reference: 'N/A'\n",
1022
+ "' Maybe':\tRank 0.00, probability: 86.93%, Reference: 'N/A'\n",
1023
+ "'.':\tRank 0.00, probability: 67.84%, Reference: 'N/A'\n",
1024
+ "' Maybe':\tRank 0.00, probability: 87.52%, Reference: 'N/A'\n",
1025
+ "'.':\tRank 0.00, probability: 73.01%, Reference: 'N/A'\n",
1026
+ "' Maybe':\tRank 0.00, probability: 88.17%, Reference: 'N/A'\n"
1027
+ ]
1028
+ }
1029
+ ],
1030
  "source": [
1031
  "import os\n",
1032
  "import transformers\n",
1033
  "import torch\n",
1034
  "import tqdm\n",
1035
  "\n",
1036
+ "if \"tokenizer\" not in locals():\n",
1037
+ " tokenizer: transformers.PreTrainedTokenizer = (\n",
1038
+ " transformers.AutoTokenizer.from_pretrained(\"openai-community/gpt2\")\n",
1039
+ " )\n",
1040
+ "if \"model\" not in locals():\n",
1041
+ " model: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(\n",
1042
+ " \"openai-community/gpt2\"\n",
1043
+ " )\n",
1044
  "\n",
1045
  "# Load the best soft prompt from file\n",
1046
  "best_ss_prompt = torch.load(\"best_ss_prompt.pt\")\n",