Update app.py
Browse files
app.py
CHANGED
@@ -32,11 +32,9 @@ def start_train(status_box):
|
|
32 |
# Handles CUDA OOM errors.
|
33 |
os.system(f"export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
|
34 |
os.system("echo 'Okay, trying training.'")
|
35 |
-
status_box.value = "Status: Training"
|
36 |
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="TRAINING")
|
37 |
os.system(f"cd pytorch-image-models; ./train.sh 4 --dataset hfds/datacomp/imagenet-1k-random-10.0-frac-1over4 --log-wandb --wandb-project ImageNetTraining10.0-frac-1over4 --experiment ImageNetTraining10.0-frac-1over4 --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4")
|
38 |
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="IDLE")
|
39 |
-
status_box.value = "Status: Idle"
|
40 |
os.system("echo 'Done'.")
|
41 |
os.system("ls")
|
42 |
# Upload output to repository
|
@@ -54,7 +52,7 @@ def run():
|
|
54 |
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="IDLE")
|
55 |
status_box.value = "Status: Idle"
|
56 |
start = gr.Button("Start")
|
57 |
-
start.click(start_train, inputs=status_box)
|
58 |
else:
|
59 |
status_box.value = "Status: Training"
|
60 |
app.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
32 |
# Handles CUDA OOM errors.
|
33 |
os.system(f"export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
|
34 |
os.system("echo 'Okay, trying training.'")
|
|
|
35 |
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="TRAINING")
|
36 |
os.system(f"cd pytorch-image-models; ./train.sh 4 --dataset hfds/datacomp/imagenet-1k-random-10.0-frac-1over4 --log-wandb --wandb-project ImageNetTraining10.0-frac-1over4 --experiment ImageNetTraining10.0-frac-1over4 --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4")
|
37 |
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="IDLE")
|
|
|
38 |
os.system("echo 'Done'.")
|
39 |
os.system("ls")
|
40 |
# Upload output to repository
|
|
|
52 |
API.add_space_variable(repo_id=experiment_repo, key="STATUS", value="IDLE")
|
53 |
status_box.value = "Status: Idle"
|
54 |
start = gr.Button("Start")
|
55 |
+
start.click(start_train, inputs=status_box, outputs=status_box)
|
56 |
else:
|
57 |
status_box.value = "Status: Training"
|
58 |
app.launch(server_name="0.0.0.0", server_port=7860)
|