File size: 41,362 Bytes
0516286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import gradio as gr\n",
    "import warnings\n",
    "import streamlit as st\n",
    "from PIL import Image\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertForSequenceClassification, AutoModelForSeq2SeqLM\n",
    "from tqdm import tqdm\n",
    "from torchvision import models\n",
    "from torchvision.transforms import v2\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from keras.preprocessing import image\n",
    "from torchmetrics.classification import MultilabelF1Score\n",
    "from sklearn.metrics import average_precision_score, ndcg_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: 'Crime',\n",
       " 1: 'Thriller',\n",
       " 2: 'Fantasy',\n",
       " 3: 'Horror',\n",
       " 4: 'Sci-Fi',\n",
       " 5: 'Comedy',\n",
       " 6: 'Documentary',\n",
       " 7: 'Adventure',\n",
       " 8: 'Film-Noir',\n",
       " 9: 'Animation',\n",
       " 10: 'Romance',\n",
       " 11: 'Drama',\n",
       " 12: 'Western',\n",
       " 13: 'Musical',\n",
       " 14: 'Action',\n",
       " 15: 'Mystery',\n",
       " 16: 'War',\n",
       " 17: \"Children's\"}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "genres = [\"Crime\", \"Thriller\", \"Fantasy\", \"Horror\", \"Sci-Fi\", \"Comedy\", \"Documentary\", \"Adventure\", \"Film-Noir\", \"Animation\", \"Romance\", \"Drama\", \"Western\", \"Musical\", \"Action\", \"Mystery\", \"War\", \"Children\\'s\"]\n",
    "mapping = {}\n",
    "for i in range(len(genres)):\n",
    "    mapping[i] = genres[i]\n",
    "mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer_gen = AutoTokenizer.from_pretrained(\"MBZUAI/LaMini-Flan-T5-248M\")\n",
    "model_gen = AutoModelForSeq2SeqLM.from_pretrained(\"MBZUAI/LaMini-Flan-T5-248M\")\n",
    "\n",
    "tokenizer1 = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
    "model1 = DistilBertForSequenceClassification .from_pretrained(\"distilbert-base-uncased\", problem_type=\"multi_label_classification\", num_labels=18)\n",
    "model1.config.id2label = mapping\n",
    "\n",
    "tokenizer2 = AutoTokenizer.from_pretrained(\"dduy193/plot-classification\")\n",
    "model2 = AutoModelForSequenceClassification.from_pretrained(\"dduy193/plot-classification\")\n",
    "model2.config.id2label = mapping\n",
    "\n",
    "model3 = models.resnet101(pretrained=False)\n",
    "model3.fc = torch.nn.Linear(2048, len(genres))\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model1.to(device)\n",
    "model2.to(device)\n",
    "model3.to(device)\n",
    "model_gen.to(device)\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Multimodal(torch.nn.Module):\n",
    "    def __init__(self, model1, model2, model3):\n",
    "        super().__init__()\n",
    "        self.model1 = model1\n",
    "        self.model2 = model2\n",
    "        self.model3 = model3\n",
    "        self.fc1 = torch.nn.Linear(18, 18)\n",
    "        self.fc2 = torch.nn.Linear(18, 18)\n",
    "        self.fc3 = torch.nn.Linear(18, 18)\n",
    "\n",
    "    def forward(self, \n",
    "                title_input_ids, title_attention_mask,\n",
    "                plot_input_ids, plot_attention_mask,\n",
    "                image_input):\n",
    "        title_output = self.model1(title_input_ids, title_attention_mask)\n",
    "        plot_output = self.model2(plot_input_ids, plot_attention_mask)\n",
    "        image_output = self.model3(image_input)\n",
    "\n",
    "        title_output = self.fc1(title_output.logits)\n",
    "        plot_output = self.fc2(plot_output.logits)\n",
    "        image_output = self.fc3(image_output)\n",
    "        \n",
    "        output = torch.add(title_output, plot_output)\n",
    "        output = torch.add(output, image_output)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Multimodal(model1, model2, model3)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model.to(device)\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**_PLEASE INSTALL THE MODEL CHECKPOINT FROM THE LINK IN README.txt_**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Multimodal(\n",
       "  (model1): DistilBertForSequenceClassification(\n",
       "    (distilbert): DistilBertModel(\n",
       "      (embeddings): Embeddings(\n",
       "        (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "        (position_embeddings): Embedding(512, 768)\n",
       "        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (transformer): Transformer(\n",
       "        (layer): ModuleList(\n",
       "          (0-5): 6 x TransformerBlock(\n",
       "            (attention): MultiHeadSelfAttention(\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            )\n",
       "            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (ffn): FFN(\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (activation): GELUActivation()\n",
       "            )\n",
       "            (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (classifier): Linear(in_features=768, out_features=18, bias=True)\n",
       "    (dropout): Dropout(p=0.2, inplace=False)\n",
       "  )\n",
       "  (model2): DistilBertForSequenceClassification(\n",
       "    (distilbert): DistilBertModel(\n",
       "      (embeddings): Embeddings(\n",
       "        (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "        (position_embeddings): Embedding(512, 768)\n",
       "        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (transformer): Transformer(\n",
       "        (layer): ModuleList(\n",
       "          (0-5): 6 x TransformerBlock(\n",
       "            (attention): MultiHeadSelfAttention(\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            )\n",
       "            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (ffn): FFN(\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "              (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (activation): GELUActivation()\n",
       "            )\n",
       "            (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (classifier): Linear(in_features=768, out_features=18, bias=True)\n",
       "    (dropout): Dropout(p=0.2, inplace=False)\n",
       "  )\n",
       "  (model3): ResNet(\n",
       "    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu): ReLU(inplace=True)\n",
       "    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (layer1): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer2): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer3): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (4): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (5): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (6): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (7): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (8): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (9): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (10): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (11): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (12): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (13): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (14): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (15): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (16): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (17): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (18): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (19): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (20): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (21): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (22): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer4): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (fc): Linear(in_features=2048, out_features=18, bias=True)\n",
       "  )\n",
       "  (fc1): Linear(in_features=18, out_features=18, bias=True)\n",
       "  (fc2): Linear(in_features=18, out_features=18, bias=True)\n",
       "  (fc3): Linear(in_features=18, out_features=18, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('multimodel.pt'))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_plot(title: str, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, device) -> str:\n",
    "    quote = 'What is the story of the movie {}?'\n",
    "    model_gen.to(device)\n",
    "    model_gen.eval()\n",
    "\n",
    "    input_ids = tokenizer(quote.format(title), return_tensors='pt').input_ids.to(device)\n",
    "    output = model.generate(input_ids, max_length=256, do_sample=True, temperature=0.09)\n",
    "    return tokenizer.decode(output[0], skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(title, image, \n",
    "              tokenizer1=tokenizer1, tokenizer2=tokenizer2, tokenizer_gen=tokenizer_gen,\n",
    "              model_gen=model_gen, model=model, \n",
    "              genres=genres, device=device):\n",
    "    title_input = tokenizer1(title, return_tensors='pt', padding=True, truncation=True)\n",
    "    title_input_ids = title_input['input_ids'].to(device)\n",
    "    title_attention_mask = title_input['attention_mask'].to(device)\n",
    "\n",
    "    plot = generate_plot(title, model_gen, tokenizer_gen, device)\n",
    "    plot_input = tokenizer2(plot, return_tensors='pt', padding=True, truncation=True)\n",
    "    plot_input_ids = plot_input['input_ids'].to(device)\n",
    "    plot_attention_mask = plot_input['attention_mask'].to(device)\n",
    "\n",
    "    # If image is not uploaded\n",
    "    if image is None:\n",
    "        image_input = torch.zeros((1, 3, 224, 224)).to(device)\n",
    "\n",
    "    else:\n",
    "        image_input = image.resize((224, 224))\n",
    "        image_input = v2.ToTensor()(image_input)\n",
    "        image_input = image_input.unsqueeze(0)\n",
    "        image_input = image_input.to(device)\n",
    "\n",
    "    output = model(title_input_ids, title_attention_mask, plot_input_ids, plot_attention_mask, image_input)\n",
    "    output = torch.sigmoid(output)\n",
    "    output = output.cpu().detach().numpy()\n",
    "    output = np.where(output > 0.5, 1, 0)\n",
    "    output = output.squeeze()\n",
    "    output = np.where(output == 1)[0]\n",
    "    output = [genres[i] for i in output]\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "app = gr.Interface(fn=inference, inputs=[\"text\", \"pil\"], outputs=\"text\", title=\"Movie Genre Classification\", \n",
    "                   description=\"This model classifies the genre of a movie based on its title and poster.\", \n",
    "                   examples=[[\"The Matrix\", \"https://upload.wikimedia.org/wikipedia/en/c/c1/The_Matrix_Poster.jpg\"],\n",
    "                             [\"The Dark Knight\", \"https://upload.wikimedia.org/wikipedia/en/1/1c/The_Dark_Knight_%282008_film%29.jpg\"],\n",
    "                             [\"The Godfather\", \"https://upload.wikimedia.org/wikipedia/en/1/1c/Godfather_ver1.jpg\"],\n",
    "                             [\"The Shawshank Redemption\", \"https://upload.wikimedia.org/wikipedia/en/8/81/ShawshankRedemptionMoviePoster.jpg\"],\n",
    "                             [\"The Lord of the Rings: The Return of the King\", \"https://upload.wikimedia.org/wikipedia/en/2/23/The_Lord_of_the_Rings%2C_TROTK_%282003%29.jpg\"],\n",
    "                             [\"The Godfather: Part II\", \"https://upload.wikimedia.org/wikipedia/en/0/03/Godfather_part_ii.jpg\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7860\n",
      "Running on public URL: https://9a1eeae0e7e33064f8.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://9a1eeae0e7e33064f8.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "app.launch(share=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}