Esmail-AGumaan
commited on
Commit
•
34bd885
1
Parent(s):
41eeca8
Upload 5 files
Browse files- Decoder.ipynb +165 -0
- Encoder.ipynb +156 -0
- PE.ipynb +0 -0
- TransformerBlock.ipynb +143 -0
- transformer.py +241 -0
Decoder.ipynb
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import math\n",
|
10 |
+
"import torch \n",
|
11 |
+
"import torch.nn as nn\n",
|
12 |
+
"\n",
|
13 |
+
"class InputEmbeddingsLayer(nn.Module):\n",
|
14 |
+
" def __init__(self, d_model: int, vocab_size: int) -> None:\n",
|
15 |
+
" super().__init__()\n",
|
16 |
+
" self.d_model = d_model\n",
|
17 |
+
" self.vocab_size = vocab_size\n",
|
18 |
+
" self.embedding = nn.Embedding(vocab_size, d_model)\n",
|
19 |
+
" def forward(self, x):\n",
|
20 |
+
" return self.embedding(x) * math.sqrt(self.d_model)\n",
|
21 |
+
"\n",
|
22 |
+
"class PositionalEncodingLayer(nn.Module):\n",
|
23 |
+
" def __init__(self, d_model: int, sequence_length: int, dropout: float) -> None:\n",
|
24 |
+
" super().__init__()\n",
|
25 |
+
" self.d_model = d_model\n",
|
26 |
+
" self.sequence_length = sequence_length\n",
|
27 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
28 |
+
"\n",
|
29 |
+
" PE = torch.zeros(sequence_length, d_model)\n",
|
30 |
+
" Position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1)\n",
|
31 |
+
" deviation_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
|
32 |
+
" \n",
|
33 |
+
" PE[:, 0::2] = torch.sin(Position * deviation_term)\n",
|
34 |
+
" PE[:, 1::2] = torch.cos(Position * deviation_term)\n",
|
35 |
+
" PE = PE.unsqueeze(0)\n",
|
36 |
+
" self.register_buffer(\"PE\", PE)\n",
|
37 |
+
" def forward(self, x):\n",
|
38 |
+
" x = x + (self.PE[:, :x.shape[1], :]).requires_grad_(False)\n",
|
39 |
+
" return self.dropout(x)\n",
|
40 |
+
"\n",
|
41 |
+
"class NormalizationLayer(nn.Module):\n",
|
42 |
+
" def __init__(self, Epslone: float = 10**-6) -> None:\n",
|
43 |
+
" super().__init__()\n",
|
44 |
+
" self.Epslone = Epslone\n",
|
45 |
+
" self.Alpha = nn.Parameter(torch.ones(1))\n",
|
46 |
+
" self.Bias = nn.Parameter(torch.ones(1))\n",
|
47 |
+
" def forward(self, x):\n",
|
48 |
+
" mean = x.mean(dim = -1, keepdim = True)\n",
|
49 |
+
" std = x.std(dim = -1, keepdim = True)\n",
|
50 |
+
" return self.Alpha * (x - mean) / (std + self.Epslone) + self.Bias\n",
|
51 |
+
"\n",
|
52 |
+
"class FeedForwardBlock(nn.Module):\n",
|
53 |
+
" def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:\n",
|
54 |
+
" super().__init__()\n",
|
55 |
+
" self.Linear_1 = nn.Linear(d_model, d_ff)\n",
|
56 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
57 |
+
" self.Linear_2 = nn.Linear(d_ff, d_model)\n",
|
58 |
+
" def forward(self, x):\n",
|
59 |
+
" return self.Linear_2(self.dropout(torch.relu(self.Linear_1(x))))"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"class MultiHeadAttentionBlock(nn.Module):\n",
|
69 |
+
" def __init__(self, d_model: int, heads: int, dropout: float) -> None:\n",
|
70 |
+
" super().__init__()\n",
|
71 |
+
" self.d_model = d_model\n",
|
72 |
+
" self.heads = heads \n",
|
73 |
+
" assert d_model % heads == 0, \"d_model is not divisable by heads\"\n",
|
74 |
+
"\n",
|
75 |
+
" self.d_k = d_model // heads\n",
|
76 |
+
"\n",
|
77 |
+
" self.W_Q = nn.Linear(d_model, d_model)\n",
|
78 |
+
" self.W_K = nn.Linear(d_model, d_model)\n",
|
79 |
+
" self.W_V = nn.Linear(d_model, d_model)\n",
|
80 |
+
"\n",
|
81 |
+
" self.W_O = nn.Linear(d_model, d_model)\n",
|
82 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
83 |
+
" \n",
|
84 |
+
" @staticmethod\n",
|
85 |
+
" def Attention(Query, Key, Value, mask, dropout: nn.Module):\n",
|
86 |
+
" d_k = Query.shape[-1]\n",
|
87 |
+
"\n",
|
88 |
+
" self_attention_score = (Query @ Key.transpose(-2,-1)) / math.sqrt(d_k)\n",
|
89 |
+
" if mask is not None:\n",
|
90 |
+
" self_attention_score.masked_fill_(mask == 0, -1e9)\n",
|
91 |
+
" self_attention_score = self_attention_score.softmax(dim = -1)\n",
|
92 |
+
"\n",
|
93 |
+
" if dropout is not None:\n",
|
94 |
+
" self_attention_score = dropout(self_attention_score)\n",
|
95 |
+
" return self_attention_score @ Value\n",
|
96 |
+
" def forward(self, query, key, value, mask):\n",
|
97 |
+
" Query = self.W_Q(query)\n",
|
98 |
+
" Key = self.W_K(key)\n",
|
99 |
+
" Value = self.W_V(value)\n",
|
100 |
+
"\n",
|
101 |
+
" Query = Query.view(Query.shape[0], Query.shape[1], self.heads, self.d_k).transpose(1,2)\n",
|
102 |
+
" Key = Key.view(Key.shape[0], Key.shape[1], self.heads, self.d_k).transpose(1,2)\n",
|
103 |
+
" Value = Value.view(Value.shape[0], Value.shape[1], self.heads, self.d_k).transpose(1,2)\n",
|
104 |
+
"\n",
|
105 |
+
" x, self.self_attention_score = MultiHeadAttentionBlock.Attention(Query, Key, Value, mask, self.dropout)\n",
|
106 |
+
" x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.heads * self.d_k)\n",
|
107 |
+
" return self.W_O(x)\n",
|
108 |
+
"\n",
|
109 |
+
"class ResidualConnection(nn.Module):\n",
|
110 |
+
" def __init__(self, dropout: float) -> None:\n",
|
111 |
+
" super().__init__()\n",
|
112 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
113 |
+
" self.normalization = NormalizationLayer()\n",
|
114 |
+
" def forward(self, x, subLayer):\n",
|
115 |
+
" return x + self.dropout(subLayer(self.normalization(x)))"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": null,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"# Building the decoder block \n",
|
125 |
+
"class DecoderBlock(nn.Module):\n",
|
126 |
+
" def __init__(self, decoder_self_attention_block: MultiHeadAttentionBlock, decoder_cross_attention_block: MultiHeadAttentionBlock, decoder_feed_forward_block: FeedForwardBlock, dropout: float) -> None:\n",
|
127 |
+
" super().__init__()\n",
|
128 |
+
" self.decoder_self_attention_block = decoder_self_attention_block\n",
|
129 |
+
" self.decoder_cross_attention_block = decoder_cross_attention_block\n",
|
130 |
+
" self.decoder_feed_forward_block = decoder_feed_forward_block\n",
|
131 |
+
" self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])\n",
|
132 |
+
" def forward(self, x, Encoder_output, source_mask, target_mask):\n",
|
133 |
+
" x = self.residual_connection[0](x, lambda x: self.decoder_self_attention_block(x, x, x, source_mask))\n",
|
134 |
+
" x = self.residual_connection[1](x, lambda x: self.decoder_cross_attention_block(x, Encoder_output, Encoder_output, target_mask))\n",
|
135 |
+
" x = self.residual_connection[2](x, self.decoder_feed_forward_block)\n",
|
136 |
+
" return x\n",
|
137 |
+
"\n",
|
138 |
+
"class Decoder(nn.Module):\n",
|
139 |
+
" def __init__(self, Layers: nn.ModuleList) -> None:\n",
|
140 |
+
" super().__init__()\n",
|
141 |
+
" self.Layers = Layers\n",
|
142 |
+
" self.normalization = NormalizationLayer()\n",
|
143 |
+
" def forward(self, x, Encoder_output, source_mask, target_mask):\n",
|
144 |
+
" for layer in self.Layers:\n",
|
145 |
+
" x = layer(x, Encoder_output, source_mask, target_mask)\n",
|
146 |
+
" return self.normalization(x)\n",
|
147 |
+
"\n",
|
148 |
+
"class LinearLayer(nn.Module):\n",
|
149 |
+
" def __init__(self, d_model: int, vocab_size: int) -> None:\n",
|
150 |
+
" super().__init__()\n",
|
151 |
+
" self.Linear = nn.Linear(d_model, vocab_size)\n",
|
152 |
+
" def forward(self, x):\n",
|
153 |
+
" return self.Linear(x)"
|
154 |
+
]
|
155 |
+
}
|
156 |
+
],
|
157 |
+
"metadata": {
|
158 |
+
"language_info": {
|
159 |
+
"name": "python"
|
160 |
+
},
|
161 |
+
"orig_nbformat": 4
|
162 |
+
},
|
163 |
+
"nbformat": 4,
|
164 |
+
"nbformat_minor": 2
|
165 |
+
}
|
Encoder.ipynb
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import math\n",
|
10 |
+
"import torch \n",
|
11 |
+
"import torch.nn as nn\n",
|
12 |
+
"\n",
|
13 |
+
"class InputEmbeddingsLayer(nn.Module):\n",
|
14 |
+
" def __init__(self, d_model: int, vocab_size: int) -> None:\n",
|
15 |
+
" super().__init__()\n",
|
16 |
+
" self.d_model = d_model\n",
|
17 |
+
" self.vocab_size = vocab_size\n",
|
18 |
+
" self.embedding = nn.Embedding(vocab_size, d_model)\n",
|
19 |
+
" def forward(self, x):\n",
|
20 |
+
" return self.embedding(x) * math.sqrt(self.d_model)\n",
|
21 |
+
"\n",
|
22 |
+
"class PositionalEncodingLayer(nn.Module):\n",
|
23 |
+
" def __init__(self, d_model: int, sequence_length: int, dropout: float) -> None:\n",
|
24 |
+
" super().__init__()\n",
|
25 |
+
" self.d_model = d_model\n",
|
26 |
+
" self.sequence_length = sequence_length\n",
|
27 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
28 |
+
"\n",
|
29 |
+
" PE = torch.zeros(sequence_length, d_model)\n",
|
30 |
+
" Position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1)\n",
|
31 |
+
" deviation_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
|
32 |
+
" \n",
|
33 |
+
" PE[:, 0::2] = torch.sin(Position * deviation_term)\n",
|
34 |
+
" PE[:, 1::2] = torch.cos(Position * deviation_term)\n",
|
35 |
+
" PE = PE.unsqueeze(0)\n",
|
36 |
+
" self.register_buffer(\"PE\", PE)\n",
|
37 |
+
" def forward(self, x):\n",
|
38 |
+
" x = x + (self.PE[:, :x.shape[1], :]).requires_grad_(False)\n",
|
39 |
+
" return self.dropout(x)\n",
|
40 |
+
"\n",
|
41 |
+
"class NormalizationLayer(nn.Module):\n",
|
42 |
+
" def __init__(self, Epslone: float = 10**-6) -> None:\n",
|
43 |
+
" super().__init__()\n",
|
44 |
+
" self.Epslone = Epslone\n",
|
45 |
+
" self.Alpha = nn.Parameter(torch.ones(1))\n",
|
46 |
+
" self.Bias = nn.Parameter(torch.ones(1))\n",
|
47 |
+
" def forward(self, x):\n",
|
48 |
+
" mean = x.mean(dim = -1, keepdim = True)\n",
|
49 |
+
" std = x.std(dim = -1, keepdim = True)\n",
|
50 |
+
" return self.Alpha * (x - mean) / (std + self.Epslone) + self.Bias\n",
|
51 |
+
"\n",
|
52 |
+
"class FeedForwardBlock(nn.Module):\n",
|
53 |
+
" def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:\n",
|
54 |
+
" super().__init__()\n",
|
55 |
+
" self.Linear_1 = nn.Linear(d_model, d_ff)\n",
|
56 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
57 |
+
" self.Linear_2 = nn.Linear(d_ff, d_model)\n",
|
58 |
+
" def forward(self, x):\n",
|
59 |
+
" return self.Linear_2(self.dropout(torch.relu(self.Linear_1(x))))"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"class MultiHeadAttentionBlock(nn.Module):\n",
|
69 |
+
" def __init__(self, d_model: int, heads: int, dropout: float) -> None:\n",
|
70 |
+
" super().__init__()\n",
|
71 |
+
" self.d_model = d_model\n",
|
72 |
+
" self.heads = heads \n",
|
73 |
+
" assert d_model % heads == 0, \"d_model is not divisable by heads\"\n",
|
74 |
+
"\n",
|
75 |
+
" self.d_k = d_model // heads\n",
|
76 |
+
"\n",
|
77 |
+
" self.W_Q = nn.Linear(d_model, d_model)\n",
|
78 |
+
" self.W_K = nn.Linear(d_model, d_model)\n",
|
79 |
+
" self.W_V = nn.Linear(d_model, d_model)\n",
|
80 |
+
"\n",
|
81 |
+
" self.W_O = nn.Linear(d_model, d_model)\n",
|
82 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
83 |
+
" \n",
|
84 |
+
" @staticmethod\n",
|
85 |
+
" def Attention(Query, Key, Value, mask, dropout: nn.Module):\n",
|
86 |
+
" d_k = Query.shape[-1]\n",
|
87 |
+
"\n",
|
88 |
+
" self_attention_score = (Query @ Key.transpose(-2,-1)) / math.sqrt(d_k)\n",
|
89 |
+
" if mask is not None:\n",
|
90 |
+
" self_attention_score.masked_fill_(mask == 0, -1e9)\n",
|
91 |
+
" self_attention_score = self_attention_score.softmax(dim = -1)\n",
|
92 |
+
"\n",
|
93 |
+
" if dropout is not None:\n",
|
94 |
+
" self_attention_score = dropout(self_attention_score)\n",
|
95 |
+
" return self_attention_score @ Value\n",
|
96 |
+
" def forward(self, query, key, value, mask):\n",
|
97 |
+
" Query = self.W_Q(query)\n",
|
98 |
+
" Key = self.W_K(key)\n",
|
99 |
+
" Value = self.W_V(value)\n",
|
100 |
+
"\n",
|
101 |
+
" Query = Query.view(Query.shape[0], Query.shape[1], self.heads, self.d_k).transpose(1,2)\n",
|
102 |
+
" Key = Key.view(Key.shape[0], Key.shape[1], self.heads, self.d_k).transpose(1,2)\n",
|
103 |
+
" Value = Value.view(Value.shape[0], Value.shape[1], self.heads, self.d_k).transpose(1,2)\n",
|
104 |
+
"\n",
|
105 |
+
" x, self.self_attention_score = MultiHeadAttentionBlock.Attention(Query, Key, Value, mask, self.dropout)\n",
|
106 |
+
" x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.heads * self.d_k)\n",
|
107 |
+
" return self.W_O(x)\n",
|
108 |
+
"\n",
|
109 |
+
"class ResidualConnection(nn.Module):\n",
|
110 |
+
" def __init__(self, dropout: float) -> None:\n",
|
111 |
+
" super().__init__()\n",
|
112 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
113 |
+
" self.normalization = NormalizationLayer()\n",
|
114 |
+
" def forward(self, x, subLayer):\n",
|
115 |
+
" return x + self.dropout(subLayer(self.normalization(x)))"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": null,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"# Building the encoder block \n",
|
125 |
+
"class EncoderBlock(nn.Module):\n",
|
126 |
+
" def __init__(self, encoder_self_attention_block: MultiHeadAttentionBlock, encoder_feed_forward_block: FeedForwardBlock, dropout: float) -> None:\n",
|
127 |
+
" super().__init__()\n",
|
128 |
+
" self.encoder_self_attention_block = encoder_self_attention_block\n",
|
129 |
+
" self.encoder_feed_forward_block = encoder_feed_forward_block\n",
|
130 |
+
" self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])\n",
|
131 |
+
" def forward(self, x, source_mask):\n",
|
132 |
+
" x = self.residual_connection[0](x, lambda x: self.encoder_self_attention_block(x, x, x, source_mask))\n",
|
133 |
+
" x = self.residual_connection[1](x, self.encoder_feed_forward_block)\n",
|
134 |
+
" return x\n",
|
135 |
+
"\n",
|
136 |
+
"class Encoder(nn.Module):\n",
|
137 |
+
" def __init__(self, Layers: nn.ModuleList) -> None:\n",
|
138 |
+
" super().__init__()\n",
|
139 |
+
" self.Layers = Layers\n",
|
140 |
+
" self.normalization = NormalizationLayer()\n",
|
141 |
+
" def forward(self, x, source_mask):\n",
|
142 |
+
" for layer in self.Layers:\n",
|
143 |
+
" x = layer(x, source_mask)\n",
|
144 |
+
" return self.normalization(x)"
|
145 |
+
]
|
146 |
+
}
|
147 |
+
],
|
148 |
+
"metadata": {
|
149 |
+
"language_info": {
|
150 |
+
"name": "python"
|
151 |
+
},
|
152 |
+
"orig_nbformat": 4
|
153 |
+
},
|
154 |
+
"nbformat": 4,
|
155 |
+
"nbformat_minor": 2
|
156 |
+
}
|
PE.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
TransformerBlock.ipynb
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import import_ipynb\n",
|
10 |
+
"from Encoder import Encoder, EncoderBlock, MultiHeadAttentionBlock, FeedForwardBlock, InputEmbeddingsLayer, PositionalEncodingLayer\n",
|
11 |
+
"from Decoder import Decoder, DecoderBlock, MultiHeadAttentionBlock, FeedForwardBlock, InputEmbeddingsLayer, PositionalEncodingLayer\n",
|
12 |
+
"\n",
|
13 |
+
"import torch\n",
|
14 |
+
"import torch.nn as nn \n"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 4,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"class LinearLayer(nn.Module):\n",
|
24 |
+
"\n",
|
25 |
+
" def __init__(self, d_model: int, vocab_size: int) -> None:\n",
|
26 |
+
" super().__init__()\n",
|
27 |
+
" self.Linear = nn.Linear(d_model, vocab_size)\n",
|
28 |
+
"\n",
|
29 |
+
" def forward(self, x):\n",
|
30 |
+
" return self.Linear(x)"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 5,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"class TransformerBlock(nn.Module):\n",
|
40 |
+
"\n",
|
41 |
+
" def __init__(self, encoder: Encoder, decoder: Decoder, source_embedding: InputEmbeddingsLayer, target_embedding: InputEmbeddingsLayer, source_position: PositionalEncodingLayer, target_position: PositionalEncodingLayer, Linear: LinearLayer) -> None:\n",
|
42 |
+
" super().__init__()\n",
|
43 |
+
" self.encoder = encoder \n",
|
44 |
+
" self.decoder = decoder \n",
|
45 |
+
" self.source_embedding = source_embedding\n",
|
46 |
+
" self.target_embedding = target_embedding\n",
|
47 |
+
" self.source_position = source_position\n",
|
48 |
+
" self.target_position = target_position\n",
|
49 |
+
" self.Linear = Linear\n",
|
50 |
+
"\n",
|
51 |
+
" def encode(self, source_language, source_mask):\n",
|
52 |
+
" source_language = self.source_embedding(source_language)\n",
|
53 |
+
" source_language = self.source_position(source_language)\n",
|
54 |
+
" return self.encoder(source_language, source_mask)\n",
|
55 |
+
"\n",
|
56 |
+
" def decode(self, Encoder_output, source_mask, target_language, target_mask):\n",
|
57 |
+
" target_language = self.target_embedding(target_language)\n",
|
58 |
+
" target_language = self.target_position(target_language)\n",
|
59 |
+
" return self.decoder(target_language, Encoder_output, source_mask, target_mask)\n",
|
60 |
+
"\n",
|
61 |
+
" def linear(self, x):\n",
|
62 |
+
" return self.Linear(x)\n",
|
63 |
+
" \n"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 6,
|
69 |
+
"metadata": {},
|
70 |
+
"outputs": [],
|
71 |
+
"source": [
|
72 |
+
"def Transformer_Model(source_vocab_size: int, target_vocab_size: int, source_sequence_length: int, target_sequence_length: int, d_model: int = 512, Layers: int = 6, heads: int = 8, dropout: float = 0.1, d_ff: int = 2048)->TransformerBlock:\n",
|
73 |
+
"\n",
|
74 |
+
" source_embedding = InputEmbeddingsLayer(d_model, source_vocab_size)\n",
|
75 |
+
" target_embedding = InputEmbeddingsLayer(d_model, target_vocab_size)\n",
|
76 |
+
"\n",
|
77 |
+
" source_position = PositionalEncodingLayer(d_model, source_sequence_length, dropout)\n",
|
78 |
+
" target_position = PositionalEncodingLayer(d_model, target_sequence_length, dropout)\n",
|
79 |
+
"\n",
|
80 |
+
" EncoderBlocks = []\n",
|
81 |
+
" for _ in range(Layers):\n",
|
82 |
+
" encoder_self_attention_block = MultiHeadAttentionBlock(d_model, heads, dropout)\n",
|
83 |
+
" encoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)\n",
|
84 |
+
" encoder_block = EncoderBlock(encoder_self_attention_block, encoder_feed_forward_block, dropout)\n",
|
85 |
+
" EncoderBlocks.append(encoder_block)\n",
|
86 |
+
"\n",
|
87 |
+
" DecoderBlocks = []\n",
|
88 |
+
" for _ in range(Layers):\n",
|
89 |
+
" decoder_self_attention_block = MultiHeadAttentionBlock(d_model, heads, dropout)\n",
|
90 |
+
" decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, heads, dropout)\n",
|
91 |
+
" decoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)\n",
|
92 |
+
" decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, decoder_feed_forward_block, dropout)\n",
|
93 |
+
" DecoderBlocks.append(decoder_block)\n",
|
94 |
+
"\n",
|
95 |
+
" encoder = Encoder(nn.ModuleList(EncoderBlocks))\n",
|
96 |
+
" decoder = Decoder(nn.ModuleList(DecoderBlocks))\n",
|
97 |
+
"\n",
|
98 |
+
" linear = LinearLayer(d_model, target_vocab_size)\n",
|
99 |
+
"\n",
|
100 |
+
" Transformer = TransformerBlock(encoder, decoder, source_embedding, target_embedding, source_position, target_position, linear)\n",
|
101 |
+
" \n",
|
102 |
+
" for t in Transformer.parameters():\n",
|
103 |
+
" if t.dim() > 1:\n",
|
104 |
+
" nn.init.xavier_uniform(t)\n",
|
105 |
+
"\n",
|
106 |
+
" return Transformer\n",
|
107 |
+
" "
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": null,
|
113 |
+
"metadata": {},
|
114 |
+
"outputs": [],
|
115 |
+
"source": []
|
116 |
+
}
|
117 |
+
],
|
118 |
+
"metadata": {
|
119 |
+
"interpreter": {
|
120 |
+
"hash": "5f594f1fbc6ec12c92a2efee092a20dcfd0697dc036fc348ba81f2fc261c5e29"
|
121 |
+
},
|
122 |
+
"kernelspec": {
|
123 |
+
"display_name": "Python 3.11.5 64-bit",
|
124 |
+
"language": "python",
|
125 |
+
"name": "python3"
|
126 |
+
},
|
127 |
+
"language_info": {
|
128 |
+
"codemirror_mode": {
|
129 |
+
"name": "ipython",
|
130 |
+
"version": 3
|
131 |
+
},
|
132 |
+
"file_extension": ".py",
|
133 |
+
"mimetype": "text/x-python",
|
134 |
+
"name": "python",
|
135 |
+
"nbconvert_exporter": "python",
|
136 |
+
"pygments_lexer": "ipython3",
|
137 |
+
"version": "3.11.5"
|
138 |
+
},
|
139 |
+
"orig_nbformat": 4
|
140 |
+
},
|
141 |
+
"nbformat": 4,
|
142 |
+
"nbformat_minor": 2
|
143 |
+
}
|
transformer.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Args:
|
8 |
+
source_vocab_size: int
|
9 |
+
target_vocab_size: int
|
10 |
+
source_sequence_length: int
|
11 |
+
target_sequence_length: int
|
12 |
+
d_model: int = 512
|
13 |
+
Layers: int = 6
|
14 |
+
heads: int = 8
|
15 |
+
dropout: float = 0.1
|
16 |
+
d_ff: int = 2048
|
17 |
+
|
18 |
+
class InputEmbeddingLayer(nn.Module):
|
19 |
+
def __init__(self, d_model: int, vocab_size: int) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.d_model = d_model
|
22 |
+
self.vocab_size = vocab_size
|
23 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.embedding(x) * math.sqrt(self.d_model)
|
27 |
+
|
28 |
+
class PositionalEncodingLayer(nn.Module):
|
29 |
+
def __init__(self, d_model: int, sequence_length: int, dropout: float) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self.d_model = d_model
|
32 |
+
self.sequence_length = sequence_length
|
33 |
+
self.dropout = nn.Dropout(dropout)
|
34 |
+
|
35 |
+
PE = torch.zeros(sequence_length, d_model)
|
36 |
+
Position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1)
|
37 |
+
deviation_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
38 |
+
|
39 |
+
PE[:, 0::2] = torch.sin(Position * deviation_term)
|
40 |
+
PE[:, 1::2] = torch.cos(Position * deviation_term)
|
41 |
+
PE = PE.unsqueeze(0)
|
42 |
+
self.register_buffer('PE', PE)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = x + (self.PE[:, :x.shape[1], :]).requires_grad(False)
|
46 |
+
return self.dropout(x)
|
47 |
+
|
48 |
+
class NormalizationLayer(nn.Module):
|
49 |
+
def __init__(self, Epsilon: float = 10**-4) -> None:
|
50 |
+
super().__init__()
|
51 |
+
self.Epsilon = Epsilon
|
52 |
+
self.Alpha = nn.Parameter(torch.ones(1))
|
53 |
+
self.Bias = nn.Parameter(torch.ones(1))
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
mean = x.mean(dim = -1, keepdim = True)
|
57 |
+
std = x.std(dim = -1, keepdim = True)
|
58 |
+
return self.Alpha * (x - mean) / (std + self.Epsilon) + self.Bias
|
59 |
+
|
60 |
+
class FeedForwardBlock(nn.Module):
|
61 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
|
62 |
+
super().__init__()
|
63 |
+
self.Linear_1 = nn.Linear(d_model, d_ff)
|
64 |
+
self.dropout = nn.Dropout(dropout)
|
65 |
+
self.Linear_2 = nn.Linear(d_ff, d_model)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return self.Linear_2(self.dropout(torch.relu(self.Linear_1(x))))
|
69 |
+
|
70 |
+
class MultiHeadAttentionBlock(nn.Module):
|
71 |
+
def __init__(self, d_model: int, heads: int, dropout: float) -> None:
|
72 |
+
super().__init__()
|
73 |
+
self.d_model = d_model
|
74 |
+
self.heads = heads
|
75 |
+
assert d_model % heads == 0, "d_model is not divisable by heads"
|
76 |
+
|
77 |
+
self.d_k = d_model // heads
|
78 |
+
|
79 |
+
self.W_Q = nn.Linear(d_model, d_model)
|
80 |
+
self.W_K = nn.Linear(d_model, d_model)
|
81 |
+
self.W_V = nn.Linear(d_model, d_model)
|
82 |
+
self.W_O = nn.Linear(d_model, d_model)
|
83 |
+
|
84 |
+
self.dropout = nn.Dropout(dropout)
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def Attention(Query, Key, Value, mask, dropout):
|
88 |
+
d_k = Query.shape[-1]
|
89 |
+
self_attention_scores = (Query @ Key.traspose(-2, -1)) / math.sqrt(d_k)
|
90 |
+
|
91 |
+
if mask is not None:
|
92 |
+
self_attention_scores.masked_fill(mask == 0, -1e9)
|
93 |
+
self_attention_scores = self_attention_scores.Softmax(dim = -1)
|
94 |
+
|
95 |
+
if dropout is not None:
|
96 |
+
self_attention_scores = dropout(self_attention_scores)
|
97 |
+
return self_attention_scores @ Value
|
98 |
+
|
99 |
+
def forward(self, query, key, value, mask):
|
100 |
+
Query = self.W_Q(query)
|
101 |
+
Key = self.W_K(key)
|
102 |
+
Value = self.W_V(value)
|
103 |
+
|
104 |
+
Query = Query.view(Query.shape[0], Query.shape[1], self.heads, self.d_k).transpose(1,2)
|
105 |
+
Key = Key.view(Key.shape[0], Key.shape[1], self.heads, self.d_k).transpose(1,2)
|
106 |
+
Value = Value.view(Value.shape[0], Value.shape[1], self.heads, self.d_k).transpose(1,2)
|
107 |
+
|
108 |
+
x, self.self_attention_scores = MultiHeadAttentionBlock.Attention(Query, Key, Value, mask, self.dropout)
|
109 |
+
x = x.transpose().contiguous().view(x.shape[0], -1, self.heads * self.d_k)
|
110 |
+
return self.W_O(x)
|
111 |
+
|
112 |
+
class ResidualConnection(nn.Module):
|
113 |
+
def __init__(self, dropout: float) -> None:
|
114 |
+
super().__init__()
|
115 |
+
self.dropout = nn.Dropout(dropout)
|
116 |
+
self.normalization_layer = NormalizationLayer()
|
117 |
+
|
118 |
+
def forward(self, x, subLayer):
|
119 |
+
return self.dropout(subLayer(self.normalization_layer))
|
120 |
+
|
121 |
+
class EncoderBlock(nn.Module):
|
122 |
+
def __init__(self, self_attetion_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
|
123 |
+
super().__init__()
|
124 |
+
self.self_attention_block = self_attetion_block
|
125 |
+
self.feed_forward_block = feed_forward_block
|
126 |
+
self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
|
127 |
+
|
128 |
+
def forward(self, x, source_mask):
|
129 |
+
x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, source_mask))
|
130 |
+
x = self.residual_connection[1](x, self.feed_forward_block)
|
131 |
+
return x
|
132 |
+
|
133 |
+
class Encoder(nn.Module):
|
134 |
+
def __init__(self, Layers: nn.ModuleList) -> None:
|
135 |
+
super().__init__()
|
136 |
+
self.Layers = Layers
|
137 |
+
self.normalization_layer = NormalizationLayer()
|
138 |
+
|
139 |
+
def forward(self, x, source_mask):
|
140 |
+
for layer in self.Layers:
|
141 |
+
x = layer(x, source_mask)
|
142 |
+
return self.normalization_layer(x)
|
143 |
+
|
144 |
+
class DecoderBlock(nn.Module):
|
145 |
+
def __init__(self, masked_self_attention_block: MultiHeadAttentionBlock, self_attention_block: MultiHeadAttentionBlock, feedforwardblock: FeedForwardBlock, dropout: float) -> None:
|
146 |
+
super().__init__()
|
147 |
+
self.masked_self_attention_block = masked_self_attention_block
|
148 |
+
self.self_attention_block = self_attention_block
|
149 |
+
self.feedforwardblock = feedforwardblock
|
150 |
+
self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
|
151 |
+
|
152 |
+
def forward(self, x, Encoder_output, source_mask, target_mask):
|
153 |
+
x = self.residual_connection[0](x, lambda x: self.masked_self_attention_block(x, x, x, source_mask))
|
154 |
+
x = self.residual_connection[1](x, lambda x: self.self_attention_block(x, Encoder_output, Encoder_output, target_mask))
|
155 |
+
x = self.residual_connection[1](x, self.feedforwardblock)
|
156 |
+
return x
|
157 |
+
|
158 |
+
class Decoder(nn.Module):
|
159 |
+
def __init__(self, Layers: nn.ModuleList) -> None:
|
160 |
+
super().__init__()
|
161 |
+
self.Layers = Layers
|
162 |
+
self.normalization_layer = NormalizationLayer()
|
163 |
+
|
164 |
+
def forward(self, x, Encoder_output, source_mask, target_mask):
|
165 |
+
for layer in self.Layers:
|
166 |
+
x = layer(x, Encoder_output, source_mask, target_mask)
|
167 |
+
return self.normalization_layer(x)
|
168 |
+
|
169 |
+
class LinearLayer(nn.Module):
|
170 |
+
def __init__(self, d_model: int, vocab_size: int) -> None:
|
171 |
+
super().__init__()
|
172 |
+
self.Linear = nn.Linear(d_model, vocab_size)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
return self.Linear(x)
|
176 |
+
|
177 |
+
class TransformerBlock(nn.Module):
|
178 |
+
def __init__(self, encoder: Encoder,
|
179 |
+
decoder: Decoder,
|
180 |
+
source_embedding: InputEmbeddingLayer,
|
181 |
+
target_embedding: InputEmbeddingLayer,
|
182 |
+
source_position: PositionalEncodingLayer,
|
183 |
+
target_position: PositionalEncodingLayer,
|
184 |
+
Linear: LinearLayer) -> None:
|
185 |
+
super().__init__()
|
186 |
+
self.encoder = encoder
|
187 |
+
self.decoder = decoder
|
188 |
+
self.source_embedding = source_embedding
|
189 |
+
self.target_embedding = target_embedding
|
190 |
+
self.source_position = source_position
|
191 |
+
self.target_position = target_position
|
192 |
+
self.Linear = Linear
|
193 |
+
|
194 |
+
def encode(self, source_language, source_mask):
|
195 |
+
source_language = self.source_embedding(source_language)
|
196 |
+
source_language = self.source_position(source_language)
|
197 |
+
return self.encoder(source_language, source_mask)
|
198 |
+
|
199 |
+
def decode(self, Encoder_output, source_mask, target_language, target_mask):
|
200 |
+
target_language = self.target_embedding(target_language)
|
201 |
+
target_language = self.target_position(target_language)
|
202 |
+
return self.decoder(target_language, Encoder_output, source_mask, target_mask)
|
203 |
+
|
204 |
+
def linear(self, x):
|
205 |
+
return self.Linear(x)
|
206 |
+
|
207 |
+
|
208 |
+
def Transformer_model(Args: Args)->TransformerBlock:
|
209 |
+
|
210 |
+
source_embedding = InputEmbeddingLayer(Args.d_model, Args.source_vocab_size)
|
211 |
+
source_position = PositionalEncodingLayer(Args.d_model, Args.source_sequence_length, Args.dropout)
|
212 |
+
|
213 |
+
target_embedding = InputEmbeddingLayer(Args.d_model, Args.target_vocab_size)
|
214 |
+
target_position = PositionalEncodingLayer(Args.d_model, Args.target_sequence_length, Args.dropout)
|
215 |
+
|
216 |
+
Encoder_Blocks = []
|
217 |
+
for _ in range(Args.Layers):
|
218 |
+
encoder_self_attention_block = MultiHeadAttentionBlock(Args.d_model, Args.heads, Args.dropout)
|
219 |
+
encoder_feed_forward_block = FeedForwardBlock(Args.d_model, Args.d_ff, Args.dropout)
|
220 |
+
encoder_block = EncoderBlock(encoder_self_attention_block, encoder_feed_forward_block, Args.dropout)
|
221 |
+
Encoder_Blocks.append(encoder_block)
|
222 |
+
|
223 |
+
Decoder_Blocks = []
|
224 |
+
for _ in range(Args.Layers):
|
225 |
+
decoder_self_attention_block = MultiHeadAttentionBlock(Args.d_model, Args.heads, Args.dropout)
|
226 |
+
decoder_cross_attention_block = MultiHeadAttentionBlock(Args.d_model, Args.heads, Args.dropout)
|
227 |
+
decoder_feed_forward_block = FeedForwardBlock(Args.d_model, Args.d_ff, Args.dropout)
|
228 |
+
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, decoder_feed_forward_block, Args.dropout)
|
229 |
+
Decoder_Blocks.append(decoder_block)
|
230 |
+
|
231 |
+
encoder = Encoder(nn.ModuleList(Encoder_Blocks))
|
232 |
+
decoder = Decoder(nn.ModuleList(Decoder_Blocks))
|
233 |
+
|
234 |
+
linear = LinearLayer(Args.d_model, Args.target_vocab_size)
|
235 |
+
|
236 |
+
Transformer = TransformerBlock(encoder, decoder, source_embedding, target_embedding, source_position, target_position, linear)
|
237 |
+
|
238 |
+
for t in Transformer.parameters():
|
239 |
+
if t.dim() > 1:
|
240 |
+
nn.init.xavier_uniform(t)
|
241 |
+
return Transformer
|