ZhongYing
first commit
503ec99
raw
history blame
2.59 kB
# coding=utf-8
# Copyright 2020 Beijing BluePulse Corp.
# Created by Zhang Guanqun on 2020/6/5
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from typing import Union, List
import unicodedata
def preprocess_paths(paths: Union[List, str]):
if isinstance(paths, list):
return [os.path.abspath(os.path.expanduser(path)) for path in paths]
return os.path.abspath(os.path.expanduser(paths)) if paths else None
def get_reduced_length(length, reduction_factor):
return tf.cast(tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), dtype=tf.int32)
def merge_two_last_dims(x):
b, _, f, c = shape_list(x)
return tf.reshape(x, shape=[b, -1, f * c])
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
# draw loss pic
def plot_metric(history, metric, pic_file_name):
train_metrics = history.history[metric]
val_metrics = history.history['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.savefig(pic_file_name)
# against LAS loop decoding
def text_no_repeat(s):
repeat_times = 0
repeat_pattern = ''
for i in range(1, len(s) // 2):
pos = i
if s[0 - 2 * pos:0 - pos] == s[0 - i:]:
tmp_repeat_pattern = s[0 - i:]
tmp_repeat_times = 1
while pos * (tmp_repeat_times + 2) <= len(s) \
and s[0 - pos * (tmp_repeat_times + 2):0 - pos * (tmp_repeat_times + 1)] == s[0 - i:]:
tmp_repeat_times += 1
if tmp_repeat_times * len(tmp_repeat_pattern) > repeat_times * len(repeat_pattern):
repeat_times = tmp_repeat_times
repeat_pattern = tmp_repeat_pattern
# print(repeat_pattern, '*', repeat_times)
if len(repeat_pattern) != 1:
s = s[:0 - repeat_times * len(repeat_pattern)] if repeat_times > 0 else s
# print(s)
return s
# Converts the unicode file to ascii
def unicode_to_ascii(s):
return ''.join(c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn')
def log10(x):
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator