openfree commited on
Commit
19ab1ac
·
verified ·
1 Parent(s): 246eb4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -204,25 +204,8 @@ def start_training(
204
  use_more_advanced_options,
205
  more_advanced_options,
206
  ):
207
- # 환경 변수로 타임아웃 설정
208
- os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "300" # 5분으로 증가
209
- os.environ["REQUESTS_TIMEOUT"] = "300"
210
-
211
- import requests
212
- from huggingface_hub import HfApi
213
- from requests.adapters import HTTPAdapter
214
- from urllib3.util.retry import Retry
215
-
216
- # 재시도 전략 설정
217
- retry_strategy = Retry(
218
- total=5,
219
- backoff_factor=1,
220
- status_forcelist=[429, 500, 502, 503, 504],
221
- )
222
- adapter = HTTPAdapter(max_retries=retry_strategy)
223
- http = requests.Session()
224
- http.mount("https://", adapter)
225
- http.mount("http://", adapter)
226
 
227
  try:
228
  username = whoami()["name"]
@@ -237,13 +220,51 @@ def start_training(
237
  config = yaml.safe_load(f)
238
 
239
  # dev 모델 설정
 
240
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-dev"
241
  config["config"]["process"][0]["model"]["assistant_lora_path"] = None # adapter 없이 설정
 
 
 
 
 
 
 
 
 
 
 
242
  config["config"]["process"][0]["sample"]["sample_steps"] = 28
243
 
244
- # 나머지 설정은 동일...
245
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  try:
 
 
 
 
 
 
 
247
  # 직접 로컬 GPU에서 학습 실행
248
  from toolkit.job import get_job
249
  job = get_job(config_path)
@@ -255,6 +276,7 @@ def start_training(
255
  return f"""# Training completed successfully!
256
  ## Your model is available at: <a href='https://huggingface.co/{username}/{slugged_lora_name}'>{username}/{slugged_lora_name}</a>"""
257
 
 
258
  def update_pricing(steps):
259
  try:
260
  seconds_per_iteration = 7.54
 
204
  use_more_advanced_options,
205
  more_advanced_options,
206
  ):
207
+ if not lora_name:
208
+ raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  try:
211
  username = whoami()["name"]
 
220
  config = yaml.safe_load(f)
221
 
222
  # dev 모델 설정
223
+ config["config"]["name"] = slugged_lora_name
224
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-dev"
225
  config["config"]["process"][0]["model"]["assistant_lora_path"] = None # adapter 없이 설정
226
+ config["config"]["process"][0]["model"]["low_vram"] = False
227
+ config["config"]["process"][0]["train"]["skip_first_sample"] = True
228
+ config["config"]["process"][0]["train"]["steps"] = int(steps)
229
+ config["config"]["process"][0]["train"]["lr"] = float(lr)
230
+ config["config"]["process"][0]["network"]["linear"] = int(rank)
231
+ config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
232
+ config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
233
+ config["config"]["process"][0]["save"]["push_to_hub"] = True
234
+ config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
235
+ config["config"]["process"][0]["save"]["hf_private"] = True
236
+ config["config"]["process"][0]["save"]["hf_token"] = HF_TOKEN
237
  config["config"]["process"][0]["sample"]["sample_steps"] = 28
238
 
239
+ if concept_sentence:
240
+ config["config"]["process"][0]["trigger_word"] = concept_sentence
241
+
242
+ if sample_1 or sample_2 or sample_3:
243
+ config["config"]["process"][0]["train"]["disable_sampling"] = False
244
+ config["config"]["process"][0]["sample"]["sample_every"] = steps
245
+ config["config"]["process"][0]["sample"]["prompts"] = []
246
+ if sample_1:
247
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_1)
248
+ if sample_2:
249
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
250
+ if sample_3:
251
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
252
+ else:
253
+ config["config"]["process"][0]["train"]["disable_sampling"] = True
254
+
255
+ if(use_more_advanced_options):
256
+ more_advanced_options_dict = yaml.safe_load(more_advanced_options)
257
+ config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
258
+ print(config)
259
+
260
  try:
261
+ # Save the updated config
262
+ random_config_name = str(uuid.uuid4())
263
+ os.makedirs("tmp", exist_ok=True)
264
+ config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
265
+ with open(config_path, "w") as f:
266
+ yaml.dump(config, f)
267
+
268
  # 직접 로컬 GPU에서 학습 실행
269
  from toolkit.job import get_job
270
  job = get_job(config_path)
 
276
  return f"""# Training completed successfully!
277
  ## Your model is available at: <a href='https://huggingface.co/{username}/{slugged_lora_name}'>{username}/{slugged_lora_name}</a>"""
278
 
279
+
280
  def update_pricing(steps):
281
  try:
282
  seconds_per_iteration = 7.54