codeShare commited on
Commit
38a8e51
·
verified ·
1 Parent(s): 19fa3ec

Upload sd_token_similarity_calculator.ipynb

Browse files
Files changed (1) hide show
  1. sd_token_similarity_calculator.ipynb +426 -24
sd_token_similarity_calculator.ipynb CHANGED
@@ -82,10 +82,29 @@
82
  "mix_method = \"None\""
83
  ],
84
  "metadata": {
85
- "id": "Ch9puvwKH1s3"
 
 
 
 
 
86
  },
87
- "execution_count": null,
88
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  },
90
  {
91
  "cell_type": "code",
@@ -106,10 +125,22 @@
106
  "#You can leave the 'prompt' field empty to get a random value tensor. Since the tensor is random value, it will not correspond to any tensor in the vocab.json list , and this it will have no ID."
107
  ],
108
  "metadata": {
109
- "id": "RPdkYzT2_X85"
 
 
 
 
110
  },
111
- "execution_count": null,
112
- "outputs": []
 
 
 
 
 
 
 
 
113
  },
114
  {
115
  "cell_type": "code",
@@ -136,7 +167,7 @@
136
  "metadata": {
137
  "id": "YqdiF8DIz9Wu"
138
  },
139
- "execution_count": null,
140
  "outputs": []
141
  },
142
  {
@@ -189,10 +220,24 @@
189
  "#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor"
190
  ],
191
  "metadata": {
192
- "id": "oXbNSRSKPgRr"
 
 
 
 
 
193
  },
194
- "execution_count": null,
195
- "outputs": []
 
 
 
 
 
 
 
 
 
196
  },
197
  {
198
  "cell_type": "code",
@@ -230,10 +275,30 @@
230
  "#Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result"
231
  ],
232
  "metadata": {
233
- "id": "juxsvco9B0iV"
 
 
 
 
 
234
  },
235
- "execution_count": null,
236
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  },
238
  {
239
  "cell_type": "code",
@@ -260,10 +325,321 @@
260
  ],
261
  "metadata": {
262
  "id": "YIEmLAzbHeuo",
263
- "collapsed": true
 
 
 
 
264
  },
265
- "execution_count": null,
266
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  },
268
  {
269
  "cell_type": "code",
@@ -280,10 +656,23 @@
280
  "#Valid ID ranges for id_for_token_A / id_for_token_B are between 0 and 49407"
281
  ],
282
  "metadata": {
283
- "id": "MwmOdC9cNZty"
 
 
 
 
 
284
  },
285
- "execution_count": null,
286
- "outputs": []
 
 
 
 
 
 
 
 
287
  },
288
  {
289
  "cell_type": "code",
@@ -292,7 +681,7 @@
292
  "\n",
293
  "prompt_A = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
294
  "prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
295
- "use_token_padding = False # @param {type:\"boolean\"}\n",
296
  "\n",
297
  "from transformers import CLIPProcessor, CLIPModel\n",
298
  "\n",
@@ -307,7 +696,7 @@
307
  "ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
308
  "text_encoding_B = model.get_text_features(**ids_B)\n",
309
  "\n",
310
- "similarity_str = 'The similarity between the text_encoding for A and B is ' + token_similarity(text_encoding_A[0] , text_encoding_B[0])\n",
311
  "\n",
312
  "\n",
313
  "print(similarity_str)\n",
@@ -319,10 +708,23 @@
319
  "\n"
320
  ],
321
  "metadata": {
322
- "id": "QQOjh5BvnG8M"
 
 
 
 
 
323
  },
324
- "execution_count": null,
325
- "outputs": []
 
 
 
 
 
 
 
 
326
  },
327
  {
328
  "cell_type": "markdown",
 
82
  "mix_method = \"None\""
83
  ],
84
  "metadata": {
85
+ "id": "Ch9puvwKH1s3",
86
+ "collapsed": true,
87
+ "colab": {
88
+ "base_uri": "https://localhost:8080/"
89
+ },
90
+ "outputId": "982a9210-a3fd-4d90-bef7-5aa6f5864797"
91
  },
92
+ "execution_count": 2,
93
+ "outputs": [
94
+ {
95
+ "output_type": "stream",
96
+ "name": "stdout",
97
+ "text": [
98
+ "Cloning into 'sd_tokens'...\n",
99
+ "remote: Enumerating objects: 10, done.\u001b[K\n",
100
+ "remote: Counting objects: 100% (7/7), done.\u001b[K\n",
101
+ "remote: Compressing objects: 100% (7/7), done.\u001b[K\n",
102
+ "remote: Total 10 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
103
+ "Unpacking objects: 100% (10/10), 306.93 KiB | 4.72 MiB/s, done.\n",
104
+ "/content/sd_tokens\n"
105
+ ]
106
+ }
107
+ ]
108
  },
109
  {
110
  "cell_type": "code",
 
125
  "#You can leave the 'prompt' field empty to get a random value tensor. Since the tensor is random value, it will not correspond to any tensor in the vocab.json list , and this it will have no ID."
126
  ],
127
  "metadata": {
128
+ "id": "RPdkYzT2_X85",
129
+ "colab": {
130
+ "base_uri": "https://localhost:8080/"
131
+ },
132
+ "outputId": "86f2f01e-6a04-4292-cee7-70fd8398e07f"
133
  },
134
+ "execution_count": 3,
135
+ "outputs": [
136
+ {
137
+ "output_type": "stream",
138
+ "name": "stdout",
139
+ "text": [
140
+ "[49406, 8922, 49407]\n"
141
+ ]
142
+ }
143
+ ]
144
  },
145
  {
146
  "cell_type": "code",
 
167
  "metadata": {
168
  "id": "YqdiF8DIz9Wu"
169
  },
170
+ "execution_count": 4,
171
  "outputs": []
172
  },
173
  {
 
220
  "#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor"
221
  ],
222
  "metadata": {
223
+ "id": "oXbNSRSKPgRr",
224
+ "collapsed": true,
225
+ "colab": {
226
+ "base_uri": "https://localhost:8080/"
227
+ },
228
+ "outputId": "76f8ec94-d29c-46d9-893b-49875f3a1949"
229
  },
230
+ "execution_count": 5,
231
+ "outputs": [
232
+ {
233
+ "output_type": "stream",
234
+ "name": "stdout",
235
+ "text": [
236
+ "Tokenized prompt 'mix_with' tensor C is a random valued tensor with no ID\n",
237
+ "No operation\n"
238
+ ]
239
+ }
240
+ ]
241
  },
242
  {
243
  "cell_type": "code",
 
275
  "#Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result"
276
  ],
277
  "metadata": {
278
+ "id": "juxsvco9B0iV",
279
+ "collapsed": true,
280
+ "colab": {
281
+ "base_uri": "https://localhost:8080/"
282
+ },
283
+ "outputId": "dc893bbf-e9cb-425c-95b8-ffafd3ab2fbc"
284
  },
285
+ "execution_count": 6,
286
+ "outputs": [
287
+ {
288
+ "output_type": "stream",
289
+ "name": "stdout",
290
+ "text": [
291
+ "Calculated all cosine-similarities between the token banana</w> with Id_A = 8922 with the the rest of the 49407 tokens as a 1x49407 tensor\n"
292
+ ]
293
+ }
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "source": [],
299
+ "metadata": {
300
+ "id": "cYYu5C5C6MHH"
301
+ }
302
  },
303
  {
304
  "cell_type": "code",
 
325
  ],
326
  "metadata": {
327
  "id": "YIEmLAzbHeuo",
328
+ "collapsed": true,
329
+ "colab": {
330
+ "base_uri": "https://localhost:8080/"
331
+ },
332
+ "outputId": "4a2fa70f-16ff-4bba-fb01-d39ad697d4ff"
333
  },
334
+ "execution_count": 7,
335
+ "outputs": [
336
+ {
337
+ "output_type": "stream",
338
+ "name": "stdout",
339
+ "text": [
340
+ "banana</w>\n",
341
+ "similiarity = 100.0 %\n",
342
+ "--------\n",
343
+ "bananas</w>\n",
344
+ "similiarity = 38.93 %\n",
345
+ "--------\n",
346
+ "banan\n",
347
+ "similiarity = 30.8 %\n",
348
+ "--------\n",
349
+ "ðŁįĮ</w>\n",
350
+ "similiarity = 27.12 %\n",
351
+ "--------\n",
352
+ "pineapple</w>\n",
353
+ "similiarity = 19.7 %\n",
354
+ "--------\n",
355
+ "chicken</w>\n",
356
+ "similiarity = 19.24 %\n",
357
+ "--------\n",
358
+ "potassium</w>\n",
359
+ "similiarity = 19.21 %\n",
360
+ "--------\n",
361
+ "sausage</w>\n",
362
+ "similiarity = 19.07 %\n",
363
+ "--------\n",
364
+ "lemon</w>\n",
365
+ "similiarity = 18.82 %\n",
366
+ "--------\n",
367
+ "orange</w>\n",
368
+ "similiarity = 18.42 %\n",
369
+ "--------\n",
370
+ "peanut</w>\n",
371
+ "similiarity = 17.84 %\n",
372
+ "--------\n",
373
+ "parachute</w>\n",
374
+ "similiarity = 17.19 %\n",
375
+ "--------\n",
376
+ "duck\n",
377
+ "similiarity = 16.8 %\n",
378
+ "--------\n",
379
+ "yellow</w>\n",
380
+ "similiarity = 16.21 %\n",
381
+ "--------\n",
382
+ "grape</w>\n",
383
+ "similiarity = 16.19 %\n",
384
+ "--------\n",
385
+ "kangaroo</w>\n",
386
+ "similiarity = 16.13 %\n",
387
+ "--------\n",
388
+ "apple</w>\n",
389
+ "similiarity = 16.13 %\n",
390
+ "--------\n",
391
+ "tangerine</w>\n",
392
+ "similiarity = 16.08 %\n",
393
+ "--------\n",
394
+ "giraffe</w>\n",
395
+ "similiarity = 16.04 %\n",
396
+ "--------\n",
397
+ "mango</w>\n",
398
+ "similiarity = 16.03 %\n",
399
+ "--------\n",
400
+ "rubber</w>\n",
401
+ "similiarity = 15.95 %\n",
402
+ "--------\n",
403
+ "bamboo</w>\n",
404
+ "similiarity = 15.88 %\n",
405
+ "--------\n",
406
+ "umbrella</w>\n",
407
+ "similiarity = 15.82 %\n",
408
+ "--------\n",
409
+ "nutella</w>\n",
410
+ "similiarity = 15.69 %\n",
411
+ "--------\n",
412
+ "ferrari</w>\n",
413
+ "similiarity = 15.69 %\n",
414
+ "--------\n",
415
+ "oranges</w>\n",
416
+ "similiarity = 15.65 %\n",
417
+ "--------\n",
418
+ "peanuts</w>\n",
419
+ "similiarity = 15.62 %\n",
420
+ "--------\n",
421
+ "ali</w>\n",
422
+ "similiarity = 15.49 %\n",
423
+ "--------\n",
424
+ "cucumber</w>\n",
425
+ "similiarity = 15.32 %\n",
426
+ "--------\n",
427
+ "potato</w>\n",
428
+ "similiarity = 15.22 %\n",
429
+ "--------\n",
430
+ "monkey</w>\n",
431
+ "similiarity = 15.2 %\n",
432
+ "--------\n",
433
+ "croissant</w>\n",
434
+ "similiarity = 15.18 %\n",
435
+ "--------\n",
436
+ "papaya</w>\n",
437
+ "similiarity = 15.17 %\n",
438
+ "--------\n",
439
+ "christmas</w>\n",
440
+ "similiarity = 15.12 %\n",
441
+ "--------\n",
442
+ "sandwich</w>\n",
443
+ "similiarity = 15.0 %\n",
444
+ "--------\n",
445
+ "rainbow</w>\n",
446
+ "similiarity = 14.98 %\n",
447
+ "--------\n",
448
+ "tomato</w>\n",
449
+ "similiarity = 14.96 %\n",
450
+ "--------\n",
451
+ "martini</w>\n",
452
+ "similiarity = 14.93 %\n",
453
+ "--------\n",
454
+ "cabaret</w>\n",
455
+ "similiarity = 14.83 %\n",
456
+ "--------\n",
457
+ "ginger</w>\n",
458
+ "similiarity = 14.82 %\n",
459
+ "--------\n",
460
+ "animal</w>\n",
461
+ "similiarity = 14.76 %\n",
462
+ "--------\n",
463
+ "vanilla</w>\n",
464
+ "similiarity = 14.73 %\n",
465
+ "--------\n",
466
+ "mustache</w>\n",
467
+ "similiarity = 14.64 %\n",
468
+ "--------\n",
469
+ "lime</w>\n",
470
+ "similiarity = 14.62 %\n",
471
+ "--------\n",
472
+ "sickle</w>\n",
473
+ "similiarity = 14.6 %\n",
474
+ "--------\n",
475
+ "vista</w>\n",
476
+ "similiarity = 14.53 %\n",
477
+ "--------\n",
478
+ "coconut</w>\n",
479
+ "similiarity = 14.52 %\n",
480
+ "--------\n",
481
+ "kara</w>\n",
482
+ "similiarity = 14.46 %\n",
483
+ "--------\n",
484
+ "alligator</w>\n",
485
+ "similiarity = 14.39 %\n",
486
+ "--------\n",
487
+ "blueberry</w>\n",
488
+ "similiarity = 14.34 %\n",
489
+ "--------\n",
490
+ "squirrel</w>\n",
491
+ "similiarity = 14.29 %\n",
492
+ "--------\n",
493
+ "atore</w>\n",
494
+ "similiarity = 14.19 %\n",
495
+ "--------\n",
496
+ "watermelon</w>\n",
497
+ "similiarity = 14.13 %\n",
498
+ "--------\n",
499
+ "nana</w>\n",
500
+ "similiarity = 14.09 %\n",
501
+ "--------\n",
502
+ "latex</w>\n",
503
+ "similiarity = 14.08 %\n",
504
+ "--------\n",
505
+ "agricultural</w>\n",
506
+ "similiarity = 14.02 %\n",
507
+ "--------\n",
508
+ "zucchini</w>\n",
509
+ "similiarity = 14.0 %\n",
510
+ "--------\n",
511
+ "saxophone</w>\n",
512
+ "similiarity = 13.93 %\n",
513
+ "--------\n",
514
+ "mozzarella</w>\n",
515
+ "similiarity = 13.91 %\n",
516
+ "--------\n",
517
+ "eggplant</w>\n",
518
+ "similiarity = 13.9 %\n",
519
+ "--------\n",
520
+ "pickle</w>\n",
521
+ "similiarity = 13.89 %\n",
522
+ "--------\n",
523
+ "tortilla</w>\n",
524
+ "similiarity = 13.88 %\n",
525
+ "--------\n",
526
+ "maniac</w>\n",
527
+ "similiarity = 13.84 %\n",
528
+ "--------\n",
529
+ "milk</w>\n",
530
+ "similiarity = 13.83 %\n",
531
+ "--------\n",
532
+ "cellphone</w>\n",
533
+ "similiarity = 13.78 %\n",
534
+ "--------\n",
535
+ "duck</w>\n",
536
+ "similiarity = 13.73 %\n",
537
+ "--------\n",
538
+ "umbrel\n",
539
+ "similiarity = 13.71 %\n",
540
+ "--------\n",
541
+ "fanny</w>\n",
542
+ "similiarity = 13.69 %\n",
543
+ "--------\n",
544
+ "twister</w>\n",
545
+ "similiarity = 13.67 %\n",
546
+ "--------\n",
547
+ "moustache</w>\n",
548
+ "similiarity = 13.66 %\n",
549
+ "--------\n",
550
+ "manafort</w>\n",
551
+ "similiarity = 13.66 %\n",
552
+ "--------\n",
553
+ "grapefruit</w>\n",
554
+ "similiarity = 13.6 %\n",
555
+ "--------\n",
556
+ "broom</w>\n",
557
+ "similiarity = 13.59 %\n",
558
+ "--------\n",
559
+ "scorpion</w>\n",
560
+ "similiarity = 13.59 %\n",
561
+ "--------\n",
562
+ "fruit\n",
563
+ "similiarity = 13.57 %\n",
564
+ "--------\n",
565
+ "agan\n",
566
+ "similiarity = 13.53 %\n",
567
+ "--------\n",
568
+ "sunflower</w>\n",
569
+ "similiarity = 13.49 %\n",
570
+ "--------\n",
571
+ "banc\n",
572
+ "similiarity = 13.46 %\n",
573
+ "--------\n",
574
+ "literature</w>\n",
575
+ "similiarity = 13.45 %\n",
576
+ "--------\n",
577
+ "pelican</w>\n",
578
+ "similiarity = 13.43 %\n",
579
+ "--------\n",
580
+ "breakfast</w>\n",
581
+ "similiarity = 13.42 %\n",
582
+ "--------\n",
583
+ "pear</w>\n",
584
+ "similiarity = 13.42 %\n",
585
+ "--------\n",
586
+ "orange\n",
587
+ "similiarity = 13.4 %\n",
588
+ "--------\n",
589
+ "monet</w>\n",
590
+ "similiarity = 13.4 %\n",
591
+ "--------\n",
592
+ "snake</w>\n",
593
+ "similiarity = 13.32 %\n",
594
+ "--------\n",
595
+ "vampire</w>\n",
596
+ "similiarity = 13.32 %\n",
597
+ "--------\n",
598
+ "cinnamon</w>\n",
599
+ "similiarity = 13.3 %\n",
600
+ "--------\n",
601
+ "strawberries</w>\n",
602
+ "similiarity = 13.29 %\n",
603
+ "--------\n",
604
+ "butternut</w>\n",
605
+ "similiarity = 13.22 %\n",
606
+ "--------\n",
607
+ "sausages</w>\n",
608
+ "similiarity = 13.22 %\n",
609
+ "--------\n",
610
+ "iphone</w>\n",
611
+ "similiarity = 13.21 %\n",
612
+ "--------\n",
613
+ "egg\n",
614
+ "similiarity = 13.2 %\n",
615
+ "--------\n",
616
+ "capu\n",
617
+ "similiarity = 13.2 %\n",
618
+ "--------\n",
619
+ "mannequin</w>\n",
620
+ "similiarity = 13.19 %\n",
621
+ "--------\n",
622
+ "cucumbers</w>\n",
623
+ "similiarity = 13.16 %\n",
624
+ "--------\n",
625
+ "champagne</w>\n",
626
+ "similiarity = 13.15 %\n",
627
+ "--------\n",
628
+ "triangle</w>\n",
629
+ "similiarity = 13.14 %\n",
630
+ "--------\n",
631
+ "apples</w>\n",
632
+ "similiarity = 13.09 %\n",
633
+ "--------\n",
634
+ "dynamite</w>\n",
635
+ "similiarity = 13.08 %\n",
636
+ "--------\n",
637
+ "chocolate</w>\n",
638
+ "similiarity = 13.08 %\n",
639
+ "--------\n"
640
+ ]
641
+ }
642
+ ]
643
  },
644
  {
645
  "cell_type": "code",
 
656
  "#Valid ID ranges for id_for_token_A / id_for_token_B are between 0 and 49407"
657
  ],
658
  "metadata": {
659
+ "id": "MwmOdC9cNZty",
660
+ "collapsed": true,
661
+ "colab": {
662
+ "base_uri": "https://localhost:8080/"
663
+ },
664
+ "outputId": "0dd984d0-e253-4981-d72f-40aa83d57d8b"
665
  },
666
+ "execution_count": 8,
667
+ "outputs": [
668
+ {
669
+ "output_type": "stream",
670
+ "name": "stdout",
671
+ "text": [
672
+ "The similarity between tokens A and B is 3.671 %\n"
673
+ ]
674
+ }
675
+ ]
676
  },
677
  {
678
  "cell_type": "code",
 
681
  "\n",
682
  "prompt_A = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
683
  "prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
684
+ "use_token_padding = True # @param {type:\"boolean\"}\n",
685
  "\n",
686
  "from transformers import CLIPProcessor, CLIPModel\n",
687
  "\n",
 
696
  "ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
697
  "text_encoding_B = model.get_text_features(**ids_B)\n",
698
  "\n",
699
+ "similarity_str = 'The similarity between the text_encoding for A:\"' + prompt_A + '\" and B: \"' + prompt_B +'\" is ' + token_similarity(text_encoding_A[0] , text_encoding_B[0])\n",
700
  "\n",
701
  "\n",
702
  "print(similarity_str)\n",
 
708
  "\n"
709
  ],
710
  "metadata": {
711
+ "id": "QQOjh5BvnG8M",
712
+ "collapsed": true,
713
+ "colab": {
714
+ "base_uri": "https://localhost:8080/"
715
+ },
716
+ "outputId": "8bd6eb94-c5a7-47e6-913b-346941b144a6"
717
  },
718
+ "execution_count": 11,
719
+ "outputs": [
720
+ {
721
+ "output_type": "stream",
722
+ "name": "stdout",
723
+ "text": [
724
+ "The similarity between the text_encoding for A:\"one ripe banana\" and B: \"a long yellow fruit\" is 83.696 %\n"
725
+ ]
726
+ }
727
+ ]
728
  },
729
  {
730
  "cell_type": "markdown",