Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
# Introduction | |
st.title('RNN Character Prediction') | |
st.write(""" | |
This app demonstrates how to train a Recurrent Neural Network (RNN) to predict the next character in a given string. | |
The RNN learns the sequence of characters from a provided text and generates text based on the learned patterns. | |
You can choose different options for training the RNN and see how it affects the generated text. | |
""") | |
# User input for the training string | |
sequence = st.text_area('Enter the training string:', 'In the vast expanse of the digital realm.') | |
chars = list(set(sequence)) | |
data_size, vocab_size = len(sequence), len(chars) | |
# Create mappings from characters to indices and vice versa | |
char_to_idx = {ch: i for i, ch in enumerate(chars)} | |
idx_to_char = {i: ch for i, ch in enumerate(chars)} | |
# Convert the sequence to indices | |
indices = np.array([char_to_idx[ch] for ch in sequence]) | |
class RNN(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size): | |
super(RNN, self).__init__() | |
self.hidden_size = hidden_size | |
self.i2h = nn.Linear(input_size + hidden_size, hidden_size) | |
self.i2o = nn.Linear(input_size + hidden_size, output_size) | |
self.softmax = nn.LogSoftmax(dim=1) | |
def forward(self, input, hidden): | |
combined = torch.cat((input, hidden), 1) | |
hidden = self.i2h(combined) | |
output = self.i2o(combined) | |
output = self.softmax(output) | |
return output, hidden | |
def init_hidden(self): | |
return torch.zeros(1, self.hidden_size) | |
# Hyperparameters | |
n_hidden = 128 | |
learning_rate = 0.005 | |
# Initialize the model, loss function, and optimizer | |
rnn = RNN(vocab_size, n_hidden, vocab_size) | |
criterion = nn.NLLLoss() | |
# Define training options | |
options = { | |
'Quick Train (100 epochs)': 100, | |
'Medium Train (500 epochs)': 500, | |
'Long Train (1000 epochs)': 1000 | |
} | |
train_option = st.selectbox('Select training option:', list(options.keys())) | |
n_epochs = options[train_option] | |
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate) | |
def char_tensor(char): | |
if char not in char_to_idx: | |
raise ValueError(f"Character '{char}' not in vocabulary.") | |
tensor = torch.zeros(1, vocab_size) | |
tensor[0][char_to_idx[char]] = 1 | |
return tensor | |
# Training function | |
def train_model(n_epochs): | |
for epoch in range(n_epochs): | |
hidden = rnn.init_hidden() | |
rnn.zero_grad() | |
loss = 0 | |
for i in range(data_size - 1): | |
input_char = char_tensor(sequence[i]) | |
target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long) | |
output, hidden = rnn(input_char, hidden) | |
loss += criterion(output, target_char) | |
loss.backward() | |
optimizer.step() | |
if epoch % 10 == 0: | |
st.write(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}') | |
st.write("Training complete.") | |
# Train the model | |
if st.button('Train Model'): | |
train_model(n_epochs) | |
def generate(start_char, predict_len=100): | |
if start_char not in char_to_idx: | |
raise ValueError(f"Start character '{start_char}' not in vocabulary.") | |
hidden = rnn.init_hidden() | |
input_char = char_tensor(start_char) | |
predicted_str = start_char | |
for _ in range(predict_len): | |
output, hidden = rnn(input_char, hidden) | |
topv, topi = output.topk(1) | |
predicted_char_idx = topi[0][0].item() | |
predicted_char = idx_to_char[predicted_char_idx] | |
predicted_str += predicted_char | |
input_char = char_tensor(predicted_char) | |
return predicted_str | |
start_char = st.text_input('Enter a starting character:', 'h') | |
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50) | |
if st.button('Generate Text'): | |
try: | |
generated_text = generate(start_char, predict_len) | |
st.write('Generated Text:') | |
st.text(generated_text) | |
except ValueError as e: | |
st.error(str(e)) | |