Upload model
Browse files- modelling_single.py +9 -7
modelling_single.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Any, Optional, Tuple, Union
|
|
4 |
import torch
|
5 |
import transformers
|
6 |
from torch.nn import CrossEntropyLoss
|
7 |
-
from transformers import VisionEncoderDecoderModel
|
8 |
from transformers.configuration_utils import PretrainedConfig
|
9 |
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
10 |
from transformers.modeling_utils import PreTrainedModel
|
@@ -94,6 +94,10 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
94 |
decoder: Optional[PreTrainedModel] = None,
|
95 |
):
|
96 |
|
|
|
|
|
|
|
|
|
97 |
if config is None and (encoder is None or decoder is None):
|
98 |
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
99 |
if config is None:
|
@@ -132,9 +136,6 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
132 |
self.encoder.config = self.config.encoder
|
133 |
self.decoder.config = self.config.decoder
|
134 |
|
135 |
-
# config.add_cross_attention = True
|
136 |
-
# config.is_decoder = True
|
137 |
-
|
138 |
def forward(
|
139 |
self,
|
140 |
pixel_values: Optional[torch.FloatTensor] = None,
|
@@ -317,7 +318,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
317 |
|
318 |
return token_type_ids
|
319 |
|
320 |
-
def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer):
|
321 |
"""
|
322 |
Tokenize the reports and creates the inputs and targets for teacher forcing.
|
323 |
|
@@ -326,6 +327,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
326 |
impression - impression section.
|
327 |
return_token_type_ids - return the token type identifiers.
|
328 |
tokenizer - Hugging Face tokenizer.
|
|
|
329 |
|
330 |
Returns:
|
331 |
decoder_input_ids - the token identifiers for the input of the decoder.
|
@@ -342,7 +344,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
342 |
report,
|
343 |
padding='longest',
|
344 |
truncation=True,
|
345 |
-
max_length=
|
346 |
return_tensors='pt',
|
347 |
return_token_type_ids=False,
|
348 |
add_special_tokens=False,
|
@@ -363,7 +365,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
363 |
|
364 |
return batch_dict
|
365 |
|
366 |
-
def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer):
|
367 |
"""
|
368 |
Split the token identifiers into sections, then convert the token identifiers into strings.
|
369 |
|
|
|
4 |
import torch
|
5 |
import transformers
|
6 |
from torch.nn import CrossEntropyLoss
|
7 |
+
from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
|
8 |
from transformers.configuration_utils import PretrainedConfig
|
9 |
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
10 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
94 |
decoder: Optional[PreTrainedModel] = None,
|
95 |
):
|
96 |
|
97 |
+
if decoder:
|
98 |
+
assert decoder.config.add_cross_attention, '"add_cross_attention" must be True for the given decoder'
|
99 |
+
assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
|
100 |
+
|
101 |
if config is None and (encoder is None or decoder is None):
|
102 |
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
103 |
if config is None:
|
|
|
136 |
self.encoder.config = self.config.encoder
|
137 |
self.decoder.config = self.config.decoder
|
138 |
|
|
|
|
|
|
|
139 |
def forward(
|
140 |
self,
|
141 |
pixel_values: Optional[torch.FloatTensor] = None,
|
|
|
318 |
|
319 |
return token_type_ids
|
320 |
|
321 |
+
def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
|
322 |
"""
|
323 |
Tokenize the reports and creates the inputs and targets for teacher forcing.
|
324 |
|
|
|
327 |
impression - impression section.
|
328 |
return_token_type_ids - return the token type identifiers.
|
329 |
tokenizer - Hugging Face tokenizer.
|
330 |
+
max_len - maximum number of tokens.
|
331 |
|
332 |
Returns:
|
333 |
decoder_input_ids - the token identifiers for the input of the decoder.
|
|
|
344 |
report,
|
345 |
padding='longest',
|
346 |
truncation=True,
|
347 |
+
max_length=max_len + 1, # +1 to account for the bias between input and target.
|
348 |
return_tensors='pt',
|
349 |
return_token_type_ids=False,
|
350 |
add_special_tokens=False,
|
|
|
365 |
|
366 |
return batch_dict
|
367 |
|
368 |
+
def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
|
369 |
"""
|
370 |
Split the token identifiers into sections, then convert the token identifiers into strings.
|
371 |
|