eaglelandsonce commited on
Commit
829c774
·
verified ·
1 Parent(s): 59482c2

Update pages/17_RNN_News.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN_News.py +13 -17
pages/17_RNN_News.py CHANGED
@@ -51,25 +51,22 @@ def load_data():
51
  vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
52
  vocab.set_default_index(vocab["<unk>"])
53
 
54
- # Define the text and label processing pipelines globally
55
- global text_pipeline, label_pipeline
56
- text_pipeline = lambda x: vocab(tokenizer(x))
57
- label_pipeline = lambda x: int(x) - 1
58
 
59
- # Create DataLoaders
60
- train_dataset = list(train_iter)
61
- test_dataset = list(test_iter)
 
62
 
63
- train_size = int(0.8 * len(train_dataset))
64
- valid_size = len(train_dataset) - train_size
65
- train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
 
66
 
67
- BATCH_SIZE = 64
68
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
69
- valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
70
- test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
71
-
72
- return vocab, train_loader, valid_loader, test_loader
73
 
74
  # Function to train the network
75
  def train_network(net, iterator, optimizer, criterion, epochs):
@@ -116,7 +113,6 @@ def evaluate_network(net, iterator, criterion):
116
 
117
  # Load data
118
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
119
- vocab, train_loader, valid_loader, test_loader = load_data()
120
 
121
  # Streamlit interface
122
  st.title("RNN for Text Classification on AG News Dataset")
 
51
  vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
52
  vocab.set_default_index(vocab["<unk>"])
53
 
54
+ return vocab, tokenizer, list(train_iter), list(test_iter)
 
 
 
55
 
56
+ # Initialize global pipelines
57
+ vocab, tokenizer, train_dataset, test_dataset = load_data()
58
+ text_pipeline = lambda x: vocab(tokenizer(x))
59
+ label_pipeline = lambda x: int(x) - 1
60
 
61
+ # Create DataLoaders
62
+ train_size = int(0.8 * len(train_dataset))
63
+ valid_size = len(train_dataset) - train_size
64
+ train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
65
 
66
+ BATCH_SIZE = 64
67
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
68
+ valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
69
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
 
 
70
 
71
  # Function to train the network
72
  def train_network(net, iterator, optimizer, criterion, epochs):
 
113
 
114
  # Load data
115
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
116
 
117
  # Streamlit interface
118
  st.title("RNN for Text Classification on AG News Dataset")