Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,11 @@ def start_train():
|
|
32 |
# Handles CUDA OOM errors.
|
33 |
os.system(f"export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
|
34 |
os.system("echo 'Okay, trying training.'")
|
|
|
|
|
35 |
os.system(f"cd pytorch-image-models; ./train.sh 4 --dataset hfds/datacomp/imagenet-1k-random-10.0-frac-1over4 --log-wandb --wandb-project ImageNetTraining10.0-frac-1over4 --experiment ImageNetTraining10.0-frac-1over4 --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4")
|
|
|
|
|
36 |
os.system("echo 'Done'.")
|
37 |
os.system("ls")
|
38 |
# Upload output to repository
|
@@ -44,8 +48,15 @@ def run():
|
|
44 |
with gr.Blocks() as app:
|
45 |
gr.Markdown(f"Randomization: 10.0")
|
46 |
gr.Markdown(f"Subset: frac-1over4")
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
app.launch(server_name="0.0.0.0", server_port=7860)
|
50 |
|
51 |
if __name__ == '__main__':
|
|
|
32 |
# Handles CUDA OOM errors.
|
33 |
os.system(f"export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
|
34 |
os.system("echo 'Okay, trying training.'")
|
35 |
+
status_box.value = "Status: Training"
|
36 |
+
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="TRAINING")
|
37 |
os.system(f"cd pytorch-image-models; ./train.sh 4 --dataset hfds/datacomp/imagenet-1k-random-10.0-frac-1over4 --log-wandb --wandb-project ImageNetTraining10.0-frac-1over4 --experiment ImageNetTraining10.0-frac-1over4 --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4")
|
38 |
+
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="IDLE")
|
39 |
+
status_box.value = "Status: Idle"
|
40 |
os.system("echo 'Done'.")
|
41 |
os.system("ls")
|
42 |
# Upload output to repository
|
|
|
48 |
with gr.Blocks() as app:
|
49 |
gr.Markdown(f"Randomization: 10.0")
|
50 |
gr.Markdown(f"Subset: frac-1over4")
|
51 |
+
status_box = gr.Textbox()
|
52 |
+
space_variables = API.get_space_variables(repo_id=experiment_repo)
|
53 |
+
if 'STATUS' not in space_variables or space_variables['STATUS'].value != 'TRAINING':
|
54 |
+
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="IDLE")
|
55 |
+
status_box.value = "Status: Idle"
|
56 |
+
start = gr.Button("Start")
|
57 |
+
start.click(start_train, inputs=status_box)
|
58 |
+
else:
|
59 |
+
status_box.value = "Status: Training"
|
60 |
app.launch(server_name="0.0.0.0", server_port=7860)
|
61 |
|
62 |
if __name__ == '__main__':
|