Spaces:
Runtime error
Runtime error
from transformers import Trainer | |
import torch.nn.functional as F | |
from typing import Optional | |
import os | |
import torch | |
from transformers.utils import WEIGHTS_NAME | |
import json | |
class VideoBaseTrainer(Trainer): | |
def _save(self, output_dir: Optional[str] = None, state_dict=None): | |
output_dir = output_dir if output_dir is not None else self.args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
if state_dict is None: | |
state_dict = self.model.state_dict() | |
# get model config | |
model_config = self.model.config.to_dict() | |
# add more information | |
model_config['model'] = self.model.__class__.__name__ | |
with open(os.path.join(output_dir, "config.json"), "w") as file: | |
json.dump(self.model.config.to_dict(), file) | |
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | |
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) | |