Spaces:
Runtime error
Runtime error
Pause Space
Browse files- app_training.py +3 -3
- trainer.py +4 -6
app_training.py
CHANGED
@@ -105,8 +105,8 @@ def create_training_demo(trainer: Trainer,
|
|
105 |
choices=[_.value for _ in UploadTarget],
|
106 |
value=UploadTarget.MODEL_LIBRARY.value)
|
107 |
|
108 |
-
|
109 |
-
label='
|
110 |
value=False,
|
111 |
interactive=bool(os.getenv('SPACE_ID')),
|
112 |
visible=False)
|
@@ -143,7 +143,7 @@ def create_training_demo(trainer: Trainer,
|
|
143 |
use_private_repo,
|
144 |
delete_existing_repo,
|
145 |
upload_to,
|
146 |
-
|
147 |
hf_token,
|
148 |
])
|
149 |
return demo
|
|
|
105 |
choices=[_.value for _ in UploadTarget],
|
106 |
value=UploadTarget.MODEL_LIBRARY.value)
|
107 |
|
108 |
+
pause_space_after_training = gr.Checkbox(
|
109 |
+
label='Pause this Space after training',
|
110 |
value=False,
|
111 |
interactive=bool(os.getenv('SPACE_ID')),
|
112 |
visible=False)
|
|
|
143 |
use_private_repo,
|
144 |
delete_existing_repo,
|
145 |
upload_to,
|
146 |
+
pause_space_after_training,
|
147 |
hf_token,
|
148 |
])
|
149 |
return demo
|
trainer.py
CHANGED
@@ -60,7 +60,7 @@ class Trainer:
|
|
60 |
use_private_repo: bool,
|
61 |
delete_existing_repo: bool,
|
62 |
upload_to: str,
|
63 |
-
|
64 |
hf_token: str,
|
65 |
) -> None:
|
66 |
if not torch.cuda.is_available():
|
@@ -140,9 +140,7 @@ class Trainer:
|
|
140 |
with open(self.log_file, 'a') as f:
|
141 |
f.write(upload_message)
|
142 |
|
143 |
-
if
|
144 |
-
space_id
|
145 |
-
if space_id:
|
146 |
api = HfApi(token=os.getenv('HF_TOKEN') or hf_token)
|
147 |
-
api.
|
148 |
-
hardware='cpu-basic')
|
|
|
60 |
use_private_repo: bool,
|
61 |
delete_existing_repo: bool,
|
62 |
upload_to: str,
|
63 |
+
pause_space_after_training: bool,
|
64 |
hf_token: str,
|
65 |
) -> None:
|
66 |
if not torch.cuda.is_available():
|
|
|
140 |
with open(self.log_file, 'a') as f:
|
141 |
f.write(upload_message)
|
142 |
|
143 |
+
if pause_space_after_training:
|
144 |
+
if space_id := os.getenv('SPACE_ID'):
|
|
|
145 |
api = HfApi(token=os.getenv('HF_TOKEN') or hf_token)
|
146 |
+
api.pause_space(repo_id=space_id)
|
|