File size: 669 Bytes
986f758
b99e299
a041fbe
b99e299
32db49c
986f758
32db49c
 
986f758
32db49c
a041fbe
32db49c
a041fbe
b99e299
986f758
 
 
 
 
 
32db49c
986f758
b99e299
986f758
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from transformers import AutoConfig, AutoModel
from modeling_sagvit import SAGViTClassifier, SAGViTConfig


print("Registering model...")
AutoConfig.register("sagvit", SAGViTConfig)
AutoModel.register(SAGViTConfig, SAGViTClassifier)
print("Registered model...")

# Load config and model
config = SAGViTConfig()
model = SAGViTClassifier(config)

# Load the state dict into the model
print("Loading model weights...")
state_dict = torch.load('SAG-ViT.pth')
model.load_state_dict(state_dict)
print("Loaded model weights...")

# Push model and code
model.save_pretrained('.')
model.push_to_hub("shravvvv/SAG-ViT")

print("Pushed model to Hugging Face hub...")