File size: 669 Bytes
986f758 b99e299 a041fbe b99e299 32db49c 986f758 32db49c 986f758 32db49c a041fbe 32db49c a041fbe b99e299 986f758 32db49c 986f758 b99e299 986f758 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
from transformers import AutoConfig, AutoModel
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
print("Registering model...")
AutoConfig.register("sagvit", SAGViTConfig)
AutoModel.register(SAGViTConfig, SAGViTClassifier)
print("Registered model...")
# Load config and model
config = SAGViTConfig()
model = SAGViTClassifier(config)
# Load the state dict into the model
print("Loading model weights...")
state_dict = torch.load('SAG-ViT.pth')
model.load_state_dict(state_dict)
print("Loaded model weights...")
# Push model and code
model.save_pretrained('.')
model.push_to_hub("shravvvv/SAG-ViT")
print("Pushed model to Hugging Face hub...")
|