liwii commited on
Commit
26d475a
·
verified ·
1 Parent(s): 611bfa1

Training in progress, epoch 1

Browse files
added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 0,
3
+ "<unk>": 1,
4
+ "[CLS]": 2,
5
+ "[MASK]": 4,
6
+ "[SEP]": 3
7
+ }
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "line-corporation/line-distilbert-base-japanese",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "ConsistentSentenceRegressor"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "LABEL_0"
13
+ },
14
+ "initializer_range": 0.02,
15
+ "label2id": {
16
+ "LABEL_0": 0
17
+ },
18
+ "max_position_embeddings": 512,
19
+ "model_type": "distilbert",
20
+ "n_heads": 12,
21
+ "n_layers": 6,
22
+ "output_hidden_states": true,
23
+ "pad_token_id": 0,
24
+ "problem_type": "regression",
25
+ "qa_dropout": 0.1,
26
+ "seq_classif_dropout": 0.2,
27
+ "sinusoidal_pos_embds": true,
28
+ "tie_weights_": true,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.34.0",
31
+ "vocab_size": 32768
32
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1f295326993252bd9df044d41ab49e3250373aee8f37f46cb0072b73e52d1f7
3
+ size 274752173
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "<pad>",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "<unk>"
9
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcfafc8c0662d9c8f39621a64c74260f2ad120310c8dd24886de2dddaf599b4e
3
+ size 439391
tokenizer_config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<unk>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "auto_map": {
46
+ "AutoTokenizer": [
47
+ "line-corporation/line-distilbert-base-japanese--distilbert_japanese_tokenizer.DistilBertJapaneseTokenizer",
48
+ null
49
+ ]
50
+ },
51
+ "bos_token": "[CLS]",
52
+ "clean_up_tokenization_spaces": true,
53
+ "cls_token": "[CLS]",
54
+ "do_lower_case": true,
55
+ "do_subword_tokenize": true,
56
+ "do_word_tokenize": true,
57
+ "eos_token": "[SEP]",
58
+ "jumanpp_kwargs": null,
59
+ "keep_accents": true,
60
+ "mask_token": "[MASK]",
61
+ "mecab_kwargs": {
62
+ "mecab_dic": "unidic_lite"
63
+ },
64
+ "model_max_length": 1000000000000000019884624838656,
65
+ "never_split": null,
66
+ "pad_token": "<pad>",
67
+ "remove_space": true,
68
+ "sep_token": "[SEP]",
69
+ "subword_tokenizer_type": "sentencepiece",
70
+ "sudachi_kwargs": null,
71
+ "tokenize_chinese_chars": false,
72
+ "tokenizer_class": "BertJapaneseTokenizer",
73
+ "tokenizer_file": null,
74
+ "unk_token": "<unk>",
75
+ "word_tokenizer_type": "mecab"
76
+ }
train-v1.1.json ADDED
The diff for this file is too large to render. See raw diff
 
train_factual_consistency.ipynb ADDED
@@ -0,0 +1,1489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b12ae8a3-9e08-402c-894c-31697fad6c56",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "application/vnd.jupyter.widget-view+json": {
12
+ "model_id": "6e13508dc55b4712a4d6e91647a932a3",
13
+ "version_major": 2,
14
+ "version_minor": 0
15
+ },
16
+ "text/plain": [
17
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
18
+ ]
19
+ },
20
+ "metadata": {},
21
+ "output_type": "display_data"
22
+ }
23
+ ],
24
+ "source": [
25
+ "from huggingface_hub import notebook_login\n",
26
+ "\n",
27
+ "notebook_login()"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "160c80c1-0ca4-45df-8171-87cd3c88a223",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "\n",
38
+ "from transformers import (\n",
39
+ " AutoTokenizer,\n",
40
+ " DataCollatorWithPadding,\n",
41
+ " Trainer,\n",
42
+ " TrainingArguments,\n",
43
+ ")\n",
44
+ "from utils import ConsistentSentenceRegressor, get_metrics, load_dataset"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "25800588-5d42-4524-9dc6-a6a0c180b8b0",
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "name": "stdout",
55
+ "output_type": "stream",
56
+ "text": [
57
+ " text label\n",
58
+ "512 カーキ色の服を着た男性が、口元にリンゴを当てています。[SEP]カーキ色の服を着た男性が、口... 0.0\n",
59
+ "513 男性がグラウンドでボールを投げています。[SEP]白い髯を生やした男性がボールを投げています。 0.5\n",
60
+ "514 椅子に座った子供が、手づかみで食事をしています。[SEP]椅子に座った子供が手づかみで、食事... 1.0\n",
61
+ "515 プロペラ機が何台も駐機しています。[SEP]プロペラ機が何台も連なって飛んでいます。 0.0\n",
62
+ "516 消火栓から水が勢いよく噴き出しています。[SEP]水が噴き出している消火栓の水を浴びるように... 0.5\n",
63
+ "517 冷蔵庫のないキッチンにナイフとフォークが置かれています。[SEP]冷蔵庫の置かれたキッチンに... 0.0\n",
64
+ "518 うみでサーフィンをしているひとがいます。[SEP]黒いウェットスーツを着た人がサーフボードに... 0.5\n",
65
+ "519 池から白い鳥が飛び立っています。[SEP]森にある水の上を鳥が飛んでいます。 0.5\n",
66
+ "520 丈夫なビーチパラソルが立っています。[SEP]ビーチパラソルの支柱が折れ曲がっています。 0.0\n",
67
+ "521 白髪の男性が少女から花束を受け取っています。[SEP]花束を持った男性の前に多くの子供たちが... 0.5\n",
68
+ " text label\n",
69
+ "0 赤いひとつの傘に、二人の人が入っています。[SEP]歩道を歩く通行人が傘をさして歩いています。 0.5\n",
70
+ "1 川を小さなボートが進んで行きます。[SEP]川を豪華客船が進んでいきます。 0.0\n",
71
+ "2 ゲレンデのこぶでスキージャンプしています。[SEP]雪上でモーグルを楽しむ水色のウェアを着た女性。 0.5\n",
72
+ "3 黒いお皿に乗っているピザをカットしています。[SEP]黒い皿の上にピザが盛られています。 1.0\n",
73
+ "4 女性が目を細めて携帯電話で話をしています。[SEP]目を細めた女性が携帯電話で話をしています。 1.0\n",
74
+ "5 バナナやパパイヤなどの果物が売られている。[SEP]台の上にはバナナなどの青果が並べられています。 0.5\n",
75
+ "6 ヘッドライトを点灯させた白いバスが駐車場に止まっています。[SEP]ライトを点灯させているバ... 1.0\n",
76
+ "7 水面の上に、カイトサーフィンの凧が揚がっています。[SEP]海の上に水上スポーツ用の凧が揚が... 0.5\n",
77
+ "8 ホットドッグを野外で食べている人たちです。[SEP]家の中でホットドッグを食べている。 0.0\n",
78
+ "9 草が生い茂っている所に、3頭のゾウがいます。[SEP]草むらの中に三頭のゾウが立っているとこ... 0.5\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "application/vnd.jupyter.widget-view+json": {
84
+ "model_id": "37636d1b642c4b5382572caabd6f7466",
85
+ "version_major": 2,
86
+ "version_minor": 0
87
+ },
88
+ "text/plain": [
89
+ "Map: 0%| | 0/19561 [00:00<?, ? examples/s]"
90
+ ]
91
+ },
92
+ "metadata": {},
93
+ "output_type": "display_data"
94
+ },
95
+ {
96
+ "name": "stderr",
97
+ "output_type": "stream",
98
+ "text": [
99
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.\n",
100
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
101
+ ]
102
+ },
103
+ {
104
+ "data": {
105
+ "application/vnd.jupyter.widget-view+json": {
106
+ "model_id": "901f21c168624db8aa6e8881dd30df60",
107
+ "version_major": 2,
108
+ "version_minor": 0
109
+ },
110
+ "text/plain": [
111
+ "Map: 0%| | 0/512 [00:00<?, ? examples/s]"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ }
117
+ ],
118
+ "source": [
119
+ "tokenizer = AutoTokenizer.from_pretrained(\"line-corporation/line-distilbert-base-japanese\")\n",
120
+ "dataset = load_dataset('train-v1.1.json')\n",
121
+ "tokenized_dataset = dataset.map(\n",
122
+ " lambda examples: tokenizer(examples[\"text\"], padding='max_length', truncation=True), batched=True\n",
123
+ ")"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 7,
129
+ "id": "6bc83d4c-378c-4313-b641-8ead0c02f715",
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "name": "stdout",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "torch.Size([64, 60])\n",
137
+ "torch.Size([64, 1])\n",
138
+ "torch.Size([64])\n"
139
+ ]
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='406' max='30600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [ 406/30600 00:45 < 56:06, 8.97 it/s, Epoch 1.32/100]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Epoch</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " <th>Validation Loss</th>\n",
156
+ " </tr>\n",
157
+ " </thead>\n",
158
+ " <tbody>\n",
159
+ " <tr>\n",
160
+ " <td>1</td>\n",
161
+ " <td>No log</td>\n",
162
+ " <td>-3.658799</td>\n",
163
+ " </tr>\n",
164
+ " </tbody>\n",
165
+ "</table><p>"
166
+ ],
167
+ "text/plain": [
168
+ "<IPython.core.display.HTML object>"
169
+ ]
170
+ },
171
+ "metadata": {},
172
+ "output_type": "display_data"
173
+ },
174
+ {
175
+ "name": "stdout",
176
+ "output_type": "stream",
177
+ "text": [
178
+ "torch.Size([64, 63])\n",
179
+ "torch.Size([64, 1])\n",
180
+ "torch.Size([64])\n",
181
+ "torch.Size([64, 59])\n",
182
+ "torch.Size([64, 1])\n",
183
+ "torch.Size([64])\n",
184
+ "torch.Size([64, 52])\n",
185
+ "torch.Size([64, 1])\n",
186
+ "torch.Size([64])\n",
187
+ "torch.Size([64, 56])\n",
188
+ "torch.Size([64, 1])\n",
189
+ "torch.Size([64])\n",
190
+ "torch.Size([64, 63])\n",
191
+ "torch.Size([64, 1])\n",
192
+ "torch.Size([64])\n",
193
+ "torch.Size([64, 63])\n",
194
+ "torch.Size([64, 1])\n",
195
+ "torch.Size([64])\n",
196
+ "torch.Size([64, 57])\n",
197
+ "torch.Size([64, 1])\n",
198
+ "torch.Size([64])\n",
199
+ "torch.Size([64, 56])\n",
200
+ "torch.Size([64, 1])\n",
201
+ "torch.Size([64])\n",
202
+ "torch.Size([64, 77])\n",
203
+ "torch.Size([64, 1])\n",
204
+ "torch.Size([64])\n",
205
+ "torch.Size([64, 60])\n",
206
+ "torch.Size([64, 1])\n",
207
+ "torch.Size([64])\n",
208
+ "torch.Size([64, 72])\n",
209
+ "torch.Size([64, 1])\n",
210
+ "torch.Size([64])\n",
211
+ "torch.Size([64, 60])\n",
212
+ "torch.Size([64, 1])\n",
213
+ "torch.Size([64])\n",
214
+ "torch.Size([64, 56])\n",
215
+ "torch.Size([64, 1])\n",
216
+ "torch.Size([64])\n",
217
+ "torch.Size([64, 50])\n",
218
+ "torch.Size([64, 1])\n",
219
+ "torch.Size([64])\n",
220
+ "torch.Size([64, 61])\n",
221
+ "torch.Size([64, 1])\n",
222
+ "torch.Size([64])\n",
223
+ "torch.Size([64, 69])\n",
224
+ "torch.Size([64, 1])\n",
225
+ "torch.Size([64])\n",
226
+ "torch.Size([64, 62])\n",
227
+ "torch.Size([64, 1])\n",
228
+ "torch.Size([64])\n",
229
+ "torch.Size([64, 56])\n",
230
+ "torch.Size([64, 1])\n",
231
+ "torch.Size([64])\n",
232
+ "torch.Size([64, 50])\n",
233
+ "torch.Size([64, 1])\n",
234
+ "torch.Size([64])\n",
235
+ "torch.Size([64, 63])\n",
236
+ "torch.Size([64, 1])\n",
237
+ "torch.Size([64])\n",
238
+ "torch.Size([64, 52])\n",
239
+ "torch.Size([64, 1])\n",
240
+ "torch.Size([64])\n",
241
+ "torch.Size([64, 80])\n",
242
+ "torch.Size([64, 1])\n",
243
+ "torch.Size([64])\n",
244
+ "torch.Size([64, 71])\n",
245
+ "torch.Size([64, 1])\n",
246
+ "torch.Size([64])\n",
247
+ "torch.Size([64, 52])\n",
248
+ "torch.Size([64, 1])\n",
249
+ "torch.Size([64])\n",
250
+ "torch.Size([64, 54])\n",
251
+ "torch.Size([64, 1])\n",
252
+ "torch.Size([64])\n",
253
+ "torch.Size([64, 51])\n",
254
+ "torch.Size([64, 1])\n",
255
+ "torch.Size([64])\n",
256
+ "torch.Size([64, 51])\n",
257
+ "torch.Size([64, 1])\n",
258
+ "torch.Size([64])\n",
259
+ "torch.Size([64, 51])\n",
260
+ "torch.Size([64, 1])\n",
261
+ "torch.Size([64])\n",
262
+ "torch.Size([64, 70])\n",
263
+ "torch.Size([64, 1])\n",
264
+ "torch.Size([64])\n",
265
+ "torch.Size([64, 64])\n",
266
+ "torch.Size([64, 1])\n",
267
+ "torch.Size([64])\n",
268
+ "torch.Size([64, 54])\n",
269
+ "torch.Size([64, 1])\n",
270
+ "torch.Size([64])\n",
271
+ "torch.Size([64, 55])\n",
272
+ "torch.Size([64, 1])\n",
273
+ "torch.Size([64])\n",
274
+ "torch.Size([64, 53])\n",
275
+ "torch.Size([64, 1])\n",
276
+ "torch.Size([64])\n",
277
+ "torch.Size([64, 76])\n",
278
+ "torch.Size([64, 1])\n",
279
+ "torch.Size([64])\n",
280
+ "torch.Size([64, 53])\n",
281
+ "torch.Size([64, 1])\n",
282
+ "torch.Size([64])\n",
283
+ "torch.Size([64, 55])\n",
284
+ "torch.Size([64, 1])\n",
285
+ "torch.Size([64])\n",
286
+ "torch.Size([64, 70])\n",
287
+ "torch.Size([64, 1])\n",
288
+ "torch.Size([64])\n",
289
+ "torch.Size([64, 59])\n",
290
+ "torch.Size([64, 1])\n",
291
+ "torch.Size([64])\n",
292
+ "torch.Size([64, 59])\n",
293
+ "torch.Size([64, 1])\n",
294
+ "torch.Size([64])\n",
295
+ "torch.Size([64, 68])\n",
296
+ "torch.Size([64, 1])\n",
297
+ "torch.Size([64])\n",
298
+ "torch.Size([64, 71])\n",
299
+ "torch.Size([64, 1])\n",
300
+ "torch.Size([64])\n",
301
+ "torch.Size([64, 58])\n",
302
+ "torch.Size([64, 1])\n",
303
+ "torch.Size([64])\n",
304
+ "torch.Size([64, 47])\n",
305
+ "torch.Size([64, 1])\n",
306
+ "torch.Size([64])\n",
307
+ "torch.Size([64, 65])\n",
308
+ "torch.Size([64, 1])\n",
309
+ "torch.Size([64])\n",
310
+ "torch.Size([64, 67])\n",
311
+ "torch.Size([64, 1])\n",
312
+ "torch.Size([64])\n",
313
+ "torch.Size([64, 67])\n",
314
+ "torch.Size([64, 1])\n",
315
+ "torch.Size([64])\n",
316
+ "torch.Size([64, 77])\n",
317
+ "torch.Size([64, 1])\n",
318
+ "torch.Size([64])\n",
319
+ "torch.Size([64, 55])\n",
320
+ "torch.Size([64, 1])\n",
321
+ "torch.Size([64])\n",
322
+ "torch.Size([64, 51])\n",
323
+ "torch.Size([64, 1])\n",
324
+ "torch.Size([64])\n",
325
+ "torch.Size([64, 64])\n",
326
+ "torch.Size([64, 1])\n",
327
+ "torch.Size([64])\n",
328
+ "torch.Size([64, 61])\n",
329
+ "torch.Size([64, 1])\n",
330
+ "torch.Size([64])\n",
331
+ "torch.Size([64, 79])\n",
332
+ "torch.Size([64, 1])\n",
333
+ "torch.Size([64])\n",
334
+ "torch.Size([64, 47])\n",
335
+ "torch.Size([64, 1])\n",
336
+ "torch.Size([64])\n",
337
+ "torch.Size([64, 59])\n",
338
+ "torch.Size([64, 1])\n",
339
+ "torch.Size([64])\n",
340
+ "torch.Size([64, 63])\n",
341
+ "torch.Size([64, 1])\n",
342
+ "torch.Size([64])\n",
343
+ "torch.Size([64, 53])\n",
344
+ "torch.Size([64, 1])\n",
345
+ "torch.Size([64])\n",
346
+ "torch.Size([64, 79])\n",
347
+ "torch.Size([64, 1])\n",
348
+ "torch.Size([64])\n",
349
+ "torch.Size([64, 55])\n",
350
+ "torch.Size([64, 1])\n",
351
+ "torch.Size([64])\n",
352
+ "torch.Size([64, 77])\n",
353
+ "torch.Size([64, 1])\n",
354
+ "torch.Size([64])\n",
355
+ "torch.Size([64, 67])\n",
356
+ "torch.Size([64, 1])\n",
357
+ "torch.Size([64])\n",
358
+ "torch.Size([64, 57])\n",
359
+ "torch.Size([64, 1])\n",
360
+ "torch.Size([64])\n",
361
+ "torch.Size([64, 67])\n",
362
+ "torch.Size([64, 1])\n",
363
+ "torch.Size([64])\n",
364
+ "torch.Size([64, 70])\n",
365
+ "torch.Size([64, 1])\n",
366
+ "torch.Size([64])\n",
367
+ "torch.Size([64, 48])\n",
368
+ "torch.Size([64, 1])\n",
369
+ "torch.Size([64])\n",
370
+ "torch.Size([64, 80])\n",
371
+ "torch.Size([64, 1])\n",
372
+ "torch.Size([64])\n",
373
+ "torch.Size([64, 54])\n",
374
+ "torch.Size([64, 1])\n",
375
+ "torch.Size([64])\n",
376
+ "torch.Size([64, 50])\n",
377
+ "torch.Size([64, 1])\n",
378
+ "torch.Size([64])\n",
379
+ "torch.Size([64, 64])\n",
380
+ "torch.Size([64, 1])\n",
381
+ "torch.Size([64])\n",
382
+ "torch.Size([64, 52])\n",
383
+ "torch.Size([64, 1])\n",
384
+ "torch.Size([64])\n",
385
+ "torch.Size([64, 55])\n",
386
+ "torch.Size([64, 1])\n",
387
+ "torch.Size([64])\n",
388
+ "torch.Size([64, 61])\n",
389
+ "torch.Size([64, 1])\n",
390
+ "torch.Size([64])\n",
391
+ "torch.Size([64, 73])\n",
392
+ "torch.Size([64, 1])\n",
393
+ "torch.Size([64])\n",
394
+ "torch.Size([64, 69])\n",
395
+ "torch.Size([64, 1])\n",
396
+ "torch.Size([64])\n",
397
+ "torch.Size([64, 54])\n",
398
+ "torch.Size([64, 1])\n",
399
+ "torch.Size([64])\n",
400
+ "torch.Size([64, 59])\n",
401
+ "torch.Size([64, 1])\n",
402
+ "torch.Size([64])\n",
403
+ "torch.Size([64, 74])\n",
404
+ "torch.Size([64, 1])\n",
405
+ "torch.Size([64])\n",
406
+ "torch.Size([64, 49])\n",
407
+ "torch.Size([64, 1])\n",
408
+ "torch.Size([64])\n",
409
+ "torch.Size([64, 52])\n",
410
+ "torch.Size([64, 1])\n",
411
+ "torch.Size([64])\n",
412
+ "torch.Size([64, 62])\n",
413
+ "torch.Size([64, 1])\n",
414
+ "torch.Size([64])\n",
415
+ "torch.Size([64, 58])\n",
416
+ "torch.Size([64, 1])\n",
417
+ "torch.Size([64])\n",
418
+ "torch.Size([64, 72])\n",
419
+ "torch.Size([64, 1])\n",
420
+ "torch.Size([64])\n",
421
+ "torch.Size([64, 69])\n",
422
+ "torch.Size([64, 1])\n",
423
+ "torch.Size([64])\n",
424
+ "torch.Size([64, 50])\n",
425
+ "torch.Size([64, 1])\n",
426
+ "torch.Size([64])\n",
427
+ "torch.Size([64, 74])\n",
428
+ "torch.Size([64, 1])\n",
429
+ "torch.Size([64])\n",
430
+ "torch.Size([64, 54])\n",
431
+ "torch.Size([64, 1])\n",
432
+ "torch.Size([64])\n",
433
+ "torch.Size([64, 59])\n",
434
+ "torch.Size([64, 1])\n",
435
+ "torch.Size([64])\n",
436
+ "torch.Size([64, 63])\n",
437
+ "torch.Size([64, 1])\n",
438
+ "torch.Size([64])\n",
439
+ "torch.Size([64, 79])\n",
440
+ "torch.Size([64, 1])\n",
441
+ "torch.Size([64])\n",
442
+ "torch.Size([64, 52])\n",
443
+ "torch.Size([64, 1])\n",
444
+ "torch.Size([64])\n",
445
+ "torch.Size([64, 60])\n",
446
+ "torch.Size([64, 1])\n",
447
+ "torch.Size([64])\n",
448
+ "torch.Size([64, 58])\n",
449
+ "torch.Size([64, 1])\n",
450
+ "torch.Size([64])\n",
451
+ "torch.Size([64, 64])\n",
452
+ "torch.Size([64, 1])\n",
453
+ "torch.Size([64])\n",
454
+ "torch.Size([64, 52])\n",
455
+ "torch.Size([64, 1])\n",
456
+ "torch.Size([64])\n",
457
+ "torch.Size([64, 61])\n",
458
+ "torch.Size([64, 1])\n",
459
+ "torch.Size([64])\n",
460
+ "torch.Size([64, 68])\n",
461
+ "torch.Size([64, 1])\n",
462
+ "torch.Size([64])\n",
463
+ "torch.Size([64, 70])\n",
464
+ "torch.Size([64, 1])\n",
465
+ "torch.Size([64])\n",
466
+ "torch.Size([64, 48])\n",
467
+ "torch.Size([64, 1])\n",
468
+ "torch.Size([64])\n",
469
+ "torch.Size([64, 69])\n",
470
+ "torch.Size([64, 1])\n",
471
+ "torch.Size([64])\n",
472
+ "torch.Size([64, 52])\n",
473
+ "torch.Size([64, 1])\n",
474
+ "torch.Size([64])\n",
475
+ "torch.Size([64, 75])\n",
476
+ "torch.Size([64, 1])\n",
477
+ "torch.Size([64])\n",
478
+ "torch.Size([64, 67])\n",
479
+ "torch.Size([64, 1])\n",
480
+ "torch.Size([64])\n",
481
+ "torch.Size([64, 57])\n",
482
+ "torch.Size([64, 1])\n",
483
+ "torch.Size([64])\n",
484
+ "torch.Size([64, 88])\n",
485
+ "torch.Size([64, 1])\n",
486
+ "torch.Size([64])\n",
487
+ "torch.Size([64, 64])\n",
488
+ "torch.Size([64, 1])\n",
489
+ "torch.Size([64])\n",
490
+ "torch.Size([64, 63])\n",
491
+ "torch.Size([64, 1])\n",
492
+ "torch.Size([64])\n",
493
+ "torch.Size([64, 64])\n",
494
+ "torch.Size([64, 1])\n",
495
+ "torch.Size([64])\n",
496
+ "torch.Size([64, 56])\n",
497
+ "torch.Size([64, 1])\n",
498
+ "torch.Size([64])\n",
499
+ "torch.Size([64, 52])\n",
500
+ "torch.Size([64, 1])\n",
501
+ "torch.Size([64])\n",
502
+ "torch.Size([64, 71])\n",
503
+ "torch.Size([64, 1])\n",
504
+ "torch.Size([64])\n",
505
+ "torch.Size([64, 57])\n",
506
+ "torch.Size([64, 1])\n",
507
+ "torch.Size([64])\n",
508
+ "torch.Size([64, 74])\n",
509
+ "torch.Size([64, 1])\n",
510
+ "torch.Size([64])\n",
511
+ "torch.Size([64, 62])\n",
512
+ "torch.Size([64, 1])\n",
513
+ "torch.Size([64])\n",
514
+ "torch.Size([64, 63])\n",
515
+ "torch.Size([64, 1])\n",
516
+ "torch.Size([64])\n",
517
+ "torch.Size([64, 76])\n",
518
+ "torch.Size([64, 1])\n",
519
+ "torch.Size([64])\n",
520
+ "torch.Size([64, 60])\n",
521
+ "torch.Size([64, 1])\n",
522
+ "torch.Size([64])\n",
523
+ "torch.Size([64, 62])\n",
524
+ "torch.Size([64, 1])\n",
525
+ "torch.Size([64])\n",
526
+ "torch.Size([64, 55])\n",
527
+ "torch.Size([64, 1])\n",
528
+ "torch.Size([64])\n",
529
+ "torch.Size([64, 65])\n",
530
+ "torch.Size([64, 1])\n",
531
+ "torch.Size([64])\n",
532
+ "torch.Size([64, 62])\n",
533
+ "torch.Size([64, 1])\n",
534
+ "torch.Size([64])\n",
535
+ "torch.Size([64, 57])\n",
536
+ "torch.Size([64, 1])\n",
537
+ "torch.Size([64])\n",
538
+ "torch.Size([64, 58])\n",
539
+ "torch.Size([64, 1])\n",
540
+ "torch.Size([64])\n",
541
+ "torch.Size([64, 65])\n",
542
+ "torch.Size([64, 1])\n",
543
+ "torch.Size([64])\n",
544
+ "torch.Size([64, 74])\n",
545
+ "torch.Size([64, 1])\n",
546
+ "torch.Size([64])\n",
547
+ "torch.Size([64, 56])\n",
548
+ "torch.Size([64, 1])\n",
549
+ "torch.Size([64])\n",
550
+ "torch.Size([64, 77])\n",
551
+ "torch.Size([64, 1])\n",
552
+ "torch.Size([64])\n",
553
+ "torch.Size([64, 50])\n",
554
+ "torch.Size([64, 1])\n",
555
+ "torch.Size([64])\n",
556
+ "torch.Size([64, 63])\n",
557
+ "torch.Size([64, 1])\n",
558
+ "torch.Size([64])\n",
559
+ "torch.Size([64, 72])\n",
560
+ "torch.Size([64, 1])\n",
561
+ "torch.Size([64])\n",
562
+ "torch.Size([64, 60])\n",
563
+ "torch.Size([64, 1])\n",
564
+ "torch.Size([64])\n",
565
+ "torch.Size([64, 59])\n",
566
+ "torch.Size([64, 1])\n",
567
+ "torch.Size([64])\n",
568
+ "torch.Size([64, 73])\n",
569
+ "torch.Size([64, 1])\n",
570
+ "torch.Size([64])\n",
571
+ "torch.Size([64, 54])\n",
572
+ "torch.Size([64, 1])\n",
573
+ "torch.Size([64])\n",
574
+ "torch.Size([64, 65])\n",
575
+ "torch.Size([64, 1])\n",
576
+ "torch.Size([64])\n",
577
+ "torch.Size([64, 51])\n",
578
+ "torch.Size([64, 1])\n",
579
+ "torch.Size([64])\n",
580
+ "torch.Size([64, 50])\n",
581
+ "torch.Size([64, 1])\n",
582
+ "torch.Size([64])\n",
583
+ "torch.Size([64, 54])\n",
584
+ "torch.Size([64, 1])\n",
585
+ "torch.Size([64])\n",
586
+ "torch.Size([64, 67])\n",
587
+ "torch.Size([64, 1])\n",
588
+ "torch.Size([64])\n",
589
+ "torch.Size([64, 60])\n",
590
+ "torch.Size([64, 1])\n",
591
+ "torch.Size([64])\n",
592
+ "torch.Size([64, 63])\n",
593
+ "torch.Size([64, 1])\n",
594
+ "torch.Size([64])\n",
595
+ "torch.Size([64, 77])\n",
596
+ "torch.Size([64, 1])\n",
597
+ "torch.Size([64])\n",
598
+ "torch.Size([64, 62])\n",
599
+ "torch.Size([64, 1])\n",
600
+ "torch.Size([64])\n",
601
+ "torch.Size([64, 70])\n",
602
+ "torch.Size([64, 1])\n",
603
+ "torch.Size([64])\n",
604
+ "torch.Size([64, 79])\n",
605
+ "torch.Size([64, 1])\n",
606
+ "torch.Size([64])\n",
607
+ "torch.Size([64, 67])\n",
608
+ "torch.Size([64, 1])\n",
609
+ "torch.Size([64])\n",
610
+ "torch.Size([64, 57])\n",
611
+ "torch.Size([64, 1])\n",
612
+ "torch.Size([64])\n",
613
+ "torch.Size([64, 54])\n",
614
+ "torch.Size([64, 1])\n",
615
+ "torch.Size([64])\n",
616
+ "torch.Size([64, 77])\n",
617
+ "torch.Size([64, 1])\n",
618
+ "torch.Size([64])\n",
619
+ "torch.Size([64, 87])\n",
620
+ "torch.Size([64, 1])\n",
621
+ "torch.Size([64])\n",
622
+ "torch.Size([64, 56])\n",
623
+ "torch.Size([64, 1])\n",
624
+ "torch.Size([64])\n",
625
+ "torch.Size([64, 62])\n",
626
+ "torch.Size([64, 1])\n",
627
+ "torch.Size([64])\n",
628
+ "torch.Size([64, 47])\n",
629
+ "torch.Size([64, 1])\n",
630
+ "torch.Size([64])\n",
631
+ "torch.Size([64, 58])\n",
632
+ "torch.Size([64, 1])\n",
633
+ "torch.Size([64])\n",
634
+ "torch.Size([64, 51])\n",
635
+ "torch.Size([64, 1])\n",
636
+ "torch.Size([64])\n",
637
+ "torch.Size([64, 60])\n",
638
+ "torch.Size([64, 1])\n",
639
+ "torch.Size([64])\n",
640
+ "torch.Size([64, 53])\n",
641
+ "torch.Size([64, 1])\n",
642
+ "torch.Size([64])\n",
643
+ "torch.Size([64, 54])\n",
644
+ "torch.Size([64, 1])\n",
645
+ "torch.Size([64])\n",
646
+ "torch.Size([64, 47])\n",
647
+ "torch.Size([64, 1])\n",
648
+ "torch.Size([64])\n",
649
+ "torch.Size([64, 55])\n",
650
+ "torch.Size([64, 1])\n",
651
+ "torch.Size([64])\n",
652
+ "torch.Size([64, 55])\n",
653
+ "torch.Size([64, 1])\n",
654
+ "torch.Size([64])\n",
655
+ "torch.Size([64, 63])\n",
656
+ "torch.Size([64, 1])\n",
657
+ "torch.Size([64])\n",
658
+ "torch.Size([64, 58])\n",
659
+ "torch.Size([64, 1])\n",
660
+ "torch.Size([64])\n",
661
+ "torch.Size([64, 60])\n",
662
+ "torch.Size([64, 1])\n",
663
+ "torch.Size([64])\n",
664
+ "torch.Size([64, 55])\n",
665
+ "torch.Size([64, 1])\n",
666
+ "torch.Size([64])\n",
667
+ "torch.Size([64, 79])\n",
668
+ "torch.Size([64, 1])\n",
669
+ "torch.Size([64])\n",
670
+ "torch.Size([64, 53])\n",
671
+ "torch.Size([64, 1])\n",
672
+ "torch.Size([64])\n",
673
+ "torch.Size([64, 68])\n",
674
+ "torch.Size([64, 1])\n",
675
+ "torch.Size([64])\n",
676
+ "torch.Size([64, 56])\n",
677
+ "torch.Size([64, 1])\n",
678
+ "torch.Size([64])\n",
679
+ "torch.Size([64, 53])\n",
680
+ "torch.Size([64, 1])\n",
681
+ "torch.Size([64])\n",
682
+ "torch.Size([64, 88])\n",
683
+ "torch.Size([64, 1])\n",
684
+ "torch.Size([64])\n",
685
+ "torch.Size([64, 50])\n",
686
+ "torch.Size([64, 1])\n",
687
+ "torch.Size([64])\n",
688
+ "torch.Size([64, 62])\n",
689
+ "torch.Size([64, 1])\n",
690
+ "torch.Size([64])\n",
691
+ "torch.Size([64, 67])\n",
692
+ "torch.Size([64, 1])\n",
693
+ "torch.Size([64])\n",
694
+ "torch.Size([64, 79])\n",
695
+ "torch.Size([64, 1])\n",
696
+ "torch.Size([64])\n",
697
+ "torch.Size([64, 80])\n",
698
+ "torch.Size([64, 1])\n",
699
+ "torch.Size([64])\n",
700
+ "torch.Size([64, 69])\n",
701
+ "torch.Size([64, 1])\n",
702
+ "torch.Size([64])\n",
703
+ "torch.Size([64, 67])\n",
704
+ "torch.Size([64, 1])\n",
705
+ "torch.Size([64])\n",
706
+ "torch.Size([64, 72])\n",
707
+ "torch.Size([64, 1])\n",
708
+ "torch.Size([64])\n",
709
+ "torch.Size([64, 60])\n",
710
+ "torch.Size([64, 1])\n",
711
+ "torch.Size([64])\n",
712
+ "torch.Size([64, 57])\n",
713
+ "torch.Size([64, 1])\n",
714
+ "torch.Size([64])\n",
715
+ "torch.Size([64, 55])\n",
716
+ "torch.Size([64, 1])\n",
717
+ "torch.Size([64])\n",
718
+ "torch.Size([64, 116])\n",
719
+ "torch.Size([64, 1])\n",
720
+ "torch.Size([64])\n",
721
+ "torch.Size([64, 54])\n",
722
+ "torch.Size([64, 1])\n",
723
+ "torch.Size([64])\n",
724
+ "torch.Size([64, 50])\n",
725
+ "torch.Size([64, 1])\n",
726
+ "torch.Size([64])\n",
727
+ "torch.Size([64, 64])\n",
728
+ "torch.Size([64, 1])\n",
729
+ "torch.Size([64])\n",
730
+ "torch.Size([64, 51])\n",
731
+ "torch.Size([64, 1])\n",
732
+ "torch.Size([64])\n",
733
+ "torch.Size([64, 70])\n",
734
+ "torch.Size([64, 1])\n",
735
+ "torch.Size([64])\n",
736
+ "torch.Size([64, 72])\n",
737
+ "torch.Size([64, 1])\n",
738
+ "torch.Size([64])\n",
739
+ "torch.Size([64, 59])\n",
740
+ "torch.Size([64, 1])\n",
741
+ "torch.Size([64])\n",
742
+ "torch.Size([64, 61])\n",
743
+ "torch.Size([64, 1])\n",
744
+ "torch.Size([64])\n",
745
+ "torch.Size([64, 54])\n",
746
+ "torch.Size([64, 1])\n",
747
+ "torch.Size([64])\n",
748
+ "torch.Size([64, 54])\n",
749
+ "torch.Size([64, 1])\n",
750
+ "torch.Size([64])\n",
751
+ "torch.Size([64, 63])\n",
752
+ "torch.Size([64, 1])\n",
753
+ "torch.Size([64])\n",
754
+ "torch.Size([64, 57])\n",
755
+ "torch.Size([64, 1])\n",
756
+ "torch.Size([64])\n",
757
+ "torch.Size([64, 60])\n",
758
+ "torch.Size([64, 1])\n",
759
+ "torch.Size([64])\n",
760
+ "torch.Size([64, 77])\n",
761
+ "torch.Size([64, 1])\n",
762
+ "torch.Size([64])\n",
763
+ "torch.Size([64, 67])\n",
764
+ "torch.Size([64, 1])\n",
765
+ "torch.Size([64])\n",
766
+ "torch.Size([64, 54])\n",
767
+ "torch.Size([64, 1])\n",
768
+ "torch.Size([64])\n",
769
+ "torch.Size([64, 87])\n",
770
+ "torch.Size([64, 1])\n",
771
+ "torch.Size([64])\n",
772
+ "torch.Size([64, 58])\n",
773
+ "torch.Size([64, 1])\n",
774
+ "torch.Size([64])\n",
775
+ "torch.Size([64, 59])\n",
776
+ "torch.Size([64, 1])\n",
777
+ "torch.Size([64])\n",
778
+ "torch.Size([64, 67])\n",
779
+ "torch.Size([64, 1])\n",
780
+ "torch.Size([64])\n",
781
+ "torch.Size([64, 64])\n",
782
+ "torch.Size([64, 1])\n",
783
+ "torch.Size([64])\n",
784
+ "torch.Size([64, 62])\n",
785
+ "torch.Size([64, 1])\n",
786
+ "torch.Size([64])\n",
787
+ "torch.Size([64, 55])\n",
788
+ "torch.Size([64, 1])\n",
789
+ "torch.Size([64])\n",
790
+ "torch.Size([64, 65])\n",
791
+ "torch.Size([64, 1])\n",
792
+ "torch.Size([64])\n",
793
+ "torch.Size([64, 70])\n",
794
+ "torch.Size([64, 1])\n",
795
+ "torch.Size([64])\n",
796
+ "torch.Size([64, 63])\n",
797
+ "torch.Size([64, 1])\n",
798
+ "torch.Size([64])\n",
799
+ "torch.Size([64, 59])\n",
800
+ "torch.Size([64, 1])\n",
801
+ "torch.Size([64])\n",
802
+ "torch.Size([64, 59])\n",
803
+ "torch.Size([64, 1])\n",
804
+ "torch.Size([64])\n",
805
+ "torch.Size([64, 59])\n",
806
+ "torch.Size([64, 1])\n",
807
+ "torch.Size([64])\n",
808
+ "torch.Size([64, 65])\n",
809
+ "torch.Size([64, 1])\n",
810
+ "torch.Size([64])\n",
811
+ "torch.Size([64, 71])\n",
812
+ "torch.Size([64, 1])\n",
813
+ "torch.Size([64])\n",
814
+ "torch.Size([64, 61])\n",
815
+ "torch.Size([64, 1])\n",
816
+ "torch.Size([64])\n",
817
+ "torch.Size([64, 56])\n",
818
+ "torch.Size([64, 1])\n",
819
+ "torch.Size([64])\n",
820
+ "torch.Size([64, 50])\n",
821
+ "torch.Size([64, 1])\n",
822
+ "torch.Size([64])\n",
823
+ "torch.Size([64, 61])\n",
824
+ "torch.Size([64, 1])\n",
825
+ "torch.Size([64])\n",
826
+ "torch.Size([64, 74])\n",
827
+ "torch.Size([64, 1])\n",
828
+ "torch.Size([64])\n",
829
+ "torch.Size([64, 59])\n",
830
+ "torch.Size([64, 1])\n",
831
+ "torch.Size([64])\n",
832
+ "torch.Size([64, 57])\n",
833
+ "torch.Size([64, 1])\n",
834
+ "torch.Size([64])\n",
835
+ "torch.Size([64, 52])\n",
836
+ "torch.Size([64, 1])\n",
837
+ "torch.Size([64])\n",
838
+ "torch.Size([64, 49])\n",
839
+ "torch.Size([64, 1])\n",
840
+ "torch.Size([64])\n",
841
+ "torch.Size([64, 57])\n",
842
+ "torch.Size([64, 1])\n",
843
+ "torch.Size([64])\n",
844
+ "torch.Size([64, 61])\n",
845
+ "torch.Size([64, 1])\n",
846
+ "torch.Size([64])\n",
847
+ "torch.Size([64, 52])\n",
848
+ "torch.Size([64, 1])\n",
849
+ "torch.Size([64])\n",
850
+ "torch.Size([64, 58])\n",
851
+ "torch.Size([64, 1])\n",
852
+ "torch.Size([64])\n",
853
+ "torch.Size([64, 56])\n",
854
+ "torch.Size([64, 1])\n",
855
+ "torch.Size([64])\n",
856
+ "torch.Size([64, 60])\n",
857
+ "torch.Size([64, 1])\n",
858
+ "torch.Size([64])\n",
859
+ "torch.Size([64, 54])\n",
860
+ "torch.Size([64, 1])\n",
861
+ "torch.Size([64])\n",
862
+ "torch.Size([64, 63])\n",
863
+ "torch.Size([64, 1])\n",
864
+ "torch.Size([64])\n",
865
+ "torch.Size([64, 56])\n",
866
+ "torch.Size([64, 1])\n",
867
+ "torch.Size([64])\n",
868
+ "torch.Size([64, 57])\n",
869
+ "torch.Size([64, 1])\n",
870
+ "torch.Size([64])\n",
871
+ "torch.Size([64, 61])\n",
872
+ "torch.Size([64, 1])\n",
873
+ "torch.Size([64])\n",
874
+ "torch.Size([64, 73])\n",
875
+ "torch.Size([64, 1])\n",
876
+ "torch.Size([64])\n",
877
+ "torch.Size([64, 65])\n",
878
+ "torch.Size([64, 1])\n",
879
+ "torch.Size([64])\n",
880
+ "torch.Size([64, 51])\n",
881
+ "torch.Size([64, 1])\n",
882
+ "torch.Size([64])\n",
883
+ "torch.Size([64, 69])\n",
884
+ "torch.Size([64, 1])\n",
885
+ "torch.Size([64])\n",
886
+ "torch.Size([64, 79])\n",
887
+ "torch.Size([64, 1])\n",
888
+ "torch.Size([64])\n",
889
+ "torch.Size([64, 80])\n",
890
+ "torch.Size([64, 1])\n",
891
+ "torch.Size([64])\n",
892
+ "torch.Size([64, 79])\n",
893
+ "torch.Size([64, 1])\n",
894
+ "torch.Size([64])\n",
895
+ "torch.Size([64, 63])\n",
896
+ "torch.Size([64, 1])\n",
897
+ "torch.Size([64])\n",
898
+ "torch.Size([64, 59])\n",
899
+ "torch.Size([64, 1])\n",
900
+ "torch.Size([64])\n",
901
+ "torch.Size([64, 51])\n",
902
+ "torch.Size([64, 1])\n",
903
+ "torch.Size([64])\n",
904
+ "torch.Size([64, 55])\n",
905
+ "torch.Size([64, 1])\n",
906
+ "torch.Size([64])\n",
907
+ "torch.Size([64, 55])\n",
908
+ "torch.Size([64, 1])\n",
909
+ "torch.Size([64])\n",
910
+ "torch.Size([64, 50])\n",
911
+ "torch.Size([64, 1])\n",
912
+ "torch.Size([64])\n",
913
+ "torch.Size([64, 75])\n",
914
+ "torch.Size([64, 1])\n",
915
+ "torch.Size([64])\n",
916
+ "torch.Size([64, 58])\n",
917
+ "torch.Size([64, 1])\n",
918
+ "torch.Size([64])\n",
919
+ "torch.Size([64, 54])\n",
920
+ "torch.Size([64, 1])\n",
921
+ "torch.Size([64])\n",
922
+ "torch.Size([64, 54])\n",
923
+ "torch.Size([64, 1])\n",
924
+ "torch.Size([64])\n",
925
+ "torch.Size([64, 57])\n",
926
+ "torch.Size([64, 1])\n",
927
+ "torch.Size([64])\n",
928
+ "torch.Size([64, 77])\n",
929
+ "torch.Size([64, 1])\n",
930
+ "torch.Size([64])\n",
931
+ "torch.Size([64, 55])\n",
932
+ "torch.Size([64, 1])\n",
933
+ "torch.Size([64])\n",
934
+ "torch.Size([64, 58])\n",
935
+ "torch.Size([64, 1])\n",
936
+ "torch.Size([64])\n",
937
+ "torch.Size([64, 56])\n",
938
+ "torch.Size([64, 1])\n",
939
+ "torch.Size([64])\n",
940
+ "torch.Size([64, 70])\n",
941
+ "torch.Size([64, 1])\n",
942
+ "torch.Size([64])\n",
943
+ "torch.Size([64, 56])\n",
944
+ "torch.Size([64, 1])\n",
945
+ "torch.Size([64])\n",
946
+ "torch.Size([64, 55])\n",
947
+ "torch.Size([64, 1])\n",
948
+ "torch.Size([64])\n",
949
+ "torch.Size([64, 51])\n",
950
+ "torch.Size([64, 1])\n",
951
+ "torch.Size([64])\n",
952
+ "torch.Size([64, 69])\n",
953
+ "torch.Size([64, 1])\n",
954
+ "torch.Size([64])\n",
955
+ "torch.Size([64, 64])\n",
956
+ "torch.Size([64, 1])\n",
957
+ "torch.Size([64])\n",
958
+ "torch.Size([64, 64])\n",
959
+ "torch.Size([64, 1])\n",
960
+ "torch.Size([64])\n",
961
+ "torch.Size([64, 71])\n",
962
+ "torch.Size([64, 1])\n",
963
+ "torch.Size([64])\n",
964
+ "torch.Size([64, 67])\n",
965
+ "torch.Size([64, 1])\n",
966
+ "torch.Size([64])\n",
967
+ "torch.Size([64, 54])\n",
968
+ "torch.Size([64, 1])\n",
969
+ "torch.Size([64])\n",
970
+ "torch.Size([64, 63])\n",
971
+ "torch.Size([64, 1])\n",
972
+ "torch.Size([64])\n",
973
+ "torch.Size([64, 67])\n",
974
+ "torch.Size([64, 1])\n",
975
+ "torch.Size([64])\n",
976
+ "torch.Size([64, 54])\n",
977
+ "torch.Size([64, 1])\n",
978
+ "torch.Size([64])\n",
979
+ "torch.Size([64, 67])\n",
980
+ "torch.Size([64, 1])\n",
981
+ "torch.Size([64])\n",
982
+ "torch.Size([64, 50])\n",
983
+ "torch.Size([64, 1])\n",
984
+ "torch.Size([64])\n",
985
+ "torch.Size([64, 62])\n",
986
+ "torch.Size([64, 1])\n",
987
+ "torch.Size([64])\n",
988
+ "torch.Size([64, 57])\n",
989
+ "torch.Size([64, 1])\n",
990
+ "torch.Size([64])\n",
991
+ "torch.Size([64, 57])\n",
992
+ "torch.Size([64, 1])\n",
993
+ "torch.Size([64])\n",
994
+ "torch.Size([64, 50])\n",
995
+ "torch.Size([64, 1])\n",
996
+ "torch.Size([64])\n",
997
+ "torch.Size([64, 59])\n",
998
+ "torch.Size([64, 1])\n",
999
+ "torch.Size([64])\n",
1000
+ "torch.Size([64, 58])\n",
1001
+ "torch.Size([64, 1])\n",
1002
+ "torch.Size([64])\n",
1003
+ "torch.Size([64, 63])\n",
1004
+ "torch.Size([64, 1])\n",
1005
+ "torch.Size([64])\n",
1006
+ "torch.Size([64, 59])\n",
1007
+ "torch.Size([64, 1])\n",
1008
+ "torch.Size([64])\n",
1009
+ "torch.Size([64, 49])\n",
1010
+ "torch.Size([64, 1])\n",
1011
+ "torch.Size([64])\n",
1012
+ "torch.Size([64, 53])\n",
1013
+ "torch.Size([64, 1])\n",
1014
+ "torch.Size([64])\n",
1015
+ "torch.Size([64, 50])\n",
1016
+ "torch.Size([64, 1])\n",
1017
+ "torch.Size([64])\n",
1018
+ "torch.Size([64, 49])\n",
1019
+ "torch.Size([64, 1])\n",
1020
+ "torch.Size([64])\n",
1021
+ "torch.Size([64, 72])\n",
1022
+ "torch.Size([64, 1])\n",
1023
+ "torch.Size([64])\n",
1024
+ "torch.Size([64, 74])\n",
1025
+ "torch.Size([64, 1])\n",
1026
+ "torch.Size([64])\n",
1027
+ "torch.Size([64, 67])\n",
1028
+ "torch.Size([64, 1])\n",
1029
+ "torch.Size([64])\n",
1030
+ "torch.Size([64, 50])\n",
1031
+ "torch.Size([64, 1])\n",
1032
+ "torch.Size([64])\n",
1033
+ "torch.Size([64, 54])\n",
1034
+ "torch.Size([64, 1])\n",
1035
+ "torch.Size([64])\n",
1036
+ "torch.Size([64, 52])\n",
1037
+ "torch.Size([64, 1])\n",
1038
+ "torch.Size([64])\n",
1039
+ "torch.Size([64, 74])\n",
1040
+ "torch.Size([64, 1])\n",
1041
+ "torch.Size([64])\n",
1042
+ "torch.Size([64, 63])\n",
1043
+ "torch.Size([64, 1])\n",
1044
+ "torch.Size([64])\n",
1045
+ "torch.Size([64, 51])\n",
1046
+ "torch.Size([64, 1])\n",
1047
+ "torch.Size([64])\n",
1048
+ "torch.Size([64, 63])\n",
1049
+ "torch.Size([64, 1])\n",
1050
+ "torch.Size([64])\n",
1051
+ "torch.Size([64, 56])\n",
1052
+ "torch.Size([64, 1])\n",
1053
+ "torch.Size([64])\n",
1054
+ "torch.Size([64, 65])\n",
1055
+ "torch.Size([64, 1])\n",
1056
+ "torch.Size([64])\n",
1057
+ "torch.Size([64, 58])\n",
1058
+ "torch.Size([64, 1])\n",
1059
+ "torch.Size([64])\n",
1060
+ "torch.Size([64, 54])\n",
1061
+ "torch.Size([64, 1])\n",
1062
+ "torch.Size([64])\n",
1063
+ "torch.Size([64, 67])\n",
1064
+ "torch.Size([64, 1])\n",
1065
+ "torch.Size([64])\n",
1066
+ "torch.Size([64, 56])\n",
1067
+ "torch.Size([64, 1])\n",
1068
+ "torch.Size([64])\n",
1069
+ "torch.Size([64, 65])\n",
1070
+ "torch.Size([64, 1])\n",
1071
+ "torch.Size([64])\n",
1072
+ "torch.Size([64, 55])\n",
1073
+ "torch.Size([64, 1])\n",
1074
+ "torch.Size([64])\n",
1075
+ "torch.Size([64, 55])\n",
1076
+ "torch.Size([64, 1])\n",
1077
+ "torch.Size([64])\n",
1078
+ "torch.Size([64, 73])\n",
1079
+ "torch.Size([64, 1])\n",
1080
+ "torch.Size([64])\n",
1081
+ "torch.Size([64, 75])\n",
1082
+ "torch.Size([64, 1])\n",
1083
+ "torch.Size([64])\n",
1084
+ "torch.Size([64, 59])\n",
1085
+ "torch.Size([64, 1])\n",
1086
+ "torch.Size([64])\n",
1087
+ "torch.Size([64, 58])\n",
1088
+ "torch.Size([64, 1])\n",
1089
+ "torch.Size([64])\n",
1090
+ "torch.Size([41, 48])\n",
1091
+ "torch.Size([41, 1])\n",
1092
+ "torch.Size([41])\n",
1093
+ "torch.Size([512, 75])\n",
1094
+ "torch.Size([512, 1])\n",
1095
+ "torch.Size([512])\n",
1096
+ "torch.Size([64, 73])\n",
1097
+ "torch.Size([64, 1])\n",
1098
+ "torch.Size([64])\n",
1099
+ "torch.Size([64, 60])\n",
1100
+ "torch.Size([64, 1])\n",
1101
+ "torch.Size([64])\n",
1102
+ "torch.Size([64, 71])\n",
1103
+ "torch.Size([64, 1])\n",
1104
+ "torch.Size([64])\n",
1105
+ "torch.Size([64, 55])\n",
1106
+ "torch.Size([64, 1])\n",
1107
+ "torch.Size([64])\n",
1108
+ "torch.Size([64, 59])\n",
1109
+ "torch.Size([64, 1])\n",
1110
+ "torch.Size([64])\n",
1111
+ "torch.Size([64, 74])\n",
1112
+ "torch.Size([64, 1])\n",
1113
+ "torch.Size([64])\n",
1114
+ "torch.Size([64, 54])\n",
1115
+ "torch.Size([64, 1])\n",
1116
+ "torch.Size([64])\n",
1117
+ "torch.Size([64, 51])\n",
1118
+ "torch.Size([64, 1])\n",
1119
+ "torch.Size([64])\n",
1120
+ "torch.Size([64, 73])\n",
1121
+ "torch.Size([64, 1])\n",
1122
+ "torch.Size([64])\n",
1123
+ "torch.Size([64, 76])\n",
1124
+ "torch.Size([64, 1])\n",
1125
+ "torch.Size([64])\n",
1126
+ "torch.Size([64, 53])\n",
1127
+ "torch.Size([64, 1])\n",
1128
+ "torch.Size([64])\n",
1129
+ "torch.Size([64, 51])\n",
1130
+ "torch.Size([64, 1])\n",
1131
+ "torch.Size([64])\n",
1132
+ "torch.Size([64, 60])\n",
1133
+ "torch.Size([64, 1])\n",
1134
+ "torch.Size([64])\n",
1135
+ "torch.Size([64, 58])\n",
1136
+ "torch.Size([64, 1])\n",
1137
+ "torch.Size([64])\n",
1138
+ "torch.Size([64, 74])\n",
1139
+ "torch.Size([64, 1])\n",
1140
+ "torch.Size([64])\n",
1141
+ "torch.Size([64, 69])\n",
1142
+ "torch.Size([64, 1])\n",
1143
+ "torch.Size([64])\n",
1144
+ "torch.Size([64, 52])\n",
1145
+ "torch.Size([64, 1])\n",
1146
+ "torch.Size([64])\n",
1147
+ "torch.Size([64, 72])\n",
1148
+ "torch.Size([64, 1])\n",
1149
+ "torch.Size([64])\n",
1150
+ "torch.Size([64, 62])\n",
1151
+ "torch.Size([64, 1])\n",
1152
+ "torch.Size([64])\n",
1153
+ "torch.Size([64, 54])\n",
1154
+ "torch.Size([64, 1])\n",
1155
+ "torch.Size([64])\n",
1156
+ "torch.Size([64, 52])\n",
1157
+ "torch.Size([64, 1])\n",
1158
+ "torch.Size([64])\n",
1159
+ "torch.Size([64, 67])\n",
1160
+ "torch.Size([64, 1])\n",
1161
+ "torch.Size([64])\n",
1162
+ "torch.Size([64, 54])\n",
1163
+ "torch.Size([64, 1])\n",
1164
+ "torch.Size([64])\n",
1165
+ "torch.Size([64, 53])\n",
1166
+ "torch.Size([64, 1])\n",
1167
+ "torch.Size([64])\n",
1168
+ "torch.Size([64, 58])\n",
1169
+ "torch.Size([64, 1])\n",
1170
+ "torch.Size([64])\n",
1171
+ "torch.Size([64, 58])\n",
1172
+ "torch.Size([64, 1])\n",
1173
+ "torch.Size([64])\n",
1174
+ "torch.Size([64, 56])\n",
1175
+ "torch.Size([64, 1])\n",
1176
+ "torch.Size([64])\n",
1177
+ "torch.Size([64, 67])\n",
1178
+ "torch.Size([64, 1])\n",
1179
+ "torch.Size([64])\n",
1180
+ "torch.Size([64, 55])\n",
1181
+ "torch.Size([64, 1])\n",
1182
+ "torch.Size([64])\n",
1183
+ "torch.Size([64, 71])\n",
1184
+ "torch.Size([64, 1])\n",
1185
+ "torch.Size([64])\n",
1186
+ "torch.Size([64, 71])\n",
1187
+ "torch.Size([64, 1])\n",
1188
+ "torch.Size([64])\n",
1189
+ "torch.Size([64, 68])\n",
1190
+ "torch.Size([64, 1])\n",
1191
+ "torch.Size([64])\n",
1192
+ "torch.Size([64, 63])\n",
1193
+ "torch.Size([64, 1])\n",
1194
+ "torch.Size([64])\n",
1195
+ "torch.Size([64, 49])\n",
1196
+ "torch.Size([64, 1])\n",
1197
+ "torch.Size([64])\n",
1198
+ "torch.Size([64, 52])\n",
1199
+ "torch.Size([64, 1])\n",
1200
+ "torch.Size([64])\n",
1201
+ "torch.Size([64, 54])\n",
1202
+ "torch.Size([64, 1])\n",
1203
+ "torch.Size([64])\n",
1204
+ "torch.Size([64, 72])\n",
1205
+ "torch.Size([64, 1])\n",
1206
+ "torch.Size([64])\n",
1207
+ "torch.Size([64, 77])\n",
1208
+ "torch.Size([64, 1])\n",
1209
+ "torch.Size([64])\n",
1210
+ "torch.Size([64, 59])\n",
1211
+ "torch.Size([64, 1])\n",
1212
+ "torch.Size([64])\n",
1213
+ "torch.Size([64, 58])\n",
1214
+ "torch.Size([64, 1])\n",
1215
+ "torch.Size([64])\n",
1216
+ "torch.Size([64, 72])\n",
1217
+ "torch.Size([64, 1])\n",
1218
+ "torch.Size([64])\n",
1219
+ "torch.Size([64, 65])\n",
1220
+ "torch.Size([64, 1])\n",
1221
+ "torch.Size([64])\n",
1222
+ "torch.Size([64, 79])\n",
1223
+ "torch.Size([64, 1])\n",
1224
+ "torch.Size([64])\n",
1225
+ "torch.Size([64, 65])\n",
1226
+ "torch.Size([64, 1])\n",
1227
+ "torch.Size([64])\n",
1228
+ "torch.Size([64, 59])\n",
1229
+ "torch.Size([64, 1])\n",
1230
+ "torch.Size([64])\n",
1231
+ "torch.Size([64, 79])\n",
1232
+ "torch.Size([64, 1])\n",
1233
+ "torch.Size([64])\n",
1234
+ "torch.Size([64, 54])\n",
1235
+ "torch.Size([64, 1])\n",
1236
+ "torch.Size([64])\n",
1237
+ "torch.Size([64, 50])\n",
1238
+ "torch.Size([64, 1])\n",
1239
+ "torch.Size([64])\n",
1240
+ "torch.Size([64, 55])\n",
1241
+ "torch.Size([64, 1])\n",
1242
+ "torch.Size([64])\n",
1243
+ "torch.Size([64, 65])\n",
1244
+ "torch.Size([64, 1])\n",
1245
+ "torch.Size([64])\n",
1246
+ "torch.Size([64, 60])\n",
1247
+ "torch.Size([64, 1])\n",
1248
+ "torch.Size([64])\n",
1249
+ "torch.Size([64, 59])\n",
1250
+ "torch.Size([64, 1])\n",
1251
+ "torch.Size([64])\n",
1252
+ "torch.Size([64, 60])\n",
1253
+ "torch.Size([64, 1])\n",
1254
+ "torch.Size([64])\n",
1255
+ "torch.Size([64, 59])\n",
1256
+ "torch.Size([64, 1])\n",
1257
+ "torch.Size([64])\n",
1258
+ "torch.Size([64, 54])\n",
1259
+ "torch.Size([64, 1])\n",
1260
+ "torch.Size([64])\n",
1261
+ "torch.Size([64, 50])\n",
1262
+ "torch.Size([64, 1])\n",
1263
+ "torch.Size([64])\n",
1264
+ "torch.Size([64, 69])\n",
1265
+ "torch.Size([64, 1])\n",
1266
+ "torch.Size([64])\n",
1267
+ "torch.Size([64, 55])\n",
1268
+ "torch.Size([64, 1])\n",
1269
+ "torch.Size([64])\n",
1270
+ "torch.Size([64, 57])\n",
1271
+ "torch.Size([64, 1])\n",
1272
+ "torch.Size([64])\n",
1273
+ "torch.Size([64, 63])\n",
1274
+ "torch.Size([64, 1])\n",
1275
+ "torch.Size([64])\n",
1276
+ "torch.Size([64, 72])\n",
1277
+ "torch.Size([64, 1])\n",
1278
+ "torch.Size([64])\n",
1279
+ "torch.Size([64, 63])\n",
1280
+ "torch.Size([64, 1])\n",
1281
+ "torch.Size([64])\n",
1282
+ "torch.Size([64, 65])\n",
1283
+ "torch.Size([64, 1])\n",
1284
+ "torch.Size([64])\n",
1285
+ "torch.Size([64, 77])\n",
1286
+ "torch.Size([64, 1])\n",
1287
+ "torch.Size([64])\n",
1288
+ "torch.Size([64, 57])\n",
1289
+ "torch.Size([64, 1])\n",
1290
+ "torch.Size([64])\n",
1291
+ "torch.Size([64, 56])\n",
1292
+ "torch.Size([64, 1])\n",
1293
+ "torch.Size([64])\n",
1294
+ "torch.Size([64, 52])\n",
1295
+ "torch.Size([64, 1])\n",
1296
+ "torch.Size([64])\n",
1297
+ "torch.Size([64, 54])\n",
1298
+ "torch.Size([64, 1])\n",
1299
+ "torch.Size([64])\n",
1300
+ "torch.Size([64, 72])\n",
1301
+ "torch.Size([64, 1])\n",
1302
+ "torch.Size([64])\n",
1303
+ "torch.Size([64, 70])\n",
1304
+ "torch.Size([64, 1])\n",
1305
+ "torch.Size([64])\n",
1306
+ "torch.Size([64, 60])\n",
1307
+ "torch.Size([64, 1])\n",
1308
+ "torch.Size([64])\n",
1309
+ "torch.Size([64, 67])\n",
1310
+ "torch.Size([64, 1])\n",
1311
+ "torch.Size([64])\n",
1312
+ "torch.Size([64, 64])\n",
1313
+ "torch.Size([64, 1])\n",
1314
+ "torch.Size([64])\n",
1315
+ "torch.Size([64, 56])\n",
1316
+ "torch.Size([64, 1])\n",
1317
+ "torch.Size([64])\n",
1318
+ "torch.Size([64, 55])\n",
1319
+ "torch.Size([64, 1])\n",
1320
+ "torch.Size([64])\n",
1321
+ "torch.Size([64, 64])\n",
1322
+ "torch.Size([64, 1])\n",
1323
+ "torch.Size([64])\n",
1324
+ "torch.Size([64, 88])\n",
1325
+ "torch.Size([64, 1])\n",
1326
+ "torch.Size([64])\n",
1327
+ "torch.Size([64, 80])\n",
1328
+ "torch.Size([64, 1])\n",
1329
+ "torch.Size([64])\n",
1330
+ "torch.Size([64, 62])\n",
1331
+ "torch.Size([64, 1])\n",
1332
+ "torch.Size([64])\n",
1333
+ "torch.Size([64, 48])\n",
1334
+ "torch.Size([64, 1])\n",
1335
+ "torch.Size([64])\n",
1336
+ "torch.Size([64, 60])\n",
1337
+ "torch.Size([64, 1])\n",
1338
+ "torch.Size([64])\n",
1339
+ "torch.Size([64, 79])\n",
1340
+ "torch.Size([64, 1])\n",
1341
+ "torch.Size([64])\n",
1342
+ "torch.Size([64, 56])\n",
1343
+ "torch.Size([64, 1])\n",
1344
+ "torch.Size([64])\n",
1345
+ "torch.Size([64, 59])\n",
1346
+ "torch.Size([64, 1])\n",
1347
+ "torch.Size([64])\n",
1348
+ "torch.Size([64, 57])\n",
1349
+ "torch.Size([64, 1])\n",
1350
+ "torch.Size([64])\n",
1351
+ "torch.Size([64, 60])\n",
1352
+ "torch.Size([64, 1])\n",
1353
+ "torch.Size([64])\n",
1354
+ "torch.Size([64, 54])\n",
1355
+ "torch.Size([64, 1])\n",
1356
+ "torch.Size([64])\n",
1357
+ "torch.Size([64, 51])\n",
1358
+ "torch.Size([64, 1])\n",
1359
+ "torch.Size([64])\n",
1360
+ "torch.Size([64, 70])\n",
1361
+ "torch.Size([64, 1])\n",
1362
+ "torch.Size([64])\n",
1363
+ "torch.Size([64, 53])\n",
1364
+ "torch.Size([64, 1])\n",
1365
+ "torch.Size([64])\n",
1366
+ "torch.Size([64, 59])\n",
1367
+ "torch.Size([64, 1])\n",
1368
+ "torch.Size([64])\n",
1369
+ "torch.Size([64, 79])\n",
1370
+ "torch.Size([64, 1])\n",
1371
+ "torch.Size([64])\n",
1372
+ "torch.Size([64, 59])\n",
1373
+ "torch.Size([64, 1])\n",
1374
+ "torch.Size([64])\n",
1375
+ "torch.Size([64, 79])\n",
1376
+ "torch.Size([64, 1])\n",
1377
+ "torch.Size([64])\n",
1378
+ "torch.Size([64, 63])\n",
1379
+ "torch.Size([64, 1])\n",
1380
+ "torch.Size([64])\n",
1381
+ "torch.Size([64, 53])\n",
1382
+ "torch.Size([64, 1])\n",
1383
+ "torch.Size([64])\n",
1384
+ "torch.Size([64, 55])\n",
1385
+ "torch.Size([64, 1])\n",
1386
+ "torch.Size([64])\n",
1387
+ "torch.Size([64, 57])\n",
1388
+ "torch.Size([64, 1])\n",
1389
+ "torch.Size([64])\n",
1390
+ "torch.Size([64, 53])\n",
1391
+ "torch.Size([64, 1])\n",
1392
+ "torch.Size([64])\n",
1393
+ "torch.Size([64, 59])\n",
1394
+ "torch.Size([64, 1])\n",
1395
+ "torch.Size([64])\n"
1396
+ ]
1397
+ },
1398
+ {
1399
+ "ename": "KeyboardInterrupt",
1400
+ "evalue": "",
1401
+ "output_type": "error",
1402
+ "traceback": [
1403
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1404
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1405
+ "Cell \u001b[0;32mIn[7], line 28\u001b[0m\n\u001b[1;32m 18\u001b[0m data_collator \u001b[38;5;241m=\u001b[39m DataCollatorWithPadding(tokenizer\u001b[38;5;241m=\u001b[39mtokenizer)\n\u001b[1;32m 19\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m 20\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 21\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 25\u001b[0m data_collator\u001b[38;5;241m=\u001b[39mdata_collator,\n\u001b[1;32m 26\u001b[0m )\n\u001b[0;32m---> 28\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 29\u001b[0m trainer\u001b[38;5;241m.\u001b[39mpush_to_hub(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfactual-consistency-regression-ja\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
1406
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1582\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1579\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1580\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 1581\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 1582\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1583\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1584\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1585\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1586\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1587\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1588\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1589\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
1407
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1950\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1945\u001b[0m nn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mclip_grad_norm_(\n\u001b[1;32m 1946\u001b[0m amp\u001b[38;5;241m.\u001b[39mmaster_params(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer),\n\u001b[1;32m 1947\u001b[0m args\u001b[38;5;241m.\u001b[39mmax_grad_norm,\n\u001b[1;32m 1948\u001b[0m )\n\u001b[1;32m 1949\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1950\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclip_grad_norm_\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1951\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1952\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_grad_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1953\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1955\u001b[0m \u001b[38;5;66;03m# Optimizer step\u001b[39;00m\n\u001b[1;32m 1956\u001b[0m optimizer_was_run \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
1408
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py:2121\u001b[0m, in \u001b[0;36mAccelerator.clip_grad_norm_\u001b[0;34m(self, parameters, max_norm, norm_type)\u001b[0m\n\u001b[1;32m 2119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 2120\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39munscale_gradients()\n\u001b[0;32m-> 2121\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclip_grad_norm_\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparameters\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnorm_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnorm_type\u001b[49m\u001b[43m)\u001b[49m\n",
1409
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch_xla/_patched_functions.py:49\u001b[0m, in \u001b[0;36mclip_grad_norm_\u001b[0;34m(parameters, max_norm, norm_type, error_if_nonfinite, foreach)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m error_if_nonfinite \u001b[38;5;129;01mand\u001b[39;00m (total_norm\u001b[38;5;241m.\u001b[39misnan() \u001b[38;5;129;01mor\u001b[39;00m total_norm\u001b[38;5;241m.\u001b[39misinf()):\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 46\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mThe norm of order \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnorm_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for a gradient from `parameters` \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mis non-finite, so it cannot be clipped. This error can be \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdisabled with `error_if_nonfinite=False`\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 49\u001b[0m clip_coef \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m/\u001b[39m (total_norm \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1e-6\u001b[39m)\n\u001b[1;32m 50\u001b[0m clip_value \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(clip_coef \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m1\u001b[39m, clip_coef,\n\u001b[1;32m 51\u001b[0m torch\u001b[38;5;241m.\u001b[39mtensor(\u001b[38;5;241m1.\u001b[39m, device\u001b[38;5;241m=\u001b[39mdevice))\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m parameters:\n",
1410
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1411
+ ]
1412
+ }
1413
+ ],
1414
+ "source": [
1415
+ "model = ConsistentSentenceRegressor(\n",
1416
+ " freeze_bert=True)\n",
1417
+ "\n",
1418
+ "training_args = TrainingArguments(\n",
1419
+ " output_dir=\".\",\n",
1420
+ " learning_rate=1e-5,\n",
1421
+ " per_device_train_batch_size=64,\n",
1422
+ " num_train_epochs=100,\n",
1423
+ " weight_decay=0.02,\n",
1424
+ " evaluation_strategy=\"epoch\",\n",
1425
+ " eval_accumulation_steps=1,\n",
1426
+ " save_strategy=\"epoch\",\n",
1427
+ " load_best_model_at_end=True,\n",
1428
+ " push_to_hub=True,\n",
1429
+ ")\n",
1430
+ "\n",
1431
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
1432
+ "trainer = Trainer(\n",
1433
+ " model=model,\n",
1434
+ " args=training_args,\n",
1435
+ " train_dataset=tokenized_dataset[\"train\"],\n",
1436
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
1437
+ " tokenizer=tokenizer,\n",
1438
+ " data_collator=data_collator,\n",
1439
+ ")\n",
1440
+ "\n",
1441
+ "trainer.train()\n",
1442
+ "trainer.push_to_hub('factual-consistency-regression-ja')"
1443
+ ]
1444
+ },
1445
+ {
1446
+ "cell_type": "code",
1447
+ "execution_count": null,
1448
+ "id": "a6eb93f7-5a38-49a2-be0d-e42267e23a0a",
1449
+ "metadata": {},
1450
+ "outputs": [],
1451
+ "source": []
1452
+ },
1453
+ {
1454
+ "cell_type": "code",
1455
+ "execution_count": null,
1456
+ "id": "3638c8d8-fc85-4caf-83a4-4fd2ad6fb95d",
1457
+ "metadata": {},
1458
+ "outputs": [],
1459
+ "source": []
1460
+ }
1461
+ ],
1462
+ "metadata": {
1463
+ "environment": {
1464
+ "kernel": "python3",
1465
+ "name": "pytorch-gpu.2-0.m112",
1466
+ "type": "gcloud",
1467
+ "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.2-0:m112"
1468
+ },
1469
+ "kernelspec": {
1470
+ "display_name": "Python 3",
1471
+ "language": "python",
1472
+ "name": "python3"
1473
+ },
1474
+ "language_info": {
1475
+ "codemirror_mode": {
1476
+ "name": "ipython",
1477
+ "version": 3
1478
+ },
1479
+ "file_extension": ".py",
1480
+ "mimetype": "text/x-python",
1481
+ "name": "python",
1482
+ "nbconvert_exporter": "python",
1483
+ "pygments_lexer": "ipython3",
1484
+ "version": "3.10.12"
1485
+ }
1486
+ },
1487
+ "nbformat": 4,
1488
+ "nbformat_minor": 5
1489
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d7456b16ac0d734668b10f0a43291751cb4c4aa6ce7c6112c5e87aaf79a0413
3
+ size 4027
utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import datasets
4
+ import numpy as np
5
+ import evaluate
6
+ import torch
7
+ from transformers import AutoModel, DistilBertForSequenceClassification
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+ from typing import Optional
10
+
11
+ SEP_TOKEN = '[SEP]'
12
+ LABEL2NUM = {'entailment': 1, 'neutral': 0.5, 'contradiction': 0}
13
+
14
+ def format_dataset(arr):
15
+ text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr]
16
+ label = [LABEL2NUM[el['label']] for el in arr]
17
+ new_df = pd.DataFrame({'text': text, 'label': label})
18
+ return new_df.sample(frac=1, random_state=42).reset_index(drop=True)
19
+
20
+ # Load dataset
21
+ def load_dataset(path):
22
+ train_array = []
23
+ with open(path) as f:
24
+ for line in f.readlines():
25
+ if line:
26
+ train_array.append(json.loads(line))
27
+ df = format_dataset(train_array)
28
+ # Split dataset into train and val
29
+ df_train = df.iloc[512:, :]
30
+ # We do not need much test data
31
+ df_test = df.iloc[:512, :]
32
+ print(df_train[:10])
33
+ print(df_test[:10])
34
+
35
+ factual_consistency_dataset = datasets.dataset_dict.DatasetDict()
36
+ factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas(
37
+ df_train[["text", "label"]])
38
+ factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas(
39
+ df_test[["text", "label"]])
40
+
41
+ return factual_consistency_dataset
42
+
43
+
44
+ class ConsistentSentenceRegressor(DistilBertForSequenceClassification):
45
+
46
+ def __init__(self, freeze_bert=True):
47
+ base_model = AutoModel.from_pretrained(
48
+ 'line-corporation/line-distilbert-base-japanese')
49
+
50
+ config = base_model.config
51
+ config.problem_type = "regression"
52
+ config.num_labels = 1
53
+ super(ConsistentSentenceRegressor, self).__init__(config=config)
54
+
55
+ self.distilbert = base_model
56
+
57
+ # Replace the classifier with a single-neuron linear layer for regression
58
+ self.classifier = torch.nn.Linear(config.dim, config.num_labels)
59
+
60
+ if not freeze_bert:
61
+ return
62
+
63
+ for param in self.distilbert.parameters():
64
+ param.requires_grad = False
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.Tensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ head_mask: Optional[torch.Tensor] = None,
71
+ inputs_embeds: Optional[torch.Tensor] = None,
72
+ labels: Optional[torch.LongTensor] = None,
73
+ output_attentions: Optional[bool] = None,
74
+ output_hidden_states: Optional[bool] = None,
75
+ return_dict: Optional[bool] = None,
76
+ ):
77
+ print(input_ids.shape)
78
+ outputs = super().forward(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ head_mask=head_mask,
82
+ inputs_embeds=inputs_embeds,
83
+ labels=labels,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+ print(outputs.logits.shape)
89
+ logits = outputs.logits.squeeze(-1) # Remove the last dimension to match target tensor shape
90
+
91
+ print(logits.shape)
92
+
93
+
94
+ return logits
95
+
96
+
97
+ # Set up evaluation metridef get_metrics():
98
+
99
+ def get_metrics():
100
+ metric = evaluate.load("mse")
101
+
102
+ def compute_metrics(eval_pred):
103
+ predictions, labels = eval_pred
104
+ print(predictions.shape)
105
+ print(labels.shape)
106
+ return metric.compute(predictions=predictions, references=labels)
107
+
108
+ return compute_metrics