gen_ai_simple_space / model_from_saved.py
NebulasBellum's picture
update with model
c8ab7ef verified
import tensorflow as tf
import copy
import numpy as np
def generate_from_saved():
# add the start generation of the lukashenko speech from the simple seed
seed_text = 'я не глядя поддержу'
weights_path = 'results/weights_lukash.h5'
model_path = 'results/Lukashenko_tarakan'
model = tf.keras.models.load_model(model_path)
model.load_weights(weights_path)
# Show the Model summary
model.summary()
with open('data/source_text_lukash.txt', 'r') as source_text_file:
data = source_text_file.read().splitlines()
tmp_data = copy.deepcopy(data)
sent_length = 0
for idx, line in enumerate(data):
if len(line) < 5:
tmp_data.pop(idx)
else:
sent_length += len(line.split())
data = tmp_data
lstm_length = int(sent_length / len(data))
token = tf.keras.preprocessing.text.Tokenizer()
token.fit_on_texts(data)
encoded_text = token.texts_to_sequences(data)
# Vocabular size
vocab_size = len(token.word_counts) + 1
datalist = []
for d in encoded_text:
if len(d) > 1:
for i in range(2, len(d)):
datalist.append(d[:i])
max_length = 20
sequences = tf.keras.preprocessing.sequence.pad_sequences(datalist, maxlen=max_length, padding='pre')
# X - input data, y - target data
X = sequences[:, :-1]
y = sequences[:, -1]
y = tf.keras.utils.to_categorical(y, num_classes=vocab_size)
seq_length = X.shape[1]
print(f"Sequence length: {seq_length}")
generated_text = ''
number_lines = 3
for i in range(number_lines):
text_word_list = []
for _ in range(lstm_length * 2):
encoded = token.texts_to_sequences([seed_text])
encoded = tf.keras.preprocessing.sequence.pad_sequences(encoded, maxlen=seq_length, padding='pre')
y_pred = np.argmax(model.predict(encoded), axis=-1)
predicted_word = ""
for word, index in token.word_index.items():
if index == y_pred:
predicted_word = word
break
seed_text = seed_text + ' ' + predicted_word
text_word_list.append(predicted_word)
seed_text = text_word_list [-1]
generated_text = ' '.join(text_word_list)
generated_text += '\n'
print(f"Lukashenko are saying: {generated_text}")
return generated_text