pytorch / pages /17_RNN.py
eaglelandsonce's picture
Update pages/17_RNN.py
1f5dc84 verified
raw
history blame
3.01 kB
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
# Define the dataset
sequence = "hellohellohello"
chars = list(set(sequence))
data_size, vocab_size = len(sequence), len(chars)
# Create mappings from characters to indices and vice versa
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
# Convert the sequence to indices
indices = np.array([char_to_idx[ch] for ch in sequence])
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def init_hidden(self):
return torch.zeros(1, self.hidden_size)
# Hyperparameters
n_hidden = 128
learning_rate = 0.005
n_epochs = 500
# Initialize the model, loss function, and optimizer
rnn = RNN(vocab_size, n_hidden, vocab_size)
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
def char_tensor(char):
tensor = torch.zeros(1, vocab_size)
tensor[0][char_to_idx[char]] = 1
return tensor
# Training loop
for epoch in range(n_epochs):
hidden = rnn.init_hidden()
rnn.zero_grad()
loss = 0
for i in range(data_size - 1):
input_char = char_tensor(sequence[i])
target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long)
output, hidden = rnn(input_char, hidden)
loss += criterion(output, target_char)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}')
print("Training complete.")
def generate(start_char, predict_len=100):
hidden = rnn.init_hidden()
input_char = char_tensor(start_char)
predicted_str = start_char
for _ in range(predict_len):
output, hidden = rnn(input_char, hidden)
topv, topi = output.topk(1)
predicted_char_idx = topi[0][0].item()
predicted_char = idx_to_char[predicted_char_idx]
predicted_str += predicted_char
input_char = char_tensor(predicted_char)
return predicted_str
# Streamlit interface
st.title('RNN Character Prediction')
st.write('This app uses a Recurrent Neural Network (RNN) to predict the next character in a given string.')
start_char = st.text_input('Enter a starting character:', 'h')
predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50)
if st.button('Generate Text'):
generated_text = generate(start_char, predict_len)
st.write('Generated Text:')
st.text(generated_text)