# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ml_collections def get_IRENE_config(): """Returns the PATHOLOGICAL_CLASSFIER configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 512 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 1024 config.transformer.num_heads = 1 #需要被hidden_size整除 config.transformer.num_layers = 4 # 其他三个院训练后续模型都是4 TCGA用的2 # config.transformer.num_layers = 2 # 其他三个院训练后续模型都是4 TCGA用的2 config.transformer.attention_dropout_rate = 0.2 # 0.0 - 0.2 config.transformer.dropout_rate = 0.3 # 0.1 - 0.3 config.classifier = 'token' config.representation_size = None config.cc_len = 40 config.lab_len = 92 return config