Spaces:
Running
Running
Update pages/21_NLP_Transformer.py
Browse files
pages/21_NLP_Transformer.py
CHANGED
@@ -29,6 +29,7 @@ def load_and_preprocess_data():
|
|
29 |
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
|
30 |
encoded_dataset = dataset.map(preprocess_function, batched=True)
|
31 |
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
|
|
32 |
return DataLoader(encoded_dataset, shuffle=True, batch_size=batch_size)
|
33 |
|
34 |
train_dataloader = load_and_preprocess_data()
|
@@ -48,7 +49,7 @@ if st.sidebar.button("Train"):
|
|
48 |
for epoch in range(num_epochs):
|
49 |
for batch in train_dataloader:
|
50 |
batch = {k: v.to(device) for k, v in batch.items()}
|
51 |
-
outputs = model(**batch
|
52 |
loss = outputs.loss
|
53 |
loss.backward()
|
54 |
|
|
|
29 |
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
|
30 |
encoded_dataset = dataset.map(preprocess_function, batched=True)
|
31 |
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
32 |
+
encoded_dataset = encoded_dataset.rename_column("label", "labels") # Rename the column to 'labels'
|
33 |
return DataLoader(encoded_dataset, shuffle=True, batch_size=batch_size)
|
34 |
|
35 |
train_dataloader = load_and_preprocess_data()
|
|
|
49 |
for epoch in range(num_epochs):
|
50 |
for batch in train_dataloader:
|
51 |
batch = {k: v.to(device) for k, v in batch.items()}
|
52 |
+
outputs = model(**batch) # No need to pass labels explicitly if they are in the batch dictionary
|
53 |
loss = outputs.loss
|
54 |
loss.backward()
|
55 |
|