File size: 2,591 Bytes
503ec99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# coding=utf-8
# Copyright 2020 Beijing BluePulse Corp.
# Created by Zhang Guanqun on 2020/6/5


import matplotlib.pyplot as plt
import os
import tensorflow as tf
from typing import Union, List
import unicodedata


def preprocess_paths(paths: Union[List, str]):
    if isinstance(paths, list):
        return [os.path.abspath(os.path.expanduser(path)) for path in paths]
    return os.path.abspath(os.path.expanduser(paths)) if paths else None


def get_reduced_length(length, reduction_factor):
    return tf.cast(tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), dtype=tf.int32)


def merge_two_last_dims(x):
    b, _, f, c = shape_list(x)
    return tf.reshape(x, shape=[b, -1, f * c])


def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]


# draw loss pic
def plot_metric(history, metric, pic_file_name):
    train_metrics = history.history[metric]
    val_metrics = history.history['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.plot(epochs, val_metrics, 'ro-')
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.savefig(pic_file_name)


# against LAS loop decoding
def text_no_repeat(s):
    repeat_times = 0
    repeat_pattern = ''
    for i in range(1, len(s) // 2):
        pos = i
        if s[0 - 2 * pos:0 - pos] == s[0 - i:]:
            tmp_repeat_pattern = s[0 - i:]
            tmp_repeat_times = 1
            while pos * (tmp_repeat_times + 2) <= len(s) \
                    and s[0 - pos * (tmp_repeat_times + 2):0 - pos * (tmp_repeat_times + 1)] == s[0 - i:]:
                tmp_repeat_times += 1
            if tmp_repeat_times * len(tmp_repeat_pattern) > repeat_times * len(repeat_pattern):
                repeat_times = tmp_repeat_times
                repeat_pattern = tmp_repeat_pattern
    # print(repeat_pattern, '*', repeat_times)
    if len(repeat_pattern) != 1:
        s = s[:0 - repeat_times * len(repeat_pattern)] if repeat_times > 0 else s
    # print(s)
    return s

# Converts the unicode file to ascii
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                   if unicodedata.category(c) != 'Mn')

def log10(x):
    numerator = tf.math.log(x)
    denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
    return numerator / denominator