anicolson commited on
Commit
7f9decd
1 Parent(s): 6eee4a2

Upload model

Browse files
Files changed (1) hide show
  1. modelling_single.py +9 -7
modelling_single.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any, Optional, Tuple, Union
4
  import torch
5
  import transformers
6
  from torch.nn import CrossEntropyLoss
7
- from transformers import VisionEncoderDecoderModel
8
  from transformers.configuration_utils import PretrainedConfig
9
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
10
  from transformers.modeling_utils import PreTrainedModel
@@ -94,6 +94,10 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
94
  decoder: Optional[PreTrainedModel] = None,
95
  ):
96
 
 
 
 
 
97
  if config is None and (encoder is None or decoder is None):
98
  raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
99
  if config is None:
@@ -132,9 +136,6 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
132
  self.encoder.config = self.config.encoder
133
  self.decoder.config = self.config.decoder
134
 
135
- # config.add_cross_attention = True
136
- # config.is_decoder = True
137
-
138
  def forward(
139
  self,
140
  pixel_values: Optional[torch.FloatTensor] = None,
@@ -317,7 +318,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
317
 
318
  return token_type_ids
319
 
320
- def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer):
321
  """
322
  Tokenize the reports and creates the inputs and targets for teacher forcing.
323
 
@@ -326,6 +327,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
326
  impression - impression section.
327
  return_token_type_ids - return the token type identifiers.
328
  tokenizer - Hugging Face tokenizer.
 
329
 
330
  Returns:
331
  decoder_input_ids - the token identifiers for the input of the decoder.
@@ -342,7 +344,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
342
  report,
343
  padding='longest',
344
  truncation=True,
345
- max_length=self.decoder_max_len + 1, # +1 to account for the bias between input and target.
346
  return_tensors='pt',
347
  return_token_type_ids=False,
348
  add_special_tokens=False,
@@ -363,7 +365,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
363
 
364
  return batch_dict
365
 
366
- def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer):
367
  """
368
  Split the token identifiers into sections, then convert the token identifiers into strings.
369
 
 
4
  import torch
5
  import transformers
6
  from torch.nn import CrossEntropyLoss
7
+ from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
8
  from transformers.configuration_utils import PretrainedConfig
9
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
10
  from transformers.modeling_utils import PreTrainedModel
 
94
  decoder: Optional[PreTrainedModel] = None,
95
  ):
96
 
97
+ if decoder:
98
+ assert decoder.config.add_cross_attention, '"add_cross_attention" must be True for the given decoder'
99
+ assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
100
+
101
  if config is None and (encoder is None or decoder is None):
102
  raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
103
  if config is None:
 
136
  self.encoder.config = self.config.encoder
137
  self.decoder.config = self.config.decoder
138
 
 
 
 
139
  def forward(
140
  self,
141
  pixel_values: Optional[torch.FloatTensor] = None,
 
318
 
319
  return token_type_ids
320
 
321
+ def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
322
  """
323
  Tokenize the reports and creates the inputs and targets for teacher forcing.
324
 
 
327
  impression - impression section.
328
  return_token_type_ids - return the token type identifiers.
329
  tokenizer - Hugging Face tokenizer.
330
+ max_len - maximum number of tokens.
331
 
332
  Returns:
333
  decoder_input_ids - the token identifiers for the input of the decoder.
 
344
  report,
345
  padding='longest',
346
  truncation=True,
347
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
348
  return_tensors='pt',
349
  return_token_type_ids=False,
350
  add_special_tokens=False,
 
365
 
366
  return batch_dict
367
 
368
+ def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
369
  """
370
  Split the token identifiers into sections, then convert the token identifiers into strings.
371