|
import tensorflow as tf |
|
import copy |
|
import numpy as np |
|
|
|
|
|
def generate_from_saved(): |
|
|
|
seed_text = 'я не глядя поддержу' |
|
weights_path = 'results/weights_lukash.h5' |
|
model_path = 'results/Lukashenko_tarakan' |
|
|
|
model = tf.keras.models.load_model(model_path) |
|
model.load_weights(weights_path) |
|
|
|
model.summary() |
|
|
|
with open('data/source_text_lukash.txt', 'r') as source_text_file: |
|
data = source_text_file.read().splitlines() |
|
|
|
tmp_data = copy.deepcopy(data) |
|
sent_length = 0 |
|
for idx, line in enumerate(data): |
|
if len(line) < 5: |
|
tmp_data.pop(idx) |
|
else: |
|
sent_length += len(line.split()) |
|
data = tmp_data |
|
lstm_length = int(sent_length / len(data)) |
|
|
|
token = tf.keras.preprocessing.text.Tokenizer() |
|
token.fit_on_texts(data) |
|
encoded_text = token.texts_to_sequences(data) |
|
|
|
vocab_size = len(token.word_counts) + 1 |
|
|
|
datalist = [] |
|
for d in encoded_text: |
|
if len(d) > 1: |
|
for i in range(2, len(d)): |
|
datalist.append(d[:i]) |
|
|
|
max_length = 20 |
|
sequences = tf.keras.preprocessing.sequence.pad_sequences(datalist, maxlen=max_length, padding='pre') |
|
|
|
|
|
X = sequences[:, :-1] |
|
y = sequences[:, -1] |
|
|
|
y = tf.keras.utils.to_categorical(y, num_classes=vocab_size) |
|
seq_length = X.shape[1] |
|
print(f"Sequence length: {seq_length}") |
|
|
|
generated_text = '' |
|
number_lines = 3 |
|
for i in range(number_lines): |
|
text_word_list = [] |
|
for _ in range(lstm_length * 2): |
|
encoded = token.texts_to_sequences([seed_text]) |
|
encoded = tf.keras.preprocessing.sequence.pad_sequences(encoded, maxlen=seq_length, padding='pre') |
|
|
|
y_pred = np.argmax(model.predict(encoded), axis=-1) |
|
|
|
predicted_word = "" |
|
for word, index in token.word_index.items(): |
|
if index == y_pred: |
|
predicted_word = word |
|
break |
|
|
|
seed_text = seed_text + ' ' + predicted_word |
|
text_word_list.append(predicted_word) |
|
|
|
seed_text = text_word_list [-1] |
|
generated_text = ' '.join(text_word_list) |
|
generated_text += '\n' |
|
|
|
print(f"Lukashenko are saying: {generated_text}") |
|
return generated_text |
|
|