sweetfelinity commited on
Commit
b786a12
·
verified ·
1 Parent(s): a72f353

Initial commit

Browse files
Files changed (4) hide show
  1. PoemGen.keras +0 -0
  2. PoemGeneration.py +88 -0
  3. app.py +21 -0
  4. poem.txt +0 -0
PoemGen.keras ADDED
Binary file (856 kB). View file
 
PoemGeneration.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import Sequential
5
+ from tensorflow.keras.layers import LSTM, Dense, Activation
6
+ from tensorflow.keras.optimizers import RMSprop
7
+ import os
8
+
9
+ SEQ_LENGTH = 50
10
+ STEP_SIZE = 2
11
+
12
+ model_output_path = "PoemGen.keras"
13
+ input_text_path = "poem.txt"
14
+ generated_text_length = 700
15
+ retrain = False
16
+ batch_size = 256
17
+ epoch_num = 150
18
+
19
+ text = open(input_text_path, "rb").read().decode(encoding="utf-8").lower()
20
+ character_set = sorted(set(text))
21
+ char_to_index = dict((c, i) for i, c in enumerate(character_set))
22
+ index_to_char = dict((i, c) for i, c in enumerate(character_set))
23
+
24
+ class PreloadedRNNModel:
25
+ def __init__(self):
26
+ self.model = tf.keras.models.load_model(model_output_path)
27
+
28
+ def generate_text(self, temperature, output_length):
29
+ start_index = random.randint(0, len(text) - SEQ_LENGTH - 1)
30
+ sentence = text[start_index: start_index + SEQ_LENGTH]
31
+ generated_text = sentence
32
+
33
+ for _ in range(output_length):
34
+ x = np.zeros((1, SEQ_LENGTH, len(character_set)))
35
+ for j, character in enumerate(sentence):
36
+ x[0, j, char_to_index[character]] = 1
37
+
38
+ predictions = self.model.predict(x, verbose=0)[0]
39
+ next_index = self.sample(predictions, temperature)
40
+ next_character = index_to_char[next_index]
41
+
42
+ generated_text += next_character
43
+ sentence = sentence[1:] + next_character
44
+
45
+ return generated_text
46
+
47
+ def sample(self, preds, temperature):
48
+ preds = np.asarray(preds).astype("float64")
49
+ preds = np.log(preds) / temperature
50
+ exp_preds = np.exp(preds)
51
+ preds = exp_preds / np.sum(exp_preds)
52
+ probs = np.random.multinomial(1, preds, 1)
53
+ return np.argmax(probs)
54
+
55
+
56
+ # Create RNN model
57
+ if not os.path.exists(model_output_path) or retrain:
58
+ sentences = []
59
+ next_characters = []
60
+
61
+ for i in range(0, len(text) - SEQ_LENGTH, STEP_SIZE):
62
+ sentences.append(text[i: i + SEQ_LENGTH])
63
+ next_characters.append(text[i + SEQ_LENGTH])
64
+
65
+ x = np.zeros((len(sentences), SEQ_LENGTH, len(character_set)), dtype=np.bool_)
66
+ y = np.zeros((len(sentences), len(character_set)), dtype=np.bool_)
67
+
68
+ for i, sentence in enumerate(sentences):
69
+ for j, character in enumerate(sentence):
70
+ x[i, j, char_to_index[character]] = 1
71
+ y[i, char_to_index[next_characters[i]]] = 1
72
+
73
+ if not os.path.exists(model_output_path) or retrain:
74
+ model = Sequential()
75
+ model.add(LSTM(128, input_shape=(SEQ_LENGTH, len(character_set))))
76
+ model.add(Dense(len(character_set)))
77
+ model.add(Activation("softmax"))
78
+ model.compile(loss="categorical_crossentropy", optimizer=RMSprop(0.01))
79
+ model.fit(x, y, batch_size=batch_size, epochs=epoch_num)
80
+ model.save(model_output_path)
81
+ print("Model saved to path", model_output_path)
82
+
83
+ if __name__ == "__main__":
84
+ model = PreloadedRNNModel()
85
+
86
+ for temperature in [0.2, 0.4, 0.6, 0.8, 1]:
87
+ print("\nGenerated text with temperature: ", temperature)
88
+ print(model.generate_text(temperature, 500))
app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PoemGeneration import PreloadedRNNModel
3
+
4
+ model = PreloadedRNNModel()
5
+
6
+ def generate_poem(temperature, output_length):
7
+ return model.generate_text(temperature, output_length)
8
+
9
+ with gr.Blocks() as demo:
10
+ with gr.Row():
11
+ with gr.Column():
12
+ temperature_slider = gr.Slider(0.01, 1, label="Temperature", value=0.5)
13
+ output_length = gr.Number(value=700, label="Output Length")
14
+ start_button = gr.Button(variant="primary")
15
+ examples = gr.Examples([0.2, 0.4, 0.6, 0.8, 1], temperature_slider)
16
+ with gr.Column():
17
+ output_text = gr.Text(label="Generated text", interactive=False)
18
+
19
+ start_button.click(fn=generate_poem, inputs=[temperature_slider, output_length], outputs=output_text)
20
+
21
+ demo.launch()
poem.txt ADDED
The diff for this file is too large to render. See raw diff