File size: 5,514 Bytes
38548f2
 
 
 
 
83ea110
b4b3999
38548f2
 
 
83ea110
 
 
38548f2
83ea110
 
38548f2
6db8e8a
38548f2
83ea110
38548f2
 
 
 
 
83ea110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38548f2
 
 
 
 
b4b3999
1d37aeb
 
 
 
 
 
b4b3999
 
 
83ea110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38548f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ea110
 
 
 
 
 
38548f2
 
 
 
 
 
b4b3999
 
 
 
 
 
 
 
 
 
 
 
 
 
83ea110
 
 
 
 
38548f2
 
 
 
 
 
 
 
 
 
 
 
83ea110
 
 
 
 
 
38548f2
 
 
 
 
 
b4b3999
 
 
 
 
 
83ea110
 
 
 
 
38548f2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from glob import glob
import os
import shutil
import gradio as gr
from infer.lib.train.process_ckpt import extract_small_model
from app.train import train_index
from huggingface_hub import upload_folder


def download_weight(exp_dir: str) -> str:
    checkpoints = glob(f"{exp_dir}/G_*.pth")
    if not checkpoints:
        raise gr.Error("No checkpoint found")

    latest_checkpoint = max(checkpoints, key=os.path.getctime)
    print(f"Latest checkpoint: {latest_checkpoint}")

    out = os.path.join(exp_dir, f"model.pth")
    extract_small_model(
        latest_checkpoint, out, "40k", True, "Model trained by ZeroGPU.", "v2"
    )

    return out


def download_inference_pack(exp_dir: str) -> str:
    net_g = download_weight(exp_dir)
    index = glob(f"{exp_dir}/added_*.index")
    if not index:
        train_index(exp_dir)
    index = glob(f"{exp_dir}/added_*.index")
    if not index:
        raise gr.Error("Index not found")

    # make zip of those two files
    tmp = os.path.join(exp_dir, "inference_pack")
    if os.path.exists(tmp):
        shutil.rmtree(tmp)
    os.makedirs(tmp)
    shutil.copy(net_g, tmp)
    shutil.copy(index[0], tmp)
    shutil.make_archive(tmp, "zip", tmp)
    shutil.rmtree(tmp)

    return f"{tmp}.zip"


def download_expdir(exp_dir: str) -> str:
    shutil.make_archive(exp_dir, "zip", exp_dir)
    return f"{exp_dir}.zip"


def upload_to_huggingface(exp_dir: str, repo_id: str, token: str) -> str:
    commit = upload_folder(
        repo_id=repo_id,
        folder_path=exp_dir,
        ignore_patterns=["_data", "*.zip", "tmp.wav"],
        token=token if token.startswith("hf_") else None,
    )
    return commit.commit_url


def remove_legacy_checkpoints(exp_dir: str):
    checkpoints = glob(f"{exp_dir}/G_*.pth")
    if not checkpoints:
        raise gr.Error("No checkpoint found")

    latest_checkpoint = max(checkpoints, key=os.path.getctime)
    print(f"Latest checkpoint: {latest_checkpoint}")
    for checkpoint in checkpoints:
        if checkpoint != latest_checkpoint:
            os.remove(checkpoint)
            print(f"Removed: {checkpoint}")

    checkpoints = glob(f"{exp_dir}/D_*.pth")
    if not checkpoints:
        raise gr.Error("No checkpoint found")

    latest_checkpoint = max(checkpoints, key=os.path.getctime)
    print(f"Latest checkpoint: {latest_checkpoint}")
    for checkpoint in checkpoints:
        if checkpoint != latest_checkpoint:
            os.remove(checkpoint)
            print(f"Removed: {checkpoint}")


def remove_expdir(exp_dir: str) -> str:
    shutil.rmtree(exp_dir)
    return ""


class ExportTab:
    def __init__(self):
        pass

    def ui(self):
        gr.Markdown("# Download Model or Experiment Directory")
        gr.Markdown(
            "You can download the latest model or the entire experiment directory here."
        )

        with gr.Row():
            self.download_weight_btn = gr.Button(
                value="Latest model (for inferencing)", variant="primary"
            )
            self.download_weight_output = gr.File(label="Prune latest model")

        with gr.Row():
            self.download_inference_pack_btn = gr.Button(
                value="Download inference pack (model + index)", variant="primary"
            )
            self.download_inference_pack_output = gr.File(label="Inference pack")

        with gr.Row():
            self.download_expdir_btn = gr.Button(
                value="Download experiment directory", variant="primary"
            )
            self.download_expdir_output = gr.File(label="Archive experiment directory")

        with gr.Row():
            with gr.Column():
                gr.Markdown("### Upload to Hugging Face")
                gr.Markdown(
                    "You can upload the entire experiment directory to Hugging Face."
                )
                self.commit_link = gr.Markdown("")
            with gr.Column():
                self.repo_id = gr.Textbox(label="Repository ID")
                self.token = gr.Textbox(label="Personal access token")
            self.upload_to_huggingface_btn = gr.Button(
                value="Upload to Hugging Face", variant="primary"
            )

        with gr.Row():
            self.remove_legacy_checkpoints_btn = gr.Button(
                value="Remove legacy checkpoints"
            )

        with gr.Row():
            self.remove_expdir_btn = gr.Button(
                value="REMOVE experiment directory", variant="stop"
            )

    def build(self, exp_dir: gr.Textbox):
        self.download_weight_btn.click(
            fn=download_weight,
            inputs=[exp_dir],
            outputs=[self.download_weight_output],
        )

        self.download_inference_pack_btn.click(
            fn=download_inference_pack,
            inputs=[exp_dir],
            outputs=[self.download_inference_pack_output],
        )

        self.download_expdir_btn.click(
            fn=download_expdir,
            inputs=[exp_dir],
            outputs=[self.download_expdir_output],
        )

        self.upload_to_huggingface_btn.click(
            fn=upload_to_huggingface,
            inputs=[exp_dir, self.repo_id, self.token],
            outputs=[self.commit_link],
        )

        self.remove_legacy_checkpoints_btn.click(
            fn=remove_legacy_checkpoints,
            inputs=[exp_dir],
        )

        self.remove_expdir_btn.click(
            fn=remove_expdir,
            inputs=[exp_dir],
            outputs=[exp_dir],
        )