matth commited on
Commit
173f9b2
·
1 Parent(s): 10bbf96

Upload Flowformer

Browse files
Files changed (2) hide show
  1. configuration_flowformer.py +33 -1
  2. model_flowformer.py +37 -11
configuration_flowformer.py CHANGED
@@ -1,8 +1,40 @@
1
  from transformers import PretrainedConfig
2
 
3
  class FlowformerConfig(PretrainedConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def __init__(self,
5
- dim_hidden: int=32, # dim_hidden must be divisible by num_heads i.e. dim_hidden%num_heads = 0
6
  num_heads: int=4,
7
  num_inds: int=16,
8
  hidden_layers: int=3,
 
1
  from transformers import PretrainedConfig
2
 
3
  class FlowformerConfig(PretrainedConfig):
4
+ r"""
5
+ This is the configuration class to store the configuration of a [`Flowformer`]. It is used to instantiate an
6
+ Flowformer model according to the specified arguments, defining the model architecture. Instantiating a configuration
7
+ with the defaults will yield a similar configuration to that of out model for ALL data (https://arxiv.org/abs/2108.10072).
8
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
9
+ documentation from [`PretrainedConfig`] for more information.
10
+ Args:
11
+ dim_hidden (`int`, *optional*, defaults to 32):
12
+ The dimensionality of the hidden states. dim_hidden must be divisible by num_heads i.e. dim_hidden%num_heads = 0.
13
+ num_heads (`int`, *optional*, defaults to 4):
14
+ The number of attention heads.
15
+ num_inds (`int`, *optional*, defaults to 32):
16
+ The number of inducing points.
17
+ hidden_layers (`int`, *optional*, defaults to 3):
18
+ The number of hidden layers.
19
+ layer_norm (`bool`, *optional*, defaults to True):
20
+ Whether to apply layer normalization.
21
+ dim_input (`int`, *optional*, defaults to 11):
22
+ The dimensionality of the input.
23
+ markers (`list`, *optional*, defaults to ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"]):
24
+ The list of markers.
25
+ Example:
26
+ ```python
27
+ >>> from transformers import FlowformerConfig, FlowformerModel
28
+ >>> # Initializing a Flowformer configuration
29
+ >>> configuration = FlowformerConfig()
30
+ >>> # Initializing a model (with random weights) from the Flowformer configuration
31
+ >>> model = FlowformerModel(configuration)
32
+ >>> # Accessing the model configuration
33
+ >>> configuration = model.config
34
+ ```
35
+ """
36
  def __init__(self,
37
+ dim_hidden: int=32,
38
  num_heads: int=4,
39
  num_inds: int=16,
40
  hidden_layers: int=3,
model_flowformer.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
4
  from torch.nn.functional import binary_cross_entropy_with_logits
5
  import math
6
  from transformers import PreTrainedModel
 
7
  from .configuration_flowformer import FlowformerConfig
8
 
9
 
@@ -11,7 +12,7 @@ class MAB(nn.Module):
11
  """
12
  Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
13
  """
14
- def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
15
  super(MAB, self).__init__()
16
 
17
  self.dim_V = dim_V
@@ -47,7 +48,7 @@ class ISAB(nn.Module):
47
  """
48
  The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
49
  """
50
- def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
51
  super(ISAB, self).__init__()
52
 
53
  self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
@@ -60,8 +61,30 @@ class ISAB(nn.Module):
60
 
61
  return self.mab1(X, H)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  class Flowformer(PreTrainedModel):
64
- def __init__(self, config):
65
  super().__init__(config)
66
 
67
  # Load config
@@ -72,7 +95,7 @@ class Flowformer(PreTrainedModel):
72
  hidden_layers = config.hidden_layers
73
  layer_norm = config.layer_norm
74
  dim_output = 1
75
- self._pretrained_markers = config.markers or ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"]
76
 
77
  # Define encoder
78
  enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
@@ -85,17 +108,18 @@ class Flowformer(PreTrainedModel):
85
  dec_layers = [nn.Linear(dim_input, dim_output)]
86
  self.dec = nn.Sequential(*dec_layers)
87
 
88
- def pretrained_markers(self):
89
  return self._pretrained_markers
90
 
91
- def forward(self, tensor, labels=None, markers: list=None):
 
92
  B, L, M = tensor.shape
93
  if markers is not None:
94
  assert len(markers) == M, "Number of markers in x and markers must be identical"
95
 
96
- zeros = torch.zeros((B, L, len(self._pretrained_markers)), device=tensor.device)
97
- valid_markers = [m for m in markers if m in set(self._pretrained_markers).intersection(markers)]
98
- idx = [self._pretrained_markers.index(m) for m in valid_markers]
99
  zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
100
  tensor = zeros
101
 
@@ -105,10 +129,12 @@ class Flowformer(PreTrainedModel):
105
  if labels is not None:
106
  return {
107
  'loss': binary_cross_entropy_with_logits(output, labels),
108
- 'logits': output
 
109
  }
110
  else:
111
  return {
112
- 'logits': output
 
113
  }
114
 
 
4
  from torch.nn.functional import binary_cross_entropy_with_logits
5
  import math
6
  from transformers import PreTrainedModel
7
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
8
  from .configuration_flowformer import FlowformerConfig
9
 
10
 
 
12
  """
13
  Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
14
  """
15
+ def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: int=False):
16
  super(MAB, self).__init__()
17
 
18
  self.dim_V = dim_V
 
48
  """
49
  The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
50
  """
51
+ def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool=False):
52
  super(ISAB, self).__init__()
53
 
54
  self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
 
61
 
62
  return self.mab1(X, H)
63
 
64
+ FLOWFORMER_START_DOCSTRING = r"""
65
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
66
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
67
+ behavior.
68
+ Parameters:
69
+ config ([`FlowformerConfig`]): Model configuration class with all the parameters of the model.
70
+ Initializing with a config file does not load the weights associated with the model, only the
71
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
72
+ """
73
+
74
+ FLOWFORMER_INPUTS_DOCSTRING = r"""
75
+ Args:
76
+ tensor (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_markers)`):
77
+ The sample used as a basis for the prediction.
78
+ labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
79
+ Optional ground truth lables for computing the loss.
80
+ markers (`list` of length `num_markers`):
81
+ The list of markers in the same order as the last dimension of the input tensor.
82
+ """
83
+
84
+
85
+ @add_start_docstrings(FLOWFORMER_START_DOCSTRING)
86
  class Flowformer(PreTrainedModel):
87
+ def __init__(self, config: FlowformerConfig):
88
  super().__init__(config)
89
 
90
  # Load config
 
95
  hidden_layers = config.hidden_layers
96
  layer_norm = config.layer_norm
97
  dim_output = 1
98
+ self._markers = config.markers
99
 
100
  # Define encoder
101
  enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
 
108
  dec_layers = [nn.Linear(dim_input, dim_output)]
109
  self.dec = nn.Sequential(*dec_layers)
110
 
111
+ def markers(self):
112
  return self._pretrained_markers
113
 
114
+ @add_start_docstrings_to_model_forward(FLOWFORMER_INPUTS_DOCSTRING)
115
+ def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
116
  B, L, M = tensor.shape
117
  if markers is not None:
118
  assert len(markers) == M, "Number of markers in x and markers must be identical"
119
 
120
+ zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
121
+ valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
122
+ idx = [self.markers().index(m) for m in valid_markers]
123
  zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
124
  tensor = zeros
125
 
 
129
  if labels is not None:
130
  return {
131
  'loss': binary_cross_entropy_with_logits(output, labels),
132
+ 'logits': output,
133
+ 'prediction': torch.where(output > 0, 1, 0)
134
  }
135
  else:
136
  return {
137
+ 'logits': output,
138
+ 'prediction': torch.where(output > 0, 1, 0)
139
  }
140