File size: 36,538 Bytes
ed00004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import gradio as gr "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from src.data.embs import ImageDataset\n",
    "from src.model.blip_embs import blip_embs\n",
    "\n",
    "from demo_chat import Chat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data.transforms import transform_test\n",
    "\n",
    "transform = transform_test(384)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_blip_config(model=\"base\"):\n",
    "    config = dict()\n",
    "    if model == \"base\":\n",
    "        config[\n",
    "            \"pretrained\"\n",
    "        ] = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth \"\n",
    "        config[\"vit\"] = \"base\"\n",
    "        config[\"batch_size_train\"] = 32\n",
    "        config[\"batch_size_test\"] = 16\n",
    "        config[\"vit_grad_ckpt\"] = True\n",
    "        config[\"vit_ckpt_layer\"] = 4\n",
    "        config[\"init_lr\"] = 1e-5\n",
    "    elif model == \"large\":\n",
    "        config[\n",
    "            \"pretrained\"\n",
    "        ] = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth\"\n",
    "        config[\"vit\"] = \"large\"\n",
    "        config[\"batch_size_train\"] = 16\n",
    "        config[\"batch_size_test\"] = 32\n",
    "        config[\"vit_grad_ckpt\"] = True\n",
    "        config[\"vit_ckpt_layer\"] = 12\n",
    "        config[\"init_lr\"] = 5e-6\n",
    "\n",
    "    config[\"image_size\"] = 384\n",
    "    config[\"queue_size\"] = 57600\n",
    "    config[\"alpha\"] = 0.4\n",
    "    config[\"k_test\"] = 256\n",
    "    config[\"negative_all_rank\"] = True\n",
    "\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating model\n",
      "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth\n",
      "missing keys:\n",
      "[]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "BLIPEmbs(\n",
       "  (visual_encoder): VisionTransformer(\n",
       "    (patch_embed): PatchEmbed(\n",
       "      (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))\n",
       "      (norm): Identity()\n",
       "    )\n",
       "    (pos_drop): Dropout(p=0.0, inplace=False)\n",
       "    (blocks): ModuleList(\n",
       "      (0): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.004)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (2): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.009)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (3): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.013)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (4): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.017)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (5): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.022)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (6): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.026)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (7): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.030)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (8): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.035)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (9): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.039)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (10): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.043)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (11): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.048)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (12): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.052)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (13): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.057)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (14): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.061)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (15): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.065)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (16): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.070)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (17): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.074)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (18): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.078)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (19): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.083)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (20): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.087)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (21): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.091)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (22): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.096)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (23): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.100)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "  )\n",
       "  (text_encoder): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(30524, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (crossattention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=1024, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=1024, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (vision_proj): Linear(in_features=1024, out_features=256, bias=True)\n",
       "  (text_proj): Linear(in_features=768, out_features=256, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Creating model\")\n",
    "config = get_blip_config(\"large\")\n",
    "\n",
    "model = blip_embs(\n",
    "        pretrained=config[\"pretrained\"],\n",
    "        image_size=config[\"image_size\"],\n",
    "        vit=config[\"vit\"],\n",
    "        vit_grad_ckpt=config[\"vit_grad_ckpt\"],\n",
    "        vit_ckpt_layer=config[\"vit_ckpt_layer\"],\n",
    "        queue_size=config[\"queue_size\"],\n",
    "        negative_all_rank=config[\"negative_all_rank\"],\n",
    "    )\n",
    "\n",
    "model = model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_json(\"datasets/sidechef/my_recipes.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Target Embedding\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading Target Embedding\")\n",
    "tar_img_feats = []\n",
    "for _id in df[\"id_\"].tolist():     \n",
    "    tar_img_feats.append(torch.load(\"datasets/sidechef/blip-embs-large/{:07d}.pth\".format(_id)).unsqueeze(0))\n",
    "\n",
    "tar_img_feats = torch.cat(tar_img_feats, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7866\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"650px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n",
    "# Define the custom CSS to add a footer\n",
    "custom_css = \"\"\"\n",
    "/* Footer style */\n",
    ".gradio-footer {\n",
    "    display: flex;\n",
    "    justify-content: center;\n",
    "    align-items: center;\n",
    "    padding: 10px;\n",
    "    background-color: #f8f9fa;\n",
    "    color: #333;\n",
    "    font-size: 0.9em;\n",
    "}\n",
    "\n",
    ".custom-header {\n",
    "    text-align: center;\n",
    "    padding: 12px;\n",
    "    background-color: #333; \n",
    "    color: white;\n",
    "    position: bottom;\n",
    "    bottom: 0;\n",
    "    width: 100%;\n",
    "    font-size: 0.8em;\n",
    "}\n",
    "\n",
    ".footer {\n",
    "    width: 100%;\n",
    "    background-color: #f2f2f2;\n",
    "    color: #555;\n",
    "    text-align: center;\n",
    "    padding: 10px 0;\n",
    "    position: absolute;\n",
    "    bottom: 0;\n",
    "    left: 0;\n",
    "}\n",
    "\n",
    "/* Make sure the interface leaves space for the footer */\n",
    ".body {\n",
    "    margin-bottom: 50px;\n",
    "}\n",
    "\"\"\"\n",
    "\n",
    "# Add a custom footer by injecting HTML into the description\n",
    "custom_footer_html = \"\"\"\n",
    "<footer> <p> Reach out to us at {omkar.thawakar, muzammal.naseer}@mbzuai.ac.ae </p> </footer>\n",
    "\"\"\"\n",
    "\n",
    "custom_header_html = \"\"\"\n",
    "<div class='custom-header'>Nutrition-GPT Demo</div>\n",
    "\"\"\"\n",
    "\n",
    "def respond_to_user(image, message):\n",
    "    # Process the image and message here\n",
    "    # For demonstration, I'll just return a simple text response\n",
    "    chat = Chat(model,transform,df,tar_img_feats)\n",
    "    chat.encode_image(image)\n",
    "    response = chat.ask(message)\n",
    "    return response\n",
    "\n",
    "iface = gr.Interface(\n",
    "    fn=respond_to_user,\n",
    "    inputs=[gr.Image(height=\"70%\"), gr.Textbox(label=\"Ask Query\"),],\n",
    "    outputs=[gr.Textbox(label=\"Nutrition-GPT\")],\n",
    "    title=custom_header_html,  \n",
    "    # description=\"Upload an food image and ask queries!\",\n",
    "    css=custom_css,\n",
    "    # description=custom_footer_html   \n",
    ")\n",
    "\n",
    "iface.launch(show_error=True, height=\"650px\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],\n",
    "            # label=\"Prompt Examples\",\n",
    "            # samples=[\n",
    "            #     [\"Provide nutritional information for given food image.\"],\n",
    "            #     [\"What are the nutrients available in given food image.\"],\n",
    "            #     [\"Could you provide a detailed nutritional data of the given food image?\"],\n",
    "            #     [\"Describe the instructions to prepare given food.\"],\n",
    "            #     [\"What are the key ingredients in this food image?\"],\n",
    "            #     [\"Could you highlight the dietary tags for this food image?\"],\n",
    "            # ],)\n",
    "\n",
    "# example_images = gr.Dataset(components=[image], label=\"Food Examples\",\n",
    "#                     samples=[\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000018.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000021.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000035.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000038.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000090.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000122.png\")],\n",
    "#                     ])\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "covr",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}