eaglelandsonce commited on
Commit
5a1cec5
·
verified ·
1 Parent(s): 44c0170

Update pages/21_NLP_Transformer.py

Browse files
Files changed (1) hide show
  1. pages/21_NLP_Transformer.py +5 -2
pages/21_NLP_Transformer.py CHANGED
@@ -34,6 +34,9 @@ def load_and_preprocess_data():
34
 
35
  train_dataloader = load_and_preprocess_data()
36
 
 
 
 
37
  # Training loop
38
  if st.sidebar.button("Train"):
39
  optimizer = AdamW(model.parameters(), lr=learning_rate)
@@ -49,7 +52,7 @@ if st.sidebar.button("Train"):
49
  for epoch in range(num_epochs):
50
  for batch in train_dataloader:
51
  batch = {k: v.to(device) for k, v in batch.items()}
52
- outputs = model(**batch) # No need to pass labels explicitly if they are in the batch dictionary
53
  loss = outputs.loss
54
  loss.backward()
55
 
@@ -59,7 +62,7 @@ if st.sidebar.button("Train"):
59
  progress_bar.update(1)
60
  loss_values.append(loss.item())
61
 
62
- st.sidebar.success("Training completed")
63
 
64
  # Plot loss values
65
  st.write("### Training Loss")
 
34
 
35
  train_dataloader = load_and_preprocess_data()
36
 
37
+ # Initialize training status
38
+ training_completed = st.sidebar.empty()
39
+
40
  # Training loop
41
  if st.sidebar.button("Train"):
42
  optimizer = AdamW(model.parameters(), lr=learning_rate)
 
52
  for epoch in range(num_epochs):
53
  for batch in train_dataloader:
54
  batch = {k: v.to(device) for k, v in batch.items()}
55
+ outputs = model(**batch)
56
  loss = outputs.loss
57
  loss.backward()
58
 
 
62
  progress_bar.update(1)
63
  loss_values.append(loss.item())
64
 
65
+ training_completed.success("Training completed")
66
 
67
  # Plot loss values
68
  st.write("### Training Loss")