NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core model definition of YAMNet."""
import csv
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers
import features as features_lib
import params
def _batch_norm(name):
def _bn_layer(layer_input):
return layers.BatchNormalization(
name=name,
center=params.BATCHNORM_CENTER,
scale=params.BATCHNORM_SCALE,
epsilon=params.BATCHNORM_EPSILON)(layer_input)
return _bn_layer
def _conv(name, kernel, stride, filters):
def _conv_layer(layer_input):
output = layers.Conv2D(name='{}/conv'.format(name),
filters=filters,
kernel_size=kernel,
strides=stride,
padding=params.CONV_PADDING,
use_bias=False,
activation=None)(layer_input)
output = _batch_norm(name='{}/conv/bn'.format(name))(output)
output = layers.ReLU(name='{}/relu'.format(name))(output)
return output
return _conv_layer
def _separable_conv(name, kernel, stride, filters):
def _separable_conv_layer(layer_input):
output = layers.DepthwiseConv2D(name='{}/depthwise_conv'.format(name),
kernel_size=kernel,
strides=stride,
depth_multiplier=1,
padding=params.CONV_PADDING,
use_bias=False,
activation=None)(layer_input)
output = _batch_norm(name='{}/depthwise_conv/bn'.format(name))(output)
output = layers.ReLU(name='{}/depthwise_conv/relu'.format(name))(output)
output = layers.Conv2D(name='{}/pointwise_conv'.format(name),
filters=filters,
kernel_size=(1, 1),
strides=1,
padding=params.CONV_PADDING,
use_bias=False,
activation=None)(output)
output = _batch_norm(name='{}/pointwise_conv/bn'.format(name))(output)
output = layers.ReLU(name='{}/pointwise_conv/relu'.format(name))(output)
return output
return _separable_conv_layer
_YAMNET_LAYER_DEFS = [
# (layer_function, kernel, stride, num_filters)
(_conv, [3, 3], 2, 32),
(_separable_conv, [3, 3], 1, 64),
(_separable_conv, [3, 3], 2, 128),
(_separable_conv, [3, 3], 1, 128),
(_separable_conv, [3, 3], 2, 256),
(_separable_conv, [3, 3], 1, 256),
(_separable_conv, [3, 3], 2, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 2, 1024),
(_separable_conv, [3, 3], 1, 1024)
]
def yamnet(features):
"""Define the core YAMNet mode in Keras."""
net = layers.Reshape(
(params.PATCH_FRAMES, params.PATCH_BANDS, 1),
input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features)
for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters)(net)
net = layers.GlobalAveragePooling2D()(net)
logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(net)
predictions = layers.Activation(
name=params.EXAMPLE_PREDICTIONS_LAYER_NAME,
activation=params.CLASSIFIER_ACTIVATION)(logits)
return predictions
def yamnet_frames_model(feature_params):
"""Defines the YAMNet waveform-to-class-scores model.
Args:
feature_params: An object with parameter fields to control the feature
calculation.
Returns:
A model accepting (1, num_samples) waveform input and emitting a
(num_patches, num_classes) matrix of class scores per time frame as
well as a (num_spectrogram_frames, num_mel_bins) spectrogram feature
matrix.
"""
waveform = layers.Input(batch_shape=(1, None))
# Store the intermediate spectrogram features to use in visualization.
spectrogram = features_lib.waveform_to_log_mel_spectrogram(
tf.squeeze(waveform, axis=0), feature_params)
patches = features_lib.spectrogram_to_patches(spectrogram, feature_params)
predictions = yamnet(patches)
frames_model = Model(name='yamnet_frames',
inputs=waveform, outputs=[predictions, spectrogram])
return frames_model
def class_names(class_map_csv):
"""Read the class name definition file and return a list of strings."""
with open(class_map_csv) as csv_file:
reader = csv.reader(csv_file)
next(reader) # Skip header
return np.array([display_name for (_, _, display_name) in reader])