eaglelandsonce commited on
Commit
ae9ee23
·
verified ·
1 Parent(s): 284b3b6

Update pages/17_RNN_News.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN_News.py +3 -3
pages/17_RNN_News.py CHANGED
@@ -6,6 +6,7 @@ from torchtext.data.utils import get_tokenizer
6
  from torchtext.vocab import build_vocab_from_iterator
7
  from torchtext.datasets import AG_NEWS
8
  from torch.utils.data import DataLoader, random_split
 
9
  import matplotlib.pyplot as plt
10
  import pandas as pd
11
  import numpy as np
@@ -28,12 +29,11 @@ class RNN(nn.Module):
28
 
29
  # Create a custom collate function to pad sequences
30
  def collate_batch(batch):
31
- label_list, text_list, lengths = [], [], []
32
  for _label, _text in batch:
33
  label_list.append(label_pipeline(_label))
34
  processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
35
  text_list.append(processed_text)
36
- lengths.append(processed_text.size(0))
37
  labels = torch.tensor(label_list, dtype=torch.int64)
38
  texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
39
  return texts, labels
@@ -191,4 +191,4 @@ if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'
191
  for i, (text, true_label, predicted) in enumerate(samples):
192
  st.write(f'Sample {i+1}')
193
  st.text(' '.join([vocab.get_itos()[token] for token in text]))
194
- st.write(f'Ground Truth: {LABEL.vocab.itos[true_label.item()]}, Predicted: {LABEL.vocab.itos[predicted.item()]}')
 
6
  from torchtext.vocab import build_vocab_from_iterator
7
  from torchtext.datasets import AG_NEWS
8
  from torch.utils.data import DataLoader, random_split
9
+ from torch.nn.utils.rnn import pad_sequence
10
  import matplotlib.pyplot as plt
11
  import pandas as pd
12
  import numpy as np
 
29
 
30
  # Create a custom collate function to pad sequences
31
  def collate_batch(batch):
32
+ label_list, text_list = [], []
33
  for _label, _text in batch:
34
  label_list.append(label_pipeline(_label))
35
  processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
36
  text_list.append(processed_text)
 
37
  labels = torch.tensor(label_list, dtype=torch.int64)
38
  texts = pad_sequence(text_list, batch_first=True, padding_value=vocab["<pad>"])
39
  return texts, labels
 
191
  for i, (text, true_label, predicted) in enumerate(samples):
192
  st.write(f'Sample {i+1}')
193
  st.text(' '.join([vocab.get_itos()[token] for token in text]))
194
+ st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')