eaglelandsonce commited on
Commit
1535a31
·
verified ·
1 Parent(s): ae9ee23

Update pages/17_RNN_News.py

Browse files
Files changed (1) hide show
  1. pages/17_RNN_News.py +5 -4
pages/17_RNN_News.py CHANGED
@@ -52,6 +52,7 @@ def load_data():
52
  vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
53
  vocab.set_default_index(vocab["<unk>"])
54
 
 
55
  global text_pipeline, label_pipeline
56
  text_pipeline = lambda x: vocab(tokenizer(x))
57
  label_pipeline = lambda x: int(x) - 1
@@ -80,7 +81,7 @@ def train_network(net, iterator, optimizer, criterion, epochs):
80
  for texts, labels in iterator:
81
  texts, labels = texts.to(device), labels.to(device)
82
  optimizer.zero_grad()
83
- predictions = net(texts).squeeze(1)
84
  loss = criterion(predictions, labels)
85
  loss.backward()
86
  optimizer.step()
@@ -102,7 +103,7 @@ def evaluate_network(net, iterator, criterion):
102
  with torch.no_grad():
103
  for texts, labels in iterator:
104
  texts, labels = texts.to(device), labels.to(device)
105
- predictions = net(texts).squeeze(1)
106
  loss = criterion(predictions, labels)
107
  epoch_loss += loss.item()
108
  _, predicted = torch.max(predictions, 1)
@@ -168,8 +169,8 @@ if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
168
  # Display results in a table
169
  st.write('Ground Truth vs Predicted')
170
  results = pd.DataFrame({
171
- 'Ground Truth': [LABEL.vocab.itos[label] for label in all_labels],
172
- 'Predicted': [LABEL.vocab.itos[label] for label in all_predictions]
173
  })
174
  st.table(results.head(50)) # Display first 50 results for brevity
175
 
 
52
  vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
53
  vocab.set_default_index(vocab["<unk>"])
54
 
55
+ # Define the text and label processing pipelines globally
56
  global text_pipeline, label_pipeline
57
  text_pipeline = lambda x: vocab(tokenizer(x))
58
  label_pipeline = lambda x: int(x) - 1
 
81
  for texts, labels in iterator:
82
  texts, labels = texts.to(device), labels.to(device)
83
  optimizer.zero_grad()
84
+ predictions = net(texts)
85
  loss = criterion(predictions, labels)
86
  loss.backward()
87
  optimizer.step()
 
103
  with torch.no_grad():
104
  for texts, labels in iterator:
105
  texts, labels = texts.to(device), labels.to(device)
106
+ predictions = net(texts)
107
  loss = criterion(predictions, labels)
108
  epoch_loss += loss.item()
109
  _, predicted = torch.max(predictions, 1)
 
169
  # Display results in a table
170
  st.write('Ground Truth vs Predicted')
171
  results = pd.DataFrame({
172
+ 'Ground Truth': all_labels,
173
+ 'Predicted': all_predictions
174
  })
175
  st.table(results.head(50)) # Display first 50 results for brevity
176