|
import torch |
|
from transformers import AutoConfig, AutoModel |
|
from modeling_sagvit import SAGViTClassifier, SAGViTConfig |
|
|
|
|
|
print("Registering model...") |
|
AutoConfig.register("sagvit", SAGViTConfig) |
|
AutoModel.register(SAGViTConfig, SAGViTClassifier) |
|
print("Registered model...") |
|
|
|
|
|
config = SAGViTConfig() |
|
model = SAGViTClassifier(config) |
|
|
|
|
|
print("Loading model weights...") |
|
state_dict = torch.load('SAG-ViT.pth') |
|
model.load_state_dict(state_dict) |
|
print("Loaded model weights...") |
|
|
|
|
|
model.save_pretrained('.') |
|
model.push_to_hub("shravvvv/SAG-ViT") |
|
|
|
print("Pushed model to Hugging Face hub...") |
|
|