Spaces:
Running
Running
Delete pages/19_RNN_LSTM_Shakespeare.py
Browse files- pages/19_RNN_LSTM_Shakespeare.py +0 -113
pages/19_RNN_LSTM_Shakespeare.py
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.optim as optim
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
# Define the RNN or LSTM Model
|
8 |
-
class LSTMModel(nn.Module):
|
9 |
-
def __init__(self, input_size, hidden_size, output_size, num_layers):
|
10 |
-
super(LSTMModel, self).__init__()
|
11 |
-
self.hidden_size = hidden_size
|
12 |
-
self.num_layers = num_layers
|
13 |
-
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
14 |
-
self.fc = nn.Linear(hidden_size, output_size)
|
15 |
-
|
16 |
-
def forward(self, x, h):
|
17 |
-
out, h = self.lstm(x, h)
|
18 |
-
out = self.fc(out[:, -1, :])
|
19 |
-
return out, h
|
20 |
-
|
21 |
-
# Text generation function
|
22 |
-
def generate_text(model, start_str, length, char_to_int, int_to_char, num_layers, hidden_size):
|
23 |
-
model.eval()
|
24 |
-
input_seq = [char_to_int[c] for c in start_str]
|
25 |
-
input_seq = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
|
26 |
-
h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
|
27 |
-
generated_text = start_str
|
28 |
-
|
29 |
-
for _ in range(length):
|
30 |
-
output, h = model(input_seq, h)
|
31 |
-
_, predicted = torch.max(output, 1)
|
32 |
-
predicted_char = int_to_char[predicted.item()]
|
33 |
-
generated_text += predicted_char
|
34 |
-
input_seq = torch.tensor([char_to_int[predicted_char]], dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
|
35 |
-
|
36 |
-
return generated_text
|
37 |
-
|
38 |
-
# Streamlit interface
|
39 |
-
st.title("RNN/LSTM Text Generation")
|
40 |
-
|
41 |
-
# Inputs
|
42 |
-
text_data = st.text_area("Enter your text data for training:", "To be, or not to be, that is the question:\nWhether 'tis nobler in the mind to suffer\nThe slings and arrows of outrageous fortune,\nOr to take arms against a sea of troubles\nAnd by opposing end them. To die: to sleep;\nNo more; and by a sleep to say we end\nThe heart-ache and the thousand natural shocks\nThat flesh is heir to, 'tis a consummation\nDevoutly to be wish'd. To die, to sleep;\nTo sleep: perchance to dream: ay, there's the rub;\nFor in that sleep of death what dreams may come\nWhen we have shuffled off this mortal coil,\nMust give us pause: there's the respect\nThat makes calamity of so long life;")
|
43 |
-
start_string = st.text_input("Enter the start string for text generation:")
|
44 |
-
seq_length = st.number_input("Sequence length:", min_value=10, value=100)
|
45 |
-
hidden_size = st.number_input("Hidden size:", min_value=50, value=256)
|
46 |
-
num_layers = st.number_input("Number of layers:", min_value=1, value=2)
|
47 |
-
learning_rate = st.number_input("Learning rate:", min_value=0.0001, value=0.003, format="%.4f")
|
48 |
-
num_epochs = st.number_input("Number of epochs:", min_value=1, value=20)
|
49 |
-
generate_length = st.number_input("Generated text length:", min_value=50, value=500)
|
50 |
-
|
51 |
-
if st.button("Train and Generate"):
|
52 |
-
# Data Preparation
|
53 |
-
text = text_data
|
54 |
-
if len(text) <= seq_length:
|
55 |
-
st.error("Text data is too short for the given sequence length. Please enter more text data.")
|
56 |
-
else:
|
57 |
-
chars = sorted(list(set(text)))
|
58 |
-
char_to_int = {c: i for i, c in enumerate(chars)}
|
59 |
-
int_to_char = {i: c for i, c in enumerate(chars)}
|
60 |
-
|
61 |
-
# Prepare input-output pairs
|
62 |
-
dataX = []
|
63 |
-
dataY = []
|
64 |
-
for i in range(0, len(text) - seq_length):
|
65 |
-
seq_in = text[i:i + seq_length]
|
66 |
-
seq_out = text[i + seq_length]
|
67 |
-
dataX.append([char_to_int[char] for char in seq_in])
|
68 |
-
dataY.append(char_to_int[seq_out])
|
69 |
-
|
70 |
-
if len(dataX) == 0:
|
71 |
-
st.error("Not enough data to create input-output pairs. Please provide more text data.")
|
72 |
-
else:
|
73 |
-
X = np.reshape(dataX, (len(dataX), seq_length, 1))
|
74 |
-
X = X / float(len(chars))
|
75 |
-
Y = np.array(dataY)
|
76 |
-
|
77 |
-
# Convert to PyTorch tensors
|
78 |
-
X_tensor = torch.tensor(X, dtype=torch.float32)
|
79 |
-
Y_tensor = torch.tensor(Y, dtype=torch.long)
|
80 |
-
|
81 |
-
# Model initialization
|
82 |
-
model = LSTMModel(input_size=1, hidden_size=hidden_size, output_size=len(chars), num_layers=num_layers)
|
83 |
-
|
84 |
-
# Loss and optimizer
|
85 |
-
criterion = nn.CrossEntropyLoss()
|
86 |
-
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
87 |
-
|
88 |
-
# Training the model
|
89 |
-
for epoch in range(num_epochs):
|
90 |
-
h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
|
91 |
-
epoch_loss = 0
|
92 |
-
for i in range(len(dataX)):
|
93 |
-
inputs = X_tensor[i].unsqueeze(0) # Shape: (1, seq_length, 1)
|
94 |
-
targets = Y_tensor[i].unsqueeze(0) # Shape: (1,)
|
95 |
-
|
96 |
-
# Forward pass
|
97 |
-
outputs, h = model(inputs, (h[0].detach(), h[1].detach()))
|
98 |
-
loss = criterion(outputs, targets)
|
99 |
-
|
100 |
-
# Backward pass and optimization
|
101 |
-
optimizer.zero_grad()
|
102 |
-
loss.backward()
|
103 |
-
optimizer.step()
|
104 |
-
|
105 |
-
epoch_loss += loss.item()
|
106 |
-
|
107 |
-
avg_loss = epoch_loss / len(dataX)
|
108 |
-
st.write(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
109 |
-
|
110 |
-
# Text generation
|
111 |
-
generated_text = generate_text(model, start_string, generate_length, char_to_int, int_to_char, num_layers, hidden_size)
|
112 |
-
st.subheader("Generated Text")
|
113 |
-
st.write(generated_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|