resnet-train / upload_to_hub.py
Sreekanth Tangirala
remove pth from tracking
9c8dfb8
from huggingface_hub import HfApi, login
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import json
import os
def upload_model_to_hub(
model_path: str,
repo_name: str,
token: str,
num_labels: int,
label2id: dict,
id2label: dict,
model_architecture: str = "resnet50",
task: str = "image-classification",
):
"""
Upload a PyTorch model to Hugging Face Hub with proper configuration
"""
# Login to Hugging Face
login(token=token)
api = HfApi()
# Create the repository
repo_url = api.create_repo(
repo_id=repo_name,
exist_ok=True,
private=False
)
# Create config.json
config = {
"architectures": ["ResNetForImageClassification"],
"model_type": "resnet",
"num_labels": num_labels,
"id2label": id2label,
"label2id": label2id,
"num_channels": 3,
"hidden_sizes": [2048],
"image_size": [224, 224]
}
# Create feature extractor config
feature_extractor = {
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225],
"do_normalize": True,
"do_resize": True,
"size": 224,
"resample": 2
}
# Upload config files
api.upload_file(
path_or_fileobj=json.dumps(config).encode(),
path_in_repo="config.json",
repo_id=repo_name,
commit_message="Upload model config"
)
api.upload_file(
path_or_fileobj=json.dumps(feature_extractor).encode(),
path_in_repo="preprocessor_config.json",
repo_id=repo_name,
commit_message="Upload preprocessor config"
)
# Upload the model file
api.upload_file(
path_or_fileobj=model_path,
path_in_repo="pytorch_model.bin",
repo_id=repo_name,
commit_message="Upload model weights"
)
# Create and upload model card
model_card = f"""
---
language: en
tags:
- pytorch
- {model_architecture}
- {task}
---
# Model Card for {repo_name}
This model is a fine-tuned version of {model_architecture} for {task}.
"""
api.upload_file(
path_or_fileobj=model_card.encode(),
path_in_repo="README.md",
repo_id=repo_name,
commit_message="Upload model card"
)
print(f"Model uploaded successfully to: https://huggingface.co/{repo_name}")
if __name__ == "__main__":
# Get Hugging Face token from environment variable
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("Please set the HF_TOKEN environment variable")
# Example label mappings - replace with your actual labels
label2id = {
"class1": 0,
"class2": 1,
# ... add all your classes
}
id2label = {str(v): k for k, v in label2id.items()}
# Upload the model
upload_model_to_hub(
model_path="best_model.pth",
repo_name="srtangirala/resnet50-exp",
token=token,
num_labels=len(label2id),
label2id=label2id,
id2label=id2label
)