eaglelandsonce commited on
Commit
713f0f4
·
verified ·
1 Parent(s): e84d316

Update pages/17_RNN.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN.py +13 -10
pages/17_RNN.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
  from torchtext.data.utils import get_tokenizer
6
- from torchtext.vocab import build_vocab_from_iterator, GloVe
7
  from torchtext.datasets import IMDB
8
  from torch.utils.data import DataLoader, random_split
9
  import matplotlib.pyplot as plt
@@ -29,6 +29,13 @@ class RNN(nn.Module):
29
  out = self.fc(out[:, -1, :])
30
  return out
31
 
 
 
 
 
 
 
 
32
  # Function to load the data
33
  @st.cache_data
34
  def load_data():
@@ -39,7 +46,7 @@ def load_data():
39
  for _, text in data_iter:
40
  yield tokenizer(text)
41
 
42
- vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
43
  vocab.set_default_index(vocab["<unk>"])
44
 
45
  # Define the text and label processing pipelines
@@ -57,13 +64,6 @@ def load_data():
57
  train_texts, train_labels = process_data(train_iter)
58
  test_texts, test_labels = process_data(test_iter)
59
 
60
- # Create a custom collate function to pad sequences
61
- def collate_batch(batch):
62
- texts, labels = zip(*batch)
63
- text_lengths = [len(text) for text in texts]
64
- texts_padded = pad_sequence(texts, batch_first=True, padding_value=vocab["<pad>"])
65
- return texts_padded, torch.tensor(labels, dtype=torch.float), text_lengths
66
-
67
  # Create DataLoaders
68
  train_dataset = list(zip(train_texts, train_labels))
69
  test_dataset = list(zip(test_texts, test_labels))
@@ -124,6 +124,9 @@ def evaluate_network(net, iterator, criterion):
124
 
125
  # Load the data
126
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
127
  vocab, train_loader, valid_loader, test_loader = load_data()
128
 
129
  # Streamlit interface
@@ -198,5 +201,5 @@ if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'
198
  st.write('Ground Truth vs Predicted for Sample Texts')
199
  for i, (text, true_label, predicted) in enumerate(samples):
200
  st.write(f'Sample {i+1}')
201
- st.text(' '.join([vocab.itos[token] for token in text]))
202
  st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')
 
3
  import torch.nn as nn
4
  import torch.optim as optim
5
  from torchtext.data.utils import get_tokenizer
6
+ from torchtext.vocab import build_vocab_from_iterator
7
  from torchtext.datasets import IMDB
8
  from torch.utils.data import DataLoader, random_split
9
  import matplotlib.pyplot as plt
 
29
  out = self.fc(out[:, -1, :])
30
  return out
31
 
32
+ # Create a custom collate function to pad sequences
33
+ def collate_batch(batch):
34
+ texts, labels = zip(*batch)
35
+ text_lengths = [len(text) for text in texts]
36
+ texts_padded = pad_sequence(texts, batch_first=True, padding_value=vocab["<pad>"])
37
+ return texts_padded, torch.tensor(labels, dtype=torch.float), text_lengths
38
+
39
  # Function to load the data
40
  @st.cache_data
41
  def load_data():
 
46
  for _, text in data_iter:
47
  yield tokenizer(text)
48
 
49
+ vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
50
  vocab.set_default_index(vocab["<unk>"])
51
 
52
  # Define the text and label processing pipelines
 
64
  train_texts, train_labels = process_data(train_iter)
65
  test_texts, test_labels = process_data(test_iter)
66
 
 
 
 
 
 
 
 
67
  # Create DataLoaders
68
  train_dataset = list(zip(train_texts, train_labels))
69
  test_dataset = list(zip(test_texts, test_labels))
 
124
 
125
  # Load the data
126
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
127
+
128
+ # Display a loading message with some vertical space
129
+ st.markdown("<div style='margin-top: 50px;'><b>Loading data...</b></div>", unsafe_allow_html=True)
130
  vocab, train_loader, valid_loader, test_loader = load_data()
131
 
132
  # Streamlit interface
 
201
  st.write('Ground Truth vs Predicted for Sample Texts')
202
  for i, (text, true_label, predicted) in enumerate(samples):
203
  st.write(f'Sample {i+1}')
204
+ st.text(' '.join([vocab.get_itos()[token] for token in text]))
205
  st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')