from huggingface_hub import HfApi, login
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import json
import os

def upload_model_to_hub(
    model_path: str,
    repo_name: str,
    token: str,
    num_labels: int,
    label2id: dict,
    id2label: dict,
    model_architecture: str = "resnet50",
    task: str = "image-classification",
):
    """
    Upload a PyTorch model to Hugging Face Hub with proper configuration
    """
    # Login to Hugging Face
    login(token=token)
    api = HfApi()
    
    # Create the repository
    repo_url = api.create_repo(
        repo_id=repo_name,
        exist_ok=True,
        private=False
    )

    # Create config.json
    config = {
        "architectures": ["ResNetForImageClassification"],
        "model_type": "resnet",
        "num_labels": num_labels,
        "id2label": id2label,
        "label2id": label2id,
        "num_channels": 3,
        "hidden_sizes": [2048],
        "image_size": [224, 224]
    }

    # Create feature extractor config
    feature_extractor = {
        "image_mean": [0.485, 0.456, 0.406],
        "image_std": [0.229, 0.224, 0.225],
        "do_normalize": True,
        "do_resize": True,
        "size": 224,
        "resample": 2
    }
    
    # Upload config files
    api.upload_file(
        path_or_fileobj=json.dumps(config).encode(),
        path_in_repo="config.json",
        repo_id=repo_name,
        commit_message="Upload model config"
    )

    api.upload_file(
        path_or_fileobj=json.dumps(feature_extractor).encode(),
        path_in_repo="preprocessor_config.json",
        repo_id=repo_name,
        commit_message="Upload preprocessor config"
    )

    # Upload the model file
    api.upload_file(
        path_or_fileobj=model_path,
        path_in_repo="pytorch_model.bin",
        repo_id=repo_name,
        commit_message="Upload model weights"
    )
    
    # Create and upload model card
    model_card = f"""
    ---
    language: en
    tags:
    - pytorch
    - {model_architecture}
    - {task}
    ---
    
    # Model Card for {repo_name}
    
    This model is a fine-tuned version of {model_architecture} for {task}.
    """
    
    api.upload_file(
        path_or_fileobj=model_card.encode(),
        path_in_repo="README.md",
        repo_id=repo_name,
        commit_message="Upload model card"
    )
    
    print(f"Model uploaded successfully to: https://huggingface.co/{repo_name}")

if __name__ == "__main__":
    # Get Hugging Face token from environment variable
    token = os.getenv("HF_TOKEN")
    if not token:
        raise ValueError("Please set the HF_TOKEN environment variable")
    
    # Example label mappings - replace with your actual labels
    label2id = {
        "class1": 0,
        "class2": 1,
        # ... add all your classes
    }
    id2label = {str(v): k for k, v in label2id.items()}
    
    # Upload the model
    upload_model_to_hub(
        model_path="best_model.pth",
        repo_name="srtangirala/resnet50-exp",
        token=token,
        num_labels=len(label2id),
        label2id=label2id,
        id2label=id2label
    )