import tensorflow as tf from tensorflow.keras import layers, Input, Model def build_dce_net() -> Model: input_image = Input(shape=[None, None, 3]) conv1 = layers.Conv2D( 32, (3, 3), strides=(1, 1), activation="relu", padding="same" )(input_image) conv2 = layers.Conv2D( 32, (3, 3), strides=(1, 1), activation="relu", padding="same" )(conv1) conv3 = layers.Conv2D( 32, (3, 3), strides=(1, 1), activation="relu", padding="same" )(conv2) conv4 = layers.Conv2D( 32, (3, 3), strides=(1, 1), activation="relu", padding="same" )(conv3) int_con1 = layers.Concatenate(axis=-1)([conv4, conv3]) conv5 = layers.Conv2D( 32, (3, 3), strides=(1, 1), activation="relu", padding="same" )(int_con1) int_con2 = layers.Concatenate(axis=-1)([conv5, conv2]) conv6 = layers.Conv2D( 32, (3, 3), strides=(1, 1), activation="relu", padding="same" )(int_con2) int_con3 = layers.Concatenate(axis=-1)([conv6, conv1]) x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")( int_con3 ) return Model(inputs=input_image, outputs=x_r)