openfree commited on
Commit
a294a85
·
verified ·
1 Parent(s): a194156

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -0
app.py CHANGED
@@ -4,6 +4,20 @@ from typing import Union
4
  from huggingface_hub import whoami, HfApi
5
  from fastapi import FastAPI
6
  from starlette.middleware.sessions import SessionMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  is_spaces = True if os.environ.get("SPACE_ID") else False
9
 
@@ -201,6 +215,14 @@ def start_training(
201
  print("Started training")
202
  slugged_lora_name = slugify(lora_name)
203
 
 
 
 
 
 
 
 
 
204
  # Load the default config
205
  with open("train_lora_flux_24gb.yaml", "r") as f:
206
  config = yaml.safe_load(f)
 
4
  from huggingface_hub import whoami, HfApi
5
  from fastapi import FastAPI
6
  from starlette.middleware.sessions import SessionMiddleware
7
+ import sys
8
+
9
+ # ai-toolkit이 없으면 설치
10
+ if not os.path.exists("ai-toolkit"):
11
+ subprocess.run("git clone https://github.com/ostris/ai-toolkit.git", shell=True)
12
+ subprocess.run("cd ai-toolkit && git submodule update --init --recursive", shell=True)
13
+
14
+ # ai-toolkit 경로 추가
15
+ toolkit_path = os.path.join(os.getcwd(), "ai-toolkit")
16
+ sys.path.append(toolkit_path)
17
+
18
+ # 필요한 패키지 설치
19
+ subprocess.run("pip install -r ai-toolkit/requirements.txt", shell=True)
20
+
21
 
22
  is_spaces = True if os.environ.get("SPACE_ID") else False
23
 
 
215
  print("Started training")
216
  slugged_lora_name = slugify(lora_name)
217
 
218
+ try:
219
+ from toolkit.job import get_job
220
+ except ImportError:
221
+ raise gr.Error("Failed to import toolkit. Please check if ai-toolkit is properly installed.")
222
+
223
+ print("Started training")
224
+ slugged_lora_name = slugify(lora_name)
225
+
226
  # Load the default config
227
  with open("train_lora_flux_24gb.yaml", "r") as f:
228
  config = yaml.safe_load(f)