VoyagerYuan commited on
Commit
e13d6a0
·
1 Parent(s): 5e056d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -234,8 +234,12 @@ def train_signal_game(NUM_SENDERS, NUM_RECEIVERS, num_rounds):
234
  loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
235
  interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
236
 
237
- for round in range(num_rounds):
238
- states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
 
 
 
 
239
  loss, recon_loss, kld_loss, interactions = game.play_round(states)
240
  losses.append(loss)
241
  recon_losses.append(recon_loss)
 
234
  loss_plot_placeholder = st.empty() # 创建一个空位占位符来显示损失图
235
  interactions_placeholder = st.empty() # 创建一个空位占位符来显示交互
236
 
237
+ for round, batch in enumerate(train_dataloader):
238
+ if round >= num_rounds:
239
+ break
240
+ states = [batch.to(device) for _ in range(NUM_SENDERS)]
241
+ # for round in range(num_rounds):
242
+ # states = [torch.randint(VOCAB_SIZE, (BATCH_SIZE, 16)).to(device) for _ in range(NUM_SENDERS)]
243
  loss, recon_loss, kld_loss, interactions = game.play_round(states)
244
  losses.append(loss)
245
  recon_losses.append(recon_loss)