import tensorflow as tf import tensorflow.keras as k # Model Architecture def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same', is_bn=True, is_relu=True, n=2): """ Custom function for conv2d: Apply 3*3 convolutions with BN and relu. """ for i in range(1, n + 1): x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size, padding=padding, strides=strides, kernel_regularizer=tf.keras.regularizers.l2(1e-4), kernel_initializer=k.initializers.he_normal(seed=5))(x) if is_bn: x = k.layers.BatchNormalization()(x) if is_relu: x = k.activations.relu(x) return x def dotProduct(seg, cls): B, H, W, N = k.backend.int_shape(seg) seg = tf.reshape(seg, [-1, H * W, N]) final = tf.einsum("ijk,ik->ijk", seg, cls) final = tf.reshape(final, [-1, H, W, N]) return final """ UNet_3Plus """ def UNet_3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None): filters = [64, 128, 256, 512, 1024] input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3 """ Encoder""" # block 1 e1 = conv_block(input_layer, filters[0]) # 320*320*64 # block 2 e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64 e2 = conv_block(e2, filters[1]) # 160*160*128 # block 3 e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128 e3 = conv_block(e3, filters[2]) # 80*80*256 # block 4 e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256 e4 = conv_block(e4, filters[3]) # 40*40*512 # block 5 # bottleneck layer e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512 e5 = conv_block(e5, filters[4]) # 20*20*1024 """ Decoder """ cat_channels = filters[0] cat_blocks = len(filters) upsample_channels = cat_blocks * cat_channels """ d4 """ e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64 e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64 e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128 e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64 e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256 e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64 e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64 e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256 e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64 d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4]) d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320 """ d3 """ e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64 e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64 e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256 e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64 e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64 e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320 e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320 e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3]) d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320 """ d2 """ e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64 e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64 e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64 d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320 d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320 d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320 e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2]) d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320 """ d1 """ e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64 d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320 d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64 d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320 d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320 d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320 e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ]) d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320 # last layer does not have batchnorm and relu d = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) if OUTPUT_CHANNELS == 1: output = k.activations.sigmoid(d) else: output = k.activations.softmax(d) model = tf.keras.Model(inputs=input_layer, outputs=output, name='UNet_3Plus') if(pretrained_weights): model.load_weights(pretrained_weights) return model """ UNet_3Plus with Deep Supervison""" def UNet_3Plus_DeepSup(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None): filters = [64, 128, 256, 512, 1024] input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3 """ Encoder""" # block 1 e1 = conv_block(input_layer, filters[0]) # 320*320*64 # block 2 e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64 e2 = conv_block(e2, filters[1]) # 160*160*128 # block 3 e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128 e3 = conv_block(e3, filters[2]) # 80*80*256 # block 4 e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256 e4 = conv_block(e4, filters[3]) # 40*40*512 # block 5 # bottleneck layer e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512 e5 = conv_block(e5, filters[4]) # 20*20*1024 """ Decoder """ cat_channels = filters[0] cat_blocks = len(filters) upsample_channels = cat_blocks * cat_channels """ d4 """ e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64 e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64 e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128 e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64 e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256 e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64 e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64 e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256 e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64 d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4]) d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320 """ d3 """ e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64 e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64 e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256 e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64 e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64 e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320 e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320 e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3]) d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320 """ d2 """ e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64 e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64 e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64 d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320 d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320 d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320 e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2]) d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320 """ d1 """ e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64 d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320 d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64 d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320 d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320 d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320 e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ]) d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320 """ Deep Supervision Part""" # last layer does not have batchnorm and relu d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) # d1 = no need for upsampling d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) if OUTPUT_CHANNELS == 1: d1 = k.activations.sigmoid(d1) d2 = k.activations.sigmoid(d2) d3 = k.activations.sigmoid(d3) d4 = k.activations.sigmoid(d4) e5 = k.activations.sigmoid(e5) else: d1 = k.activations.softmax(d1) d2 = k.activations.softmax(d2) d3 = k.activations.softmax(d3) d4 = k.activations.softmax(d4) e5 = k.activations.softmax(e5) model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup') if(pretrained_weights): model.load_weights(pretrained_weights) return model """ UNet_3Plus with Deep Supervison and Classification Guided Module""" def UNet_3Plus_DeepSup_CGM(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None): filters = [64, 128, 256, 512, 1024] input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3 """ Encoder""" # block 1 e1 = conv_block(input_layer, filters[0]) # 320*320*64 # block 2 e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64 e2 = conv_block(e2, filters[1]) # 160*160*128 # block 3 e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128 e3 = conv_block(e3, filters[2]) # 80*80*256 # block 4 e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256 e4 = conv_block(e4, filters[3]) # 40*40*512 # block 5, bottleneck layer e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512 e5 = conv_block(e5, filters[4]) # 20*20*1024 """ Classification Guided Module. Part 1""" cls = k.layers.Dropout(rate=0.5)(e5) cls = k.layers.Conv2D(2, kernel_size=(1, 1), padding="same", strides=(1, 1))(cls) cls = k.layers.GlobalMaxPooling2D()(cls) cls = k.activations.sigmoid(cls) cls = tf.argmax(cls, axis=-1) cls = cls[..., tf.newaxis] cls = tf.cast(cls, dtype=tf.float32, ) """ Decoder """ cat_channels = filters[0] cat_blocks = len(filters) upsample_channels = cat_blocks * cat_channels """ d4 """ e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64 e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64 e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128 e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64 e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256 e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64 e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64 e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256 e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64 d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4]) d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320 """ d3 """ e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64 e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64 e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256 e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64 e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64 e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320 e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320 e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3]) d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320 """ d2 """ e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64 e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64 e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64 d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320 d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320 d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320 e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2]) d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320 """ d1 """ e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64 d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320 d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64 d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320 d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320 d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320 e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ]) d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320 """ Deep Supervision Part""" # last layer does not have batchnorm and relu d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) # d1 = no need for upsampling d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) """ Classification Guided Module. Part 2""" d1 = dotProduct(d1, cls) d2 = dotProduct(d2, cls) d3 = dotProduct(d3, cls) d4 = dotProduct(d4, cls) e5 = dotProduct(e5, cls) if OUTPUT_CHANNELS == 1: d1 = k.activations.sigmoid(d1) d2 = k.activations.sigmoid(d2) d3 = k.activations.sigmoid(d3) d4 = k.activations.sigmoid(d4) e5 = k.activations.sigmoid(e5) else: d1 = k.activations.softmax(d1) d2 = k.activations.softmax(d2) d3 = k.activations.softmax(d3) d4 = k.activations.softmax(d4) e5 = k.activations.softmax(e5) model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup_CGM') if(pretrained_weights): model.load_weights(pretrained_weights) return model