Last commit not found
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense | |
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D | |
from tensorflow.keras.layers import BatchNormalization | |
from tensorflow.keras.regularizers import l2 | |
from tensorflow.keras import backend as K | |
from tensorflow.keras.optimizers import SGD | |
import warnings | |
from constraint import tight_frame | |
from convexity_constraint import convex_add | |
warnings.filterwarnings("ignore") | |
class ParsevalNetwork(Model): | |
def __init__( | |
self, | |
input_dim, | |
weight_decay, | |
momentum, | |
nb_classes=4, | |
N=2, | |
k=1, | |
dropout=0.0, | |
verbose=1, | |
): | |
"""[Assign the initial parameters of the wide residual network] | |
Args: | |
weight_decay ([float]): [description] | |
input_dim ([tuple]): [input dimension] | |
nb_classes (int, optional): [output class]. Defaults to 4. | |
N (int, optional): [the number of blocks]. Defaults to 2. | |
k (int, optional): [network width]. Defaults to 1. | |
dropout (float, optional): [dropout value to prevent overfitting]. Defaults to 0.0. | |
verbose (int, optional): [description]. Defaults to 1. | |
Returns: | |
[Model]: [parsevalnetwork] | |
""" | |
self.weight_decay = weight_decay | |
self.input_dim = input_dim | |
self.nb_classes = nb_classes | |
self.N = N | |
self.k = k | |
self.dropout = dropout | |
self.verbose = verbose | |
def initial_conv(self, input): | |
"""[summary] | |
Args: | |
input ([type]): [description] | |
Returns: | |
[type]: [description] | |
""" | |
x = Convolution2D( | |
16, | |
(3, 3), | |
padding="same", | |
kernel_initializer="orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(input) | |
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
return x | |
def expand_conv(self, init, base, k, strides=(1, 1)): | |
"""[summary] | |
Args: | |
init ([type]): [description] | |
base ([type]): [description] | |
k ([type]): [description] | |
strides (tuple, optional): [description]. Defaults to (1, 1). | |
Returns: | |
[type]: [description] | |
""" | |
x = Convolution2D( | |
base * k, | |
(3, 3), | |
padding="same", | |
strides=strides, | |
kernel_initializer="Orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(init) | |
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
x = Convolution2D( | |
base * k, | |
(3, 3), | |
padding="same", | |
kernel_initializer="Orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(x) | |
skip = Convolution2D( | |
base * k, | |
(1, 1), | |
padding="same", | |
strides=strides, | |
kernel_initializer="Orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(init) | |
m = Add()([x, skip]) | |
return m | |
def conv1_block(self, input, k=1, dropout=0.0): | |
"""[summary] | |
Args: | |
input ([type]): [description] | |
k (int, optional): [description]. Defaults to 1. | |
dropout (float, optional): [description]. Defaults to 0.0. | |
Returns: | |
[type]: [description] | |
""" | |
init = input | |
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(input) | |
x = Activation("relu")(x) | |
x = Convolution2D( | |
16 * k, | |
(3, 3), | |
padding="same", | |
kernel_initializer="Orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(x) | |
if dropout > 0.0: | |
x = Dropout(dropout)(x) | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
x = Convolution2D( | |
16 * k, | |
(3, 3), | |
padding="same", | |
kernel_initializer="Orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(x) | |
m = convex_add(init, x, initial_convex_par=0.5, trainable=True) | |
return m | |
def conv2_block(self, input, k=1, dropout=0.0): | |
"""[summary] | |
Args: | |
input ([type]): [description] | |
k (int, optional): [description]. Defaults to 1. | |
dropout (float, optional): [description]. Defaults to 0.0. | |
Returns: | |
[type]: [description] | |
""" | |
init = input | |
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(input) | |
x = Activation("relu")(x) | |
x = Convolution2D( | |
32 * k, | |
(3, 3), | |
padding="same", | |
kernel_initializer="Orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(x) | |
if dropout > 0.0: | |
x = Dropout(dropout)(x) | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
x = Convolution2D( | |
32 * k, | |
(3, 3), | |
padding="same", | |
kernel_initializer="Orthogonal", | |
kernel_regularizer=l2(self.weight_decay), | |
kernel_constraint=tight_frame(0.001), | |
use_bias=False, | |
)(x) | |
m = convex_add(init, x, initial_convex_par=0.5, trainable=True) | |
return m | |
def conv3_block(self, input, k=1, dropout=0.0): | |
init = input | |
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(input) | |
x = Activation("relu")(x) | |
x = Convolution2D( | |
64 * k, | |
(3, 3), | |
padding="same", | |
kernel_initializer="Orthogonal", | |
kernel_constraint=tight_frame(0.001), | |
kernel_regularizer=l2(self.weight_decay), | |
use_bias=False, | |
)(x) | |
if dropout > 0.0: | |
x = Dropout(dropout)(x) | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
x = Convolution2D( | |
64 * k, | |
(3, 3), | |
padding="same", | |
kernel_initializer="Orthogonal", | |
kernel_constraint=tight_frame(0.001), | |
kernel_regularizer=l2(self.weight_decay), | |
use_bias=False, | |
)(x) | |
m = convex_add(init, x, initial_convex_par=0.5, trainable=True) | |
return m | |
def create_wide_residual_network(self): | |
"""create a wide residual network model | |
Returns: | |
[Model]: [wide residual network] | |
""" | |
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
ip = Input(shape=self.input_dim) | |
x = self.initial_conv(ip) | |
nb_conv = 4 | |
x = self.expand_conv(x, 16, self.k) | |
nb_conv += 2 | |
for i in range(self.N - 1): | |
x = self.conv1_block(x, self.k, self.dropout) | |
nb_conv += 2 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
x = self.expand_conv(x, 32, self.k, strides=(2, 2)) | |
nb_conv += 2 | |
for i in range(self.N - 1): | |
x = self.conv2_block(x, self.k, self.dropout) | |
nb_conv += 2 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
x = self.expand_conv(x, 64, self.k, strides=(2, 2)) | |
nb_conv += 2 | |
for i in range(self.N - 1): | |
x = self.conv3_block(x, self.k, self.dropout) | |
nb_conv += 2 | |
x = BatchNormalization( | |
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform" | |
)(x) | |
x = Activation("relu")(x) | |
x = AveragePooling2D((8, 8))(x) | |
x = Flatten()(x) | |
x = Dense( | |
self.nb_classes, | |
kernel_regularizer=l2(self.weight_decay), | |
activation="softmax", | |
)(x) | |
model = Model(ip, x) | |
if self.verbose: | |
print("Parseval Network-%d-%d created." % (nb_conv, self.k)) | |
return model | |