localizing-anomalies / push_to_hf.py
ahsanMah's picture
utility script for pushing to HF hub
8933ee4
raw
history blame
1.74 kB
import json
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
import click
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from msma import EDMScorer, ScoreFlow, build_model_from_pickle
@click.command
@click.option(
"--basedir",
help="Directory holding the model weights and logs",
type=str,
required=True,
)
@click.option(
"--preset", help="Preset of the score model used", type=str, required=True
)
def main(basedir, preset):
basedir = Path(basedir)
modeldir = basedir / preset
net = build_model_from_pickle(preset)
model = ScoreFlow(
net,
num_flows=8,
)
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
api = HfApi()
repo_name = "ahsanMah/localizing-edm"
# Create repo if not existing yet and get the associated repo_id
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
# Save all files in a temporary directory and push them in a single commit
with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
# Save weights
save_file(model.state_dict(), tmpdir / "model.safetensors")
# save config
(tmpdir / "config.json").write_text(json.dumps(model.config, sort_keys=True, indent=4))
# TODO: save gmm and cached score norms
# Generate model card
# card = generate_model_card(model)
# (tmpdir / "README.md").write_text(card)
# Save logs
shutil.copytree(modeldir / "logs", tmpdir / "logs")
# Push to hub
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
if __name__ == "__main__":
main()