openfree commited on
Commit
246eb4a
ยท
verified ยท
1 Parent(s): 498a763

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -60
app.py CHANGED
@@ -204,8 +204,25 @@ def start_training(
204
  use_more_advanced_options,
205
  more_advanced_options,
206
  ):
207
- if not lora_name:
208
- raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  try:
211
  username = whoami()["name"]
@@ -215,71 +232,25 @@ def start_training(
215
  print("Started training")
216
  slugged_lora_name = slugify(lora_name)
217
 
218
- try:
219
- from toolkit.job import get_job
220
- except ImportError:
221
- raise gr.Error("Failed to import toolkit. Please check if ai-toolkit is properly installed.")
222
-
223
- print("Started training")
224
- slugged_lora_name = slugify(lora_name)
225
-
226
  # Load the default config
227
  with open("train_lora_flux_24gb.yaml", "r") as f:
228
  config = yaml.safe_load(f)
229
 
230
- # Update the config with user inputs
231
- config["config"]["name"] = slugged_lora_name
232
- config["config"]["process"][0]["model"]["low_vram"] = False
233
- config["config"]["process"][0]["train"]["skip_first_sample"] = True
234
- config["config"]["process"][0]["train"]["steps"] = int(steps)
235
- config["config"]["process"][0]["train"]["lr"] = float(lr)
236
- config["config"]["process"][0]["network"]["linear"] = int(rank)
237
- config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
238
- config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
239
- config["config"]["process"][0]["save"]["push_to_hub"] = True
240
- config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
241
- config["config"]["process"][0]["save"]["hf_private"] = True
242
- config["config"]["process"][0]["save"]["hf_token"] = HF_TOKEN
243
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-dev"
244
- config["config"]["process"][0]["model"]["assistant_lora_path"] = None # training adapter ์—†์ด ์‹œ๋„
245
- config["config"]["process"][0]["sample"]["sample_steps"] = 28 # dev ๋ชจ๋ธ์˜ ๊ธฐ๋ณธ ์Šคํ…
246
-
247
- if concept_sentence:
248
- config["config"]["process"][0]["trigger_word"] = concept_sentence
249
-
250
- if sample_1 or sample_2 or sample_3:
251
- config["config"]["process"][0]["train"]["disable_sampling"] = False
252
- config["config"]["process"][0]["sample"]["sample_every"] = steps
253
- config["config"]["process"][0]["sample"]["sample_steps"] = 28
254
- config["config"]["process"][0]["sample"]["prompts"] = []
255
- if sample_1:
256
- config["config"]["process"][0]["sample"]["prompts"].append(sample_1)
257
- if sample_2:
258
- config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
259
- if sample_3:
260
- config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
261
- else:
262
- config["config"]["process"][0]["train"]["disable_sampling"] = True
263
-
264
 
 
265
 
266
- if(use_more_advanced_options):
267
- more_advanced_options_dict = yaml.safe_load(more_advanced_options)
268
- config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
269
- print(config)
270
-
271
- # Save the updated config
272
- random_config_name = str(uuid.uuid4())
273
- os.makedirs("tmp", exist_ok=True)
274
- config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
275
- with open(config_path, "w") as f:
276
- yaml.dump(config, f)
277
-
278
- # ์ง์ ‘ ๋กœ์ปฌ GPU์—์„œ ํ•™์Šต ์‹คํ–‰
279
- from toolkit.job import get_job
280
- job = get_job(config_path)
281
- job.run()
282
- job.cleanup()
283
 
284
  return f"""# Training completed successfully!
285
  ## Your model is available at: <a href='https://huggingface.co/{username}/{slugged_lora_name}'>{username}/{slugged_lora_name}</a>"""
 
204
  use_more_advanced_options,
205
  more_advanced_options,
206
  ):
207
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋กœ ํƒ€์ž„์•„์›ƒ ์„ค์ •
208
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "300" # 5๋ถ„์œผ๋กœ ์ฆ๊ฐ€
209
+ os.environ["REQUESTS_TIMEOUT"] = "300"
210
+
211
+ import requests
212
+ from huggingface_hub import HfApi
213
+ from requests.adapters import HTTPAdapter
214
+ from urllib3.util.retry import Retry
215
+
216
+ # ์žฌ์‹œ๋„ ์ „๋žต ์„ค์ •
217
+ retry_strategy = Retry(
218
+ total=5,
219
+ backoff_factor=1,
220
+ status_forcelist=[429, 500, 502, 503, 504],
221
+ )
222
+ adapter = HTTPAdapter(max_retries=retry_strategy)
223
+ http = requests.Session()
224
+ http.mount("https://", adapter)
225
+ http.mount("http://", adapter)
226
 
227
  try:
228
  username = whoami()["name"]
 
232
  print("Started training")
233
  slugged_lora_name = slugify(lora_name)
234
 
 
 
 
 
 
 
 
 
235
  # Load the default config
236
  with open("train_lora_flux_24gb.yaml", "r") as f:
237
  config = yaml.safe_load(f)
238
 
239
+ # dev ๋ชจ๋ธ ์„ค์ •
 
 
 
 
 
 
 
 
 
 
 
 
240
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-dev"
241
+ config["config"]["process"][0]["model"]["assistant_lora_path"] = None # adapter ์—†์ด ์„ค์ •
242
+ config["config"]["process"][0]["sample"]["sample_steps"] = 28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+ # ๋‚˜๋จธ์ง€ ์„ค์ •์€ ๋™์ผ...
245
 
246
+ try:
247
+ # ์ง์ ‘ ๋กœ์ปฌ GPU์—์„œ ํ•™์Šต ์‹คํ–‰
248
+ from toolkit.job import get_job
249
+ job = get_job(config_path)
250
+ job.run()
251
+ job.cleanup()
252
+ except Exception as e:
253
+ raise gr.Error(f"Training failed: {str(e)}")
 
 
 
 
 
 
 
 
 
254
 
255
  return f"""# Training completed successfully!
256
  ## Your model is available at: <a href='https://huggingface.co/{username}/{slugged_lora_name}'>{username}/{slugged_lora_name}</a>"""