eaglelandsonce commited on
Commit
44c0170
·
verified ·
1 Parent(s): 8ea4159

Update pages/21_NLP_Transformer.py

Browse files
Files changed (1) hide show
  1. pages/21_NLP_Transformer.py +2 -1
pages/21_NLP_Transformer.py CHANGED
@@ -29,6 +29,7 @@ def load_and_preprocess_data():
29
  return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
30
  encoded_dataset = dataset.map(preprocess_function, batched=True)
31
  encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
 
32
  return DataLoader(encoded_dataset, shuffle=True, batch_size=batch_size)
33
 
34
  train_dataloader = load_and_preprocess_data()
@@ -48,7 +49,7 @@ if st.sidebar.button("Train"):
48
  for epoch in range(num_epochs):
49
  for batch in train_dataloader:
50
  batch = {k: v.to(device) for k, v in batch.items()}
51
- outputs = model(**batch, labels=batch["labels"]) # Fixed input key
52
  loss = outputs.loss
53
  loss.backward()
54
 
 
29
  return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
30
  encoded_dataset = dataset.map(preprocess_function, batched=True)
31
  encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
32
+ encoded_dataset = encoded_dataset.rename_column("label", "labels") # Rename the column to 'labels'
33
  return DataLoader(encoded_dataset, shuffle=True, batch_size=batch_size)
34
 
35
  train_dataloader = load_and_preprocess_data()
 
49
  for epoch in range(num_epochs):
50
  for batch in train_dataloader:
51
  batch = {k: v.to(device) for k, v in batch.items()}
52
+ outputs = model(**batch) # No need to pass labels explicitly if they are in the batch dictionary
53
  loss = outputs.loss
54
  loss.backward()
55