Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,20 @@ from typing import Union
|
|
4 |
from huggingface_hub import whoami, HfApi
|
5 |
from fastapi import FastAPI
|
6 |
from starlette.middleware.sessions import SessionMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
is_spaces = True if os.environ.get("SPACE_ID") else False
|
9 |
|
@@ -201,6 +215,14 @@ def start_training(
|
|
201 |
print("Started training")
|
202 |
slugged_lora_name = slugify(lora_name)
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
# Load the default config
|
205 |
with open("train_lora_flux_24gb.yaml", "r") as f:
|
206 |
config = yaml.safe_load(f)
|
|
|
4 |
from huggingface_hub import whoami, HfApi
|
5 |
from fastapi import FastAPI
|
6 |
from starlette.middleware.sessions import SessionMiddleware
|
7 |
+
import sys
|
8 |
+
|
9 |
+
# ai-toolkit이 없으면 설치
|
10 |
+
if not os.path.exists("ai-toolkit"):
|
11 |
+
subprocess.run("git clone https://github.com/ostris/ai-toolkit.git", shell=True)
|
12 |
+
subprocess.run("cd ai-toolkit && git submodule update --init --recursive", shell=True)
|
13 |
+
|
14 |
+
# ai-toolkit 경로 추가
|
15 |
+
toolkit_path = os.path.join(os.getcwd(), "ai-toolkit")
|
16 |
+
sys.path.append(toolkit_path)
|
17 |
+
|
18 |
+
# 필요한 패키지 설치
|
19 |
+
subprocess.run("pip install -r ai-toolkit/requirements.txt", shell=True)
|
20 |
+
|
21 |
|
22 |
is_spaces = True if os.environ.get("SPACE_ID") else False
|
23 |
|
|
|
215 |
print("Started training")
|
216 |
slugged_lora_name = slugify(lora_name)
|
217 |
|
218 |
+
try:
|
219 |
+
from toolkit.job import get_job
|
220 |
+
except ImportError:
|
221 |
+
raise gr.Error("Failed to import toolkit. Please check if ai-toolkit is properly installed.")
|
222 |
+
|
223 |
+
print("Started training")
|
224 |
+
slugged_lora_name = slugify(lora_name)
|
225 |
+
|
226 |
# Load the default config
|
227 |
with open("train_lora_flux_24gb.yaml", "r") as f:
|
228 |
config = yaml.safe_load(f)
|