SAG-ViT / push_model_to_hfhub.py
shravvvv's picture
Updated stuff
986f758
raw
history blame contribute delete
669 Bytes
import torch
from transformers import AutoConfig, AutoModel
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
print("Registering model...")
AutoConfig.register("sagvit", SAGViTConfig)
AutoModel.register(SAGViTConfig, SAGViTClassifier)
print("Registered model...")
# Load config and model
config = SAGViTConfig()
model = SAGViTClassifier(config)
# Load the state dict into the model
print("Loading model weights...")
state_dict = torch.load('SAG-ViT.pth')
model.load_state_dict(state_dict)
print("Loaded model weights...")
# Push model and code
model.save_pretrained('.')
model.push_to_hub("shravvvv/SAG-ViT")
print("Pushed model to Hugging Face hub...")