openfree commited on
Commit
a9043b3
·
verified ·
1 Parent(s): 217860d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -25
app.py CHANGED
@@ -1,12 +1,9 @@
1
  import os
2
  import subprocess
3
  from typing import Union
4
- from huggingface_hub import whoami
5
- is_spaces = True if os.environ.get("SPACE_ID") else False
6
 
7
- if is_spaces:
8
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
- import spaces
10
 
11
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
12
  import sys
@@ -22,7 +19,6 @@ import gradio as gr
22
  from PIL import Image
23
  import torch
24
  import uuid
25
- import os
26
  import shutil
27
  import json
28
  import yaml
@@ -38,16 +34,11 @@ if not is_spaces:
38
  MAX_IMAGES = 150
39
 
40
 
41
- import subprocess
42
- from typing import Union
43
- from huggingface_hub import whoami, HfApi
44
-
45
  # Hugging Face 토큰 설정
46
  HF_TOKEN = os.getenv("HF_TOKEN")
47
  if not HF_TOKEN:
48
  raise ValueError("HF_TOKEN environment variable is not set")
49
 
50
- is_spaces = True if os.environ.get("SPACE_ID") else False
51
 
52
  if is_spaces:
53
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@@ -203,11 +194,13 @@ def start_training(
203
  print("Started training")
204
  slugged_lora_name = slugify(lora_name)
205
 
206
-
 
 
207
 
208
  # Update the config with user inputs
209
  config["config"]["name"] = slugged_lora_name
210
- config["config"]["process"][0]["model"]["low_vram"] = False # L40S has enough VRAM
211
  config["config"]["process"][0]["train"]["skip_first_sample"] = True
212
  config["config"]["process"][0]["train"]["steps"] = int(steps)
213
  config["config"]["process"][0]["train"]["lr"] = float(lr)
@@ -215,18 +208,15 @@ def start_training(
215
  config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
216
  config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
217
  config["config"]["process"][0]["save"]["push_to_hub"] = True
218
-
219
- try:
220
- username = profile.username
221
- except:
222
- raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
223
-
224
  config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
225
  config["config"]["process"][0]["save"]["hf_private"] = True
226
  config["config"]["process"][0]["save"]["hf_token"] = HF_TOKEN
 
 
 
 
227
 
228
 
229
-
230
  if concept_sentence:
231
  config["config"]["process"][0]["trigger_word"] = concept_sentence
232
 
@@ -244,11 +234,6 @@ def start_training(
244
  else:
245
  config["config"]["process"][0]["train"]["disable_sampling"] = True
246
 
247
- if(which_model == "[schnell] (4 step fast model)"):
248
- # schnell 관련 조건문을 dev로 변경
249
- config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-dev"
250
- config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-dev-training-adapter"
251
- config["config"]["process"][0]["sample"]["sample_steps"] = 28 # dev 모델의 기본 스텝
252
 
253
 
254
  if(use_more_advanced_options):
@@ -375,12 +360,15 @@ with gr.Blocks(theme=theme, css=css) as demo:
375
  placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
376
  interactive=True,
377
  )
 
 
378
 
379
  which_model = gr.Radio(
380
  ["[dev] (high quality model)"],
381
  label="Base model",
382
  value="[dev] (high quality model)"
383
  )
 
384
 
385
  with gr.Group(visible=True) as image_upload:
386
  with gr.Row():
 
1
  import os
2
  import subprocess
3
  from typing import Union
4
+ from huggingface_hub import whoami, HfApi
 
5
 
6
+ is_spaces = True if os.environ.get("SPACE_ID") else False
 
 
7
 
8
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
9
  import sys
 
19
  from PIL import Image
20
  import torch
21
  import uuid
 
22
  import shutil
23
  import json
24
  import yaml
 
34
  MAX_IMAGES = 150
35
 
36
 
 
 
 
 
37
  # Hugging Face 토큰 설정
38
  HF_TOKEN = os.getenv("HF_TOKEN")
39
  if not HF_TOKEN:
40
  raise ValueError("HF_TOKEN environment variable is not set")
41
 
 
42
 
43
  if is_spaces:
44
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
194
  print("Started training")
195
  slugged_lora_name = slugify(lora_name)
196
 
197
+ # Load the default config
198
+ with open("train_lora_flux_24gb.yaml", "r") as f:
199
+ config = yaml.safe_load(f)
200
 
201
  # Update the config with user inputs
202
  config["config"]["name"] = slugged_lora_name
203
+ config["config"]["process"][0]["model"]["low_vram"] = False
204
  config["config"]["process"][0]["train"]["skip_first_sample"] = True
205
  config["config"]["process"][0]["train"]["steps"] = int(steps)
206
  config["config"]["process"][0]["train"]["lr"] = float(lr)
 
208
  config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
209
  config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
210
  config["config"]["process"][0]["save"]["push_to_hub"] = True
 
 
 
 
 
 
211
  config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
212
  config["config"]["process"][0]["save"]["hf_private"] = True
213
  config["config"]["process"][0]["save"]["hf_token"] = HF_TOKEN
214
+ config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-dev"
215
+ config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-dev-training-adapter"
216
+ config["config"]["process"][0]["sample"]["sample_steps"] = 28 # dev 모델의 기본 스텝
217
+
218
 
219
 
 
220
  if concept_sentence:
221
  config["config"]["process"][0]["trigger_word"] = concept_sentence
222
 
 
234
  else:
235
  config["config"]["process"][0]["train"]["disable_sampling"] = True
236
 
 
 
 
 
 
237
 
238
 
239
  if(use_more_advanced_options):
 
360
  placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
361
  interactive=True,
362
  )
363
+ # model_warning 변수 추가
364
+ model_warning = gr.Markdown(visible=False)
365
 
366
  which_model = gr.Radio(
367
  ["[dev] (high quality model)"],
368
  label="Base model",
369
  value="[dev] (high quality model)"
370
  )
371
+
372
 
373
  with gr.Group(visible=True) as image_upload:
374
  with gr.Row():