File size: 1,736 Bytes
8933ee4
1b96548
 
 
 
8933ee4
1b96548
 
 
 
8933ee4
1b96548
 
8933ee4
 
 
 
 
 
 
 
 
 
 
 
 
1b96548
8933ee4
 
 
 
 
 
1b96548
8933ee4
 
1b96548
8933ee4
 
1b96548
8933ee4
 
 
1b96548
8933ee4
 
1b96548
8933ee4
 
 
 
1b96548
8933ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory

import click
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file

from msma import EDMScorer, ScoreFlow, build_model_from_pickle


@click.command
@click.option(
    "--basedir",
    help="Directory holding the model weights and logs",
    type=str,
    required=True,
)
@click.option(
    "--preset", help="Preset of the score model used", type=str, required=True
)
def main(basedir, preset):
    basedir = Path(basedir)
    modeldir = basedir / preset

    net = build_model_from_pickle(preset)
    model = ScoreFlow(
        net,
        num_flows=8,
    )
    model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))

    api = HfApi()
    repo_name = "ahsanMah/localizing-edm"

    # Create repo if not existing yet and get the associated repo_id
    repo_id = api.create_repo(repo_name, exist_ok=True).repo_id

    # Save all files in a temporary directory and push them in a single commit
    with TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)

        # Save weights
        save_file(model.state_dict(), tmpdir / "model.safetensors")

        # save config
        (tmpdir / "config.json").write_text(json.dumps(model.config, sort_keys=True, indent=4))
        
        # TODO: save gmm and cached score norms


        # Generate model card
        # card = generate_model_card(model)
        # (tmpdir / "README.md").write_text(card)

        # Save logs
        shutil.copytree(modeldir / "logs", tmpdir / "logs")
        
        
        # Push to hub
        api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)


if __name__ == "__main__":
    main()