Upload folder using huggingface_hub
Browse files- modelling_medicap.py +21 -264
modelling_medicap.py
CHANGED
@@ -136,98 +136,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
136 |
self.encoder.config = self.config.encoder
|
137 |
self.decoder.config = self.config.decoder
|
138 |
|
139 |
-
@classmethod
|
140 |
-
def from_encoder_decoder_pretrained(
|
141 |
-
cls,
|
142 |
-
encoder_pretrained_model_name_or_path: str = None,
|
143 |
-
decoder_pretrained_model_name_or_path: str = None,
|
144 |
-
*model_args,
|
145 |
-
**kwargs,
|
146 |
-
) -> PreTrainedModel:
|
147 |
-
kwargs_encoder = {
|
148 |
-
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
149 |
-
}
|
150 |
-
|
151 |
-
kwargs_decoder = {
|
152 |
-
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
153 |
-
}
|
154 |
-
|
155 |
-
# remove encoder, decoder kwargs from kwargs
|
156 |
-
for key in kwargs_encoder.keys():
|
157 |
-
del kwargs["encoder_" + key]
|
158 |
-
for key in kwargs_decoder.keys():
|
159 |
-
del kwargs["decoder_" + key]
|
160 |
-
|
161 |
-
# Load and initialize the encoder and decoder
|
162 |
-
# The distinction between encoder and decoder at the model level is made
|
163 |
-
# by the value of the flag `is_decoder` that we need to set correctly.
|
164 |
-
encoder = kwargs_encoder.pop("model", None)
|
165 |
-
if encoder is None:
|
166 |
-
if encoder_pretrained_model_name_or_path is None:
|
167 |
-
raise ValueError(
|
168 |
-
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
169 |
-
"to be defined."
|
170 |
-
)
|
171 |
-
|
172 |
-
if "config" not in kwargs_encoder:
|
173 |
-
encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
|
174 |
-
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
175 |
-
)
|
176 |
-
|
177 |
-
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
178 |
-
logger.info(
|
179 |
-
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
180 |
-
"from a decoder model. Cross-attention and casual mask are disabled."
|
181 |
-
)
|
182 |
-
encoder_config.is_decoder = False
|
183 |
-
encoder_config.add_cross_attention = False
|
184 |
-
|
185 |
-
kwargs_encoder["config"] = encoder_config
|
186 |
-
|
187 |
-
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
188 |
-
|
189 |
-
decoder = kwargs_decoder.pop("model", None)
|
190 |
-
if decoder is None:
|
191 |
-
if decoder_pretrained_model_name_or_path is None:
|
192 |
-
raise ValueError(
|
193 |
-
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
194 |
-
"to be defined."
|
195 |
-
)
|
196 |
-
|
197 |
-
if "config" not in kwargs_decoder:
|
198 |
-
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
|
199 |
-
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
200 |
-
)
|
201 |
-
|
202 |
-
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
203 |
-
logger.info(
|
204 |
-
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
|
205 |
-
f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
|
206 |
-
f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
207 |
-
)
|
208 |
-
decoder_config.is_decoder = True
|
209 |
-
decoder_config.add_cross_attention = True
|
210 |
-
|
211 |
-
kwargs_decoder["config"] = decoder_config
|
212 |
-
|
213 |
-
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
214 |
-
logger.warning(
|
215 |
-
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
216 |
-
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
217 |
-
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
218 |
-
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
219 |
-
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
220 |
-
)
|
221 |
-
|
222 |
-
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
223 |
-
|
224 |
-
# instantiate config with corresponding kwargs
|
225 |
-
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
226 |
-
|
227 |
-
# make sure input & output embeddings is not tied
|
228 |
-
config.tie_word_embeddings = False
|
229 |
-
return cls(encoder=encoder, decoder=decoder, config=config)
|
230 |
-
|
231 |
def forward(
|
232 |
self,
|
233 |
pixel_values: Optional[torch.FloatTensor] = None,
|
@@ -265,11 +173,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
265 |
elif isinstance(encoder_outputs, tuple):
|
266 |
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
267 |
|
268 |
-
# encoder_hidden_states = encoder_outputs[0]
|
269 |
-
# encoder_attention_mask = None
|
270 |
-
|
271 |
-
# image_features = self.encoder(images).projected_last_hidden_state
|
272 |
-
|
273 |
embeddings = self.decoder.transformer.wte(decoder_input_ids)
|
274 |
embeddings = torch.cat([encoder_outputs[0], embeddings], dim=1)
|
275 |
|
@@ -314,143 +217,43 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
314 |
decoder_attentions=decoder_outputs.attentions,
|
315 |
cross_attentions=decoder_outputs.cross_attentions,
|
316 |
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
317 |
-
# encoder_hidden_states=encoder_outputs.hidden_states,
|
318 |
-
# encoder_attentions=encoder_outputs.attentions,
|
319 |
)
|
320 |
|
321 |
-
def
|
322 |
-
self,
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
attention_mask=None,
|
327 |
-
use_cache=None,
|
328 |
-
encoder_outputs=None,
|
329 |
-
**kwargs,
|
330 |
):
|
331 |
"""
|
332 |
-
|
333 |
-
https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
|
334 |
-
"""
|
335 |
-
|
336 |
-
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
337 |
-
decoder_attention_mask = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
|
338 |
-
|
339 |
-
if not past_key_values:
|
340 |
-
token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids)
|
341 |
-
else:
|
342 |
-
token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids)
|
343 |
-
|
344 |
-
input_dict = {
|
345 |
-
'attention_mask': attention_mask,
|
346 |
-
'decoder_attention_mask': decoder_attention_mask,
|
347 |
-
'decoder_input_ids': decoder_inputs['input_ids'],
|
348 |
-
'decoder_token_type_ids': token_type_ids,
|
349 |
-
'encoder_outputs': encoder_outputs,
|
350 |
-
'past_key_values': decoder_inputs['past_key_values'],
|
351 |
-
'use_cache': use_cache,
|
352 |
-
}
|
353 |
-
return input_dict
|
354 |
-
|
355 |
-
def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
|
356 |
-
"""
|
357 |
-
Extract token type identifiers from the token identifiers.
|
358 |
-
|
359 |
-
Argument/s:
|
360 |
-
token_ids - token identifiers.
|
361 |
-
special_token_ids - special token identifiers that indicate the separation between sections.
|
362 |
-
token_type_id_section - token type identifier for each section.
|
363 |
-
|
364 |
-
Returns:
|
365 |
-
token_type_ids - token type identifiers.
|
366 |
-
"""
|
367 |
-
|
368 |
-
token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
|
369 |
-
|
370 |
-
mbatch_size, seq_len = token_ids.shape
|
371 |
-
token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
|
372 |
-
|
373 |
-
for i, j in enumerate(special_token_ids):
|
374 |
-
# Find first occurrence of special tokens that indicate the boundary between sections:
|
375 |
-
cols = (token_ids == j).int().argmax(dim=1)
|
376 |
-
rows = torch.arange(mbatch_size, device=token_ids.device)
|
377 |
-
|
378 |
-
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
|
379 |
-
cols += 1
|
380 |
-
|
381 |
-
# Ensure that the column index is not out of bounds. If 0, then token_id not present.
|
382 |
-
# This is safe as index 0 is always a special token (now equal to 1 due to +1):
|
383 |
-
rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
|
384 |
-
cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
|
385 |
-
|
386 |
-
# Indices to that correspond to the second sequence:
|
387 |
-
if rows.nelement() != 0:
|
388 |
-
ids = torch.stack([
|
389 |
-
torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
|
390 |
-
y, seq_len, device=token_ids.device,
|
391 |
-
)
|
392 |
-
])
|
393 |
-
|
394 |
-
token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
|
395 |
-
|
396 |
-
return token_type_ids
|
397 |
-
|
398 |
-
def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
|
399 |
-
"""
|
400 |
-
Extract token type identifiers from the token identifiers if past != None.
|
401 |
|
402 |
Argument/s:
|
403 |
-
|
404 |
-
special_token_ids - special token identifiers that indicate the separation between sections.
|
405 |
-
|
406 |
-
Returns:
|
407 |
-
token_type_ids - token type identifiers.
|
408 |
-
"""
|
409 |
-
|
410 |
-
token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
|
411 |
-
token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
|
412 |
-
|
413 |
-
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
|
414 |
-
token_ids = token_ids[:, :-1]
|
415 |
-
|
416 |
-
for i, j in enumerate(special_token_ids):
|
417 |
-
|
418 |
-
# Find first occurrence of special token, which indicates the boundary between sections:
|
419 |
-
exists = torch.any(token_ids == j, dim=1, keepdim=True)
|
420 |
-
token_type_ids[exists] = token_type_id_sections[i + 1]
|
421 |
-
|
422 |
-
return token_type_ids
|
423 |
-
|
424 |
-
def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
|
425 |
-
"""
|
426 |
-
Tokenize the reports and creates the inputs and targets for teacher forcing.
|
427 |
-
|
428 |
-
Argument/s:
|
429 |
-
findings - findings section.
|
430 |
-
impression - impression section.
|
431 |
-
return_token_type_ids - return the token type identifiers.
|
432 |
tokenizer - Hugging Face tokenizer.
|
433 |
max_len - maximum number of tokens.
|
434 |
|
435 |
Returns:
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
439 |
"""
|
440 |
|
441 |
-
# Prepare the
|
442 |
-
|
443 |
-
zip(findings, impression)]
|
444 |
|
445 |
-
# Tokenize the
|
446 |
-
tokenized = tokenizer(
|
447 |
-
|
448 |
padding='longest',
|
449 |
truncation=True,
|
450 |
-
max_length=max_len + 1, # +1 to account for the
|
451 |
return_tensors='pt',
|
452 |
return_token_type_ids=False,
|
453 |
-
add_special_tokens=False,
|
454 |
).to(self.device)
|
455 |
|
456 |
# Modify for language modelling:
|
@@ -466,50 +269,4 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
466 |
'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
|
467 |
}
|
468 |
|
469 |
-
return batch_dict
|
470 |
-
|
471 |
-
def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
|
472 |
-
"""
|
473 |
-
Split the token identifiers into sections, then convert the token identifiers into strings.
|
474 |
-
|
475 |
-
Argument/s:
|
476 |
-
token_ids - token identifiers.
|
477 |
-
special_token_ids - special token identifiers that indicate the end of each section.
|
478 |
-
tokenizer - Hugging Face tokenizer.
|
479 |
-
|
480 |
-
Returns:
|
481 |
-
token_type_ids - token type identifiers.
|
482 |
-
"""
|
483 |
-
|
484 |
-
_, seq_len = token_ids.shape
|
485 |
-
|
486 |
-
# The number of sections is the same as the number of special_token_ids:
|
487 |
-
num_sections = len(special_token_ids)
|
488 |
-
|
489 |
-
sections = {k: [] for k in range(num_sections)}
|
490 |
-
|
491 |
-
for i in token_ids:
|
492 |
-
prev_col = 0
|
493 |
-
for j, k in enumerate(special_token_ids):
|
494 |
-
|
495 |
-
# The maximum sequence length was exceeded, thus no more tokens:
|
496 |
-
if prev_col >= seq_len:
|
497 |
-
sections[j].append('')
|
498 |
-
continue
|
499 |
-
|
500 |
-
# Find first occurrence of special tokens that indicate the boundary between sections:
|
501 |
-
col = (i == k).int().argmax().item()
|
502 |
-
|
503 |
-
# If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
|
504 |
-
# the maximum sequence length):
|
505 |
-
if col == 0:
|
506 |
-
col = seq_len
|
507 |
-
|
508 |
-
# Extract section token identifiers:
|
509 |
-
section_token_ids = i[prev_col:col]
|
510 |
-
prev_col = col
|
511 |
-
section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
|
512 |
-
|
513 |
-
sections[j].append(section_string)
|
514 |
-
|
515 |
-
return tuple(sections.values())
|
|
|
136 |
self.encoder.config = self.config.encoder
|
137 |
self.decoder.config = self.config.decoder
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
def forward(
|
140 |
self,
|
141 |
pixel_values: Optional[torch.FloatTensor] = None,
|
|
|
173 |
elif isinstance(encoder_outputs, tuple):
|
174 |
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
175 |
|
|
|
|
|
|
|
|
|
|
|
176 |
embeddings = self.decoder.transformer.wte(decoder_input_ids)
|
177 |
embeddings = torch.cat([encoder_outputs[0], embeddings], dim=1)
|
178 |
|
|
|
217 |
decoder_attentions=decoder_outputs.attentions,
|
218 |
cross_attentions=decoder_outputs.cross_attentions,
|
219 |
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
|
|
|
220 |
)
|
221 |
|
222 |
+
def tokenize_captions_teacher_forcing(
|
223 |
+
self,
|
224 |
+
captions: str,
|
225 |
+
tokenizer: PreTrainedTokenizerFast,
|
226 |
+
max_len: int,
|
|
|
|
|
|
|
|
|
227 |
):
|
228 |
"""
|
229 |
+
Tokenizes the captions and creates the inputs and targets for teacher forcing.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
Argument/s:
|
232 |
+
captions - the captions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
tokenizer - Hugging Face tokenizer.
|
234 |
max_len - maximum number of tokens.
|
235 |
|
236 |
Returns:
|
237 |
+
batch_dict = {
|
238 |
+
decoder_input_ids - the token identifiers for the input of the decoder.
|
239 |
+
decoder_attention_mask - the attention mask for the decoder_input_ids.
|
240 |
+
decoder_token_type_ids - the token type identifiers for the decoder_input_ids.
|
241 |
+
label_ids - the label token identifiers for the decoder.
|
242 |
+
}
|
243 |
"""
|
244 |
|
245 |
+
# Prepare the caption for the tokenizer by placing the special tokens:
|
246 |
+
caption = [f'{tokenizer.bos_token}{i}{tokenizer.eos_token}' for i in captions]
|
|
|
247 |
|
248 |
+
# Tokenize the caption:
|
249 |
+
tokenized = self.tokenizer(
|
250 |
+
caption,
|
251 |
padding='longest',
|
252 |
truncation=True,
|
253 |
+
max_length=max_len + 1, # +1 to account for the shift between input and target.
|
254 |
return_tensors='pt',
|
255 |
return_token_type_ids=False,
|
256 |
+
add_special_tokens=False, # Done in prepare_sections_for_tokenizer()
|
257 |
).to(self.device)
|
258 |
|
259 |
# Modify for language modelling:
|
|
|
269 |
'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
|
270 |
}
|
271 |
|
272 |
+
return batch_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|