pytorch / pages /17_RNN.py
eaglelandsonce's picture
Update pages/17_RNN.py
68c5fd5 verified
raw
history blame
4.02 kB
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
# Introduction
st.title('RNN Character Prediction')
st.write("""
This app demonstrates how to train a Recurrent Neural Network (RNN) to predict the next character in a given string.
The RNN learns the sequence of characters from a provided text and generates text based on the learned patterns.
You can choose different options for training the RNN and see how it affects the generated text.
""")
# User input for the training string
sequence = st.text_area('Enter the training string:', 'In the vast expanse of the digital realm.')
chars = list(set(sequence))
data_size, vocab_size = len(sequence), len(chars)
# Create mappings from characters to indices and vice versa
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
# Convert the sequence to indices
indices = np.array([char_to_idx[ch] for ch in sequence])
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def init_hidden(self):
return torch.zeros(1, self.hidden_size)
# Hyperparameters
n_hidden = 128
learning_rate = 0.005
# Initialize the model, loss function, and optimizer
rnn = RNN(vocab_size, n_hidden, vocab_size)
criterion = nn.NLLLoss()
# Define training options
options = {
'Quick Train (100 epochs)': 100,
'Medium Train (500 epochs)': 500,
'Long Train (1000 epochs)': 1000
}
train_option = st.selectbox('Select training option:', list(options.keys()))
n_epochs = options[train_option]
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
def char_tensor(char):
if char not in char_to_idx:
raise ValueError(f"Character '{char}' not in vocabulary.")
tensor = torch.zeros(1, vocab_size)
tensor[0][char_to_idx[char]] = 1
return tensor
# Training function
def train_model(n_epochs):
for epoch in range(n_epochs):
hidden = rnn.init_hidden()
rnn.zero_grad()
loss = 0
for i in range(data_size - 1):
input_char = char_tensor(sequence[i])
target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long)
output, hidden = rnn(input_char, hidden)
loss += criterion(output, target_char)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
st.write(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')
st.write("Training complete.")
# Train the model
if st.button('Train Model'):
train_model(n_epochs)
def generate(start_char, predict_len=100):
if start_char not in char_to_idx:
raise ValueError(f"Start character '{start_char}' not in vocabulary.")
hidden = rnn.init_hidden()
input_char = char_tensor(start_char)
predicted_str = start_char
for _ in range(predict_len):
output, hidden = rnn(input_char, hidden)
topv, topi = output.topk(1)
predicted_char_idx = topi[0][0].item()
predicted_char = idx_to_char[predicted_char_idx]
predicted_str += predicted_char
input_char = char_tensor(predicted_char)
return predicted_str
start_char = st.text_input('Enter a starting character:', 'h')
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
if st.button('Generate Text'):
try:
generated_text = generate(start_char, predict_len)
st.write('Generated Text:')
st.text(generated_text)
except ValueError as e:
st.error(str(e))