shredder-31 commited on
Commit
f463280
·
verified ·
1 Parent(s): 00e7036

Create trainning.py

Browse files
Files changed (1) hide show
  1. trainning.py +301 -0
trainning.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import string
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ import torchvision.models as models
6
+
7
+ def decoder(indices, vocab):
8
+ tokens = [vocab.lookup_token(idx) for idx in indices]
9
+ words = []
10
+ current_word = []
11
+ for token in tokens:
12
+ if len(token) == 1 and token in string.ascii_lowercase:
13
+ current_word.append(token)
14
+ else:
15
+ if current_word:
16
+ words.append("".join(current_word))
17
+ current_word = []
18
+ words.append(token)
19
+
20
+ if current_word:
21
+ words.append(" "+"".join(current_word))
22
+
23
+ return "".join(words)
24
+
25
+ def beam_search_caption(model, images, vocab, decoder, device="cpu",
26
+ start_token="<sos>", end_token="<eos>",
27
+ beam_width=3, max_seq_length=100):
28
+ """
29
+ Generates captions for images using beam search.
30
+
31
+ Args:
32
+ model (ImgCap): The image captioning model.
33
+ images (torch.Tensor): Batch of images.
34
+ vocab (Vocab): Vocabulary object.
35
+ decoder (function): Function to decode indices to words.
36
+ device (str): Device to perform computation on.
37
+ start_token (str): Start-of-sequence token.
38
+ end_token (str): End-of-sequence token.
39
+ beam_width (int): Number of beams to keep.
40
+ max_seq_length (int): Maximum length of the generated caption.
41
+
42
+ Returns:
43
+ list: Generated captions for each image in the batch.
44
+ """
45
+ model.eval()
46
+
47
+ with torch.no_grad():
48
+ start_index = vocab[start_token]
49
+ end_index = vocab[end_token]
50
+ images = images.to(device)
51
+ batch_size = images.size(0)
52
+
53
+ # Ensure batch_size is 1 for beam search (one image at a time)
54
+ if batch_size != 1:
55
+ raise ValueError("Beam search currently supports batch_size=1.")
56
+
57
+ cnn_feature = model.cnn(images) # Shape: (1, 1024)
58
+ lstm_input = model.lstm.projection(cnn_feature).unsqueeze(1) # Shape: (1, 1, 1024)
59
+ state = None # Initial LSTM state
60
+
61
+ # Initialize the beam with the start token
62
+ sequences = [([start_index], 0.0, lstm_input, state)] # List of tuples: (sequence, score, input, state)
63
+
64
+ completed_sequences = []
65
+
66
+ for _ in range(max_seq_length):
67
+ all_candidates = []
68
+
69
+ # Iterate over all current sequences in the beam
70
+ for seq, score, lstm_input, state in sequences:
71
+ # If the last token is the end token, add the sequence to completed_sequences
72
+ if seq[-1] == end_index:
73
+ completed_sequences.append((seq, score))
74
+ continue
75
+
76
+ # Pass the current input and state through the LSTM
77
+ lstm_out, state_new = model.lstm.lstm(lstm_input, state) # lstm_out: (1, 1, 1024)
78
+
79
+ # Pass the LSTM output through the fully connected layer to get logits
80
+ output = model.lstm.fc(lstm_out.squeeze(1)) # Shape: (1, vocab_size)
81
+
82
+ # Compute log probabilities
83
+ log_probs = F.log_softmax(output, dim=1) # Shape: (1, vocab_size)
84
+
85
+ # Get the top beam_width tokens and their log probabilities
86
+ top_log_probs, top_indices = log_probs.topk(beam_width, dim=1) # Each of shape: (1, beam_width)
87
+
88
+ # Iterate over the top tokens to create new candidate sequences
89
+ for i in range(beam_width):
90
+ token = top_indices[0, i].item()
91
+ token_log_prob = top_log_probs[0, i].item()
92
+
93
+ # Create a new sequence by appending the current token
94
+ new_seq = seq + [token]
95
+ new_score = score + token_log_prob
96
+
97
+ # Get the embedding of the new token
98
+ token_tensor = torch.tensor([token], device=device)
99
+ new_lstm_input = model.lstm.embedding(token_tensor).unsqueeze(1) # Shape: (1, 1, 1024)
100
+
101
+ # Clone the new state to ensure each beam has its own state
102
+ if state_new is not None:
103
+ new_state = (state_new[0].clone(), state_new[1].clone())
104
+ else:
105
+ new_state = None
106
+
107
+ # Add the new candidate to all_candidates
108
+ all_candidates.append((new_seq, new_score, new_lstm_input, new_state))
109
+
110
+ # If no candidates are left to process, break out of the loop
111
+ if not all_candidates:
112
+ break
113
+
114
+ # Sort all candidates by score in descending order
115
+ ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
116
+
117
+ # Select the top beam_width sequences to form the new beam
118
+ sequences = ordered[:beam_width]
119
+
120
+ # If enough completed sequences are found, stop early
121
+ if len(completed_sequences) >= beam_width:
122
+ break
123
+
124
+ # If no sequences have completed, use the current sequences
125
+ if len(completed_sequences) == 0:
126
+ completed_sequences = sequences
127
+
128
+ # Select the sequence with the highest score
129
+ best_seq, best_score = max(completed_sequences, key=lambda x: x[1])
130
+
131
+ if best_seq[0] == start_index:
132
+ best_seq = best_seq[1:]
133
+
134
+ best_caption = decoder(best_seq, vocab)
135
+
136
+ return best_caption
137
+
138
+
139
+ def generate_caption(model, images, vocab, decoder, device="cpu", start_token="<sos>", end_token="<eos>", max_seq_length=100, top_k=2):
140
+ model.eval()
141
+
142
+ with torch.no_grad():
143
+ start_index = vocab[start_token]
144
+ end_index = vocab[end_token]
145
+ images = images.to(device)
146
+ batch_size = images.size(0)
147
+
148
+ end_token_appear = {i: False for i in range(batch_size)}
149
+ captions = [[] for _ in range(batch_size)]
150
+
151
+ cnn_feature = model.cnn(images)
152
+ lstm_input = model.lstm.projection(cnn_feature).unsqueeze(1) # (B, 1, hidden_size)
153
+
154
+ state = None
155
+
156
+ for i in range(max_seq_length):
157
+ lstm_out, state = model.lstm.lstm(lstm_input, state)
158
+ output = model.lstm.fc(lstm_out.squeeze(1))
159
+
160
+ top_k_probs, top_k_indices = torch.topk(F.softmax(output, dim=1), top_k, dim=1)
161
+ top_k_probs = top_k_probs / torch.sum(top_k_probs, dim=1, keepdim=True)
162
+ top_k_samples = torch.multinomial(top_k_probs, 1).squeeze()
163
+
164
+ predicted_word_indices = top_k_indices[range(batch_size), top_k_samples]
165
+
166
+ lstm_input = model.lstm.embedding(predicted_word_indices).unsqueeze(1) # (B, 1, hidden_size)
167
+
168
+ for j in range(batch_size):
169
+ if end_token_appear[j]:
170
+ continue
171
+
172
+ word = vocab.lookup_token(predicted_word_indices[j].item())
173
+ if word == end_token:
174
+ end_token_appear[j] = True
175
+
176
+ captions[j].append(predicted_word_indices[j].item())
177
+
178
+ captions = [decoder(caption, vocab) for caption in captions]
179
+
180
+ return captions
181
+
182
+
183
+
184
+
185
+ class ResNet50(nn.Module):
186
+ def __init__(self):
187
+ super(ResNet50, self).__init__()
188
+ self.ResNet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
189
+
190
+ self.ResNet50.fc = nn.Sequential(
191
+ nn.Linear(2048, 1024),
192
+ nn.ReLU(),
193
+ nn.Dropout(0.5),
194
+ nn.Linear(1024, 1024),
195
+ nn.ReLU(),
196
+ )
197
+
198
+ for k,v in self.ResNet50.named_parameters(recurse=True):
199
+ if 'fc' in k:
200
+ v.requires_grad = True
201
+ else:
202
+ v.requires_grad = False
203
+
204
+ def forward(self,x):
205
+ return self.ResNet50(x)
206
+
207
+ ## lSTM (Decoder)
208
+
209
+ class lstm(nn.Module):
210
+ def __init__(self, input_size, hidden_size, number_layers, embedding_dim, vocab_size):
211
+ super(lstm, self).__init__()
212
+
213
+ self.input_size = input_size
214
+ self.hidden_size = hidden_size
215
+ self.number_layers = number_layers
216
+ self.embedding_dim = embedding_dim
217
+ self.vocab_size = vocab_size
218
+
219
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
220
+ self.projection = nn.Linear(input_size, hidden_size)
221
+ self.relu = nn.ReLU()
222
+
223
+ self.lstm = nn.LSTM(
224
+ input_size=hidden_size,
225
+ hidden_size=hidden_size,
226
+ num_layers=number_layers,
227
+ dropout=0.5,
228
+ batch_first=True,
229
+ )
230
+
231
+ self.fc = nn.Linear(hidden_size, vocab_size)
232
+
233
+ def forward(self, x, captions):
234
+ projected_image = self.projection(x).unsqueeze(dim=1)
235
+ embeddings = self.embedding(captions[:, :-1])
236
+
237
+ # Concatenate the image feature as frist step with word embeddings
238
+ lstm_input = torch.cat((projected_image, embeddings), dim=1)
239
+ # print(torch.all(projected_image[:, 0, :] == lstm_input[:, 0, :])) # check
240
+
241
+ lstm_out, _ = self.lstm(lstm_input)
242
+ logits = self.fc(lstm_out)
243
+
244
+ return logits
245
+
246
+ ## ImgCap
247
+
248
+ class ImgCap(nn.Module):
249
+ def __init__(self, cnn_feature_size, lstm_hidden_size, num_layers, vocab_size, embedding_dim):
250
+ super(ImgCap, self).__init__()
251
+
252
+ self.cnn = ResNet50()
253
+
254
+ self.lstm = lstm(input_size=cnn_feature_size,
255
+ hidden_size=lstm_hidden_size,
256
+ number_layers=num_layers,
257
+ embedding_dim=embedding_dim,
258
+ vocab_size=vocab_size)
259
+
260
+ def forward(self, images, captions):
261
+ cnn_features = self.cnn(images)
262
+ output = self.lstm(cnn_features, captions)
263
+ return output
264
+
265
+ def generate_caption(self, images, vocab, decoder, device="cpu", start_token="<sos>", end_token="<eos>", max_seq_length=100):
266
+ self.eval()
267
+
268
+ with torch.no_grad():
269
+ start_index = vocab[start_token]
270
+ end_index = vocab[end_token]
271
+ images = images.to(device)
272
+ batch_size = images.size(0)
273
+
274
+ end_token_appear = {i: False for i in range(batch_size)}
275
+ captions = [[] for _ in range(batch_size)]
276
+
277
+ cnn_feature = self.cnn(images)
278
+ lstm_input = self.lstm.projection(cnn_feature).unsqueeze(1) # (B, 1, hidden_size)
279
+
280
+ state = None
281
+
282
+ for i in range(max_seq_length):
283
+ lstm_out, state = self.lstm.lstm(lstm_input, state)
284
+ output = self.lstm.fc(lstm_out.squeeze(1))
285
+ predicted_word_indices = torch.argmax(output, dim=1)
286
+ lstm_input = self.lstm.embedding(predicted_word_indices).unsqueeze(1) # (B, 1, hidden_size)
287
+
288
+ for j in range(batch_size):
289
+ if end_token_appear[j]:
290
+ continue
291
+
292
+ word = vocab.lookup_token(predicted_word_indices[j].item())
293
+ if word == end_token:
294
+ end_token_appear[j] = True
295
+
296
+ captions[j].append(predicted_word_indices[j].item())
297
+
298
+ captions = [decoder(caption) for caption in captions]
299
+
300
+ return captions
301
+