|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import os |
|
import tensorflow as tf |
|
from typing import Union, List |
|
import unicodedata |
|
|
|
|
|
def preprocess_paths(paths: Union[List, str]): |
|
if isinstance(paths, list): |
|
return [os.path.abspath(os.path.expanduser(path)) for path in paths] |
|
return os.path.abspath(os.path.expanduser(paths)) if paths else None |
|
|
|
|
|
def get_reduced_length(length, reduction_factor): |
|
return tf.cast(tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), dtype=tf.int32) |
|
|
|
|
|
def merge_two_last_dims(x): |
|
b, _, f, c = shape_list(x) |
|
return tf.reshape(x, shape=[b, -1, f * c]) |
|
|
|
|
|
def shape_list(x): |
|
"""Deal with dynamic shape in tensorflow cleanly.""" |
|
static = x.shape.as_list() |
|
dynamic = tf.shape(x) |
|
return [dynamic[i] if s is None else s for i, s in enumerate(static)] |
|
|
|
|
|
|
|
def plot_metric(history, metric, pic_file_name): |
|
train_metrics = history.history[metric] |
|
val_metrics = history.history['val_'+metric] |
|
epochs = range(1, len(train_metrics) + 1) |
|
plt.plot(epochs, train_metrics, 'bo--') |
|
plt.plot(epochs, val_metrics, 'ro-') |
|
plt.title('Training and validation '+ metric) |
|
plt.xlabel("Epochs") |
|
plt.ylabel(metric) |
|
plt.legend(["train_"+metric, 'val_'+metric]) |
|
plt.savefig(pic_file_name) |
|
|
|
|
|
|
|
def text_no_repeat(s): |
|
repeat_times = 0 |
|
repeat_pattern = '' |
|
for i in range(1, len(s) // 2): |
|
pos = i |
|
if s[0 - 2 * pos:0 - pos] == s[0 - i:]: |
|
tmp_repeat_pattern = s[0 - i:] |
|
tmp_repeat_times = 1 |
|
while pos * (tmp_repeat_times + 2) <= len(s) \ |
|
and s[0 - pos * (tmp_repeat_times + 2):0 - pos * (tmp_repeat_times + 1)] == s[0 - i:]: |
|
tmp_repeat_times += 1 |
|
if tmp_repeat_times * len(tmp_repeat_pattern) > repeat_times * len(repeat_pattern): |
|
repeat_times = tmp_repeat_times |
|
repeat_pattern = tmp_repeat_pattern |
|
|
|
if len(repeat_pattern) != 1: |
|
s = s[:0 - repeat_times * len(repeat_pattern)] if repeat_times > 0 else s |
|
|
|
return s |
|
|
|
|
|
def unicode_to_ascii(s): |
|
return ''.join(c for c in unicodedata.normalize('NFD', s) |
|
if unicodedata.category(c) != 'Mn') |
|
|
|
def log10(x): |
|
numerator = tf.math.log(x) |
|
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) |
|
return numerator / denominator |