Spaces:
Running
Running
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from transformers import TrainingArguments, Trainer | |
import os | |
import torch | |
# Load dataset | |
ds = load_dataset("knkarthick/dialogsum") | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") | |
# Preprocessing function | |
def preprocess_function(batch): | |
source = batch['dialogue'] | |
target = batch['summary'] | |
source_enc = tokenizer(source, padding='max_length', truncation=True, max_length=128) | |
target_enc = tokenizer(target, padding='max_length', truncation=True, max_length=128) | |
labels = target_enc['input_ids'] | |
labels = [[(token if token != tokenizer.pad_token_id else -100) for token in label] for label in labels] | |
return { | |
'input_ids': source_enc['input_ids'], | |
'attention_mask': source_enc['attention_mask'], | |
'labels': labels | |
} | |
# Apply preprocessing | |
df_source = ds.map(preprocess_function, batched=True) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir='/content/TextSummarizer_output', | |
per_device_train_batch_size=8, | |
num_train_epochs=2, | |
save_total_limit=1, | |
save_strategy="epoch", | |
remove_unused_columns=True, | |
logging_dir='/content/logs', | |
logging_steps=50, | |
) | |
# Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=df_source['train'], | |
eval_dataset=df_source['test'], | |
) | |
# Train | |
trainer.train() | |
# Evaluate | |
eval_results = trainer.evaluate() | |
print("Evaluation Results:", eval_results) | |
# ===> Save to Google Drive path | |
save_path = "/content/drive/MyDrive/TextSummarizer2/model_directory" | |
os.makedirs(save_path, exist_ok=True) | |
# Save model and tokenizer (use safe_serialization for large model.safetensors) | |
model.save_pretrained(save_path, safe_serialization=True) | |
tokenizer.save_pretrained(save_path) | |
print(f"β Model and tokenizer saved to: {save_path}") | |
print("π¦ Files saved:", os.listdir(save_path)) | |