|
""" |
|
This task is running a cross validation. |
|
We start from the two-fold validation. |
|
""" |
|
|
|
|
|
import numpy as np |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Dense, Dropout |
|
from tensorflow.keras.losses import categorical_crossentropy |
|
from tensorflow.keras.optimizers import Adam |
|
from tensorflow.keras.callbacks import ModelCheckpoint |
|
from tensorflow.keras.utils import to_categorical |
|
import tensorflow as tf |
|
from sklearn.metrics import roc_curve |
|
from scipy.interpolate import interp1d |
|
from scipy.optimize import brentq |
|
import os |
|
import random |
|
|
|
def eer(x_test, y_test, model): |
|
preds = model.predict(x_test) |
|
fpr, tpr, thresholds = roc_curve(y_test, preds) |
|
return brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) |
|
|
|
|
|
|
|
data = np.load("/home/fazhong/Github/czx/data.npy", allow_pickle=True) |
|
labels = np.load("/home/fazhong/Github/czx/labels.npy", allow_pickle=True) |
|
|
|
|
|
data_all = [] |
|
data = data.tolist() |
|
|
|
labels = labels.tolist() |
|
for i in range(len(data)): |
|
tmp = [] |
|
tmp.append(np.array(data[i][0])) |
|
tmp.extend(to_categorical([labels[i][1]],num_classes=4).tolist()) |
|
tmp.extend([labels[i][0]]) |
|
tmp.extend([labels[i][2]]) |
|
data_all.append(tmp) |
|
random.shuffle(data_all) |
|
data = data_all |
|
|
|
|
|
|
|
batch_size = 10 |
|
feature_len = 110 |
|
loss_function = categorical_crossentropy |
|
|
|
no_epochs = 150 |
|
optimizer = Adam() |
|
verbosity = 1 |
|
model = Sequential() |
|
model.add(Dense(64, input_dim=feature_len, activation='relu')) |
|
model.add(Dropout(0.2)) |
|
model.add(Dense(32, activation='relu')) |
|
model.add(Dropout(0.2)) |
|
model.add(Dense(16, activation='relu')) |
|
model.add(Dense(4, activation='softmax')) |
|
model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy']) |
|
|
|
|
|
data_train = data[:int(0.5*(len(data)))] |
|
print(len(data_train)) |
|
X1 = np.asarray([x[0] for x in data_train]) |
|
print(X1.shape) |
|
temp = [x[1] for x in data_train] |
|
print(len(temp)) |
|
print(len(temp[1])) |
|
y1 = np.asarray([x[1] for x in data_train]) |
|
print(y1.shape) |
|
data_test = data[int(0.5*(len(data))):] |
|
X2 = np.asarray([x[0] for x in data_test]) |
|
y2 = np.asarray([x[1] for x in data_test]) |
|
checkpointer = ModelCheckpoint(filepath="./data-task0/train1.keras", |
|
verbose=verbosity, save_best_only=True) |
|
print('-' * 30) |
|
print('Training for whole data set') |
|
history = model.fit(X1, y1, |
|
|
|
validation_split=0.1, |
|
batch_size=batch_size, |
|
epochs=no_epochs, |
|
verbose=verbosity, |
|
callbacks=[checkpointer, tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)] |
|
) |
|
|
|
|
|
checkpointer = ModelCheckpoint(filepath="./data-task0/train2.keras", |
|
verbose=verbosity, save_best_only=True) |
|
print('-' * 30) |
|
print('Training for whole data set') |
|
history = model.fit(X2, y2, |
|
|
|
validation_split=0.1, |
|
batch_size=batch_size, |
|
epochs=no_epochs, |
|
verbose=verbosity, |
|
callbacks=[checkpointer, tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)] |
|
) |
|
|
|
|
|
|
|
|
|
X1 = np.asarray([x[0] for x in data_train]) |
|
y1 = np.asarray([x[1] for x in data_train]) |
|
|
|
X2 = np.asarray([x[0] for x in data_test]) |
|
y2 = np.asarray([x[1] for x in data_test]) |
|
|
|
|
|
model.load_weights("./data-task0/train1.keras") |
|
scores = model.evaluate(X2, y2) |
|
y_pred2 = model.predict(X2) |
|
print(y_pred2.shape) |
|
|
|
model.load_weights("./data-task0/train2.keras") |
|
scores = model.evaluate(X1, y1) |
|
y_pred1 = model.predict(X1) |
|
|
|
y_pred = np.concatenate((y_pred1, y_pred2)) |
|
y_pred_classes = np.argmax(y_pred,axis=1) |
|
y_label_classes = np.argmax(np.concatenate((y1, y2)),axis=1) |
|
print(y_pred_classes) |
|
ACCU = np.sum((y_pred_classes == y_label_classes)) / len(y_label_classes) |
|
print("ACCU is " + str(100 * ACCU)) |
|
fpr, tpr, thresholds = roc_curve(y_label_classes, y_pred) |
|
EER = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) |
|
print(EER) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|