openfree commited on
Commit
0c5a016
·
verified ·
1 Parent(s): 09e5977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -40
app.py CHANGED
@@ -157,6 +157,7 @@ def recursive_update(d, u):
157
  d[k] = v
158
  return d
159
 
 
160
  def start_training(
161
  lora_name,
162
  concept_sentence,
@@ -173,29 +174,19 @@ def start_training(
173
  profile: Union[gr.OAuthProfile, None],
174
  oauth_token: Union[gr.OAuthToken, None],
175
  ):
176
-
177
  if not lora_name:
178
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
179
 
180
- if not is_spaces:
181
- try:
182
- if whoami()["auth"]["accessToken"]["role"] == "write" or "repo.write" in whoami()["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]:
183
- gr.Info(f"Starting training locally {whoami()['name']}. Your LoRA will be available locally and in Hugging Face after it finishes.")
184
- else:
185
- raise gr.Error(f"You logged in to Hugging Face with not enough permissions, you need a token that allows writing to your profile.")
186
- except:
187
- raise gr.Error(f"You logged in to Hugging Face with not enough permissions, you need a token that allows writing to your profile.")
188
-
189
  print("Started training")
190
  slugged_lora_name = slugify(lora_name)
191
 
192
  # Load the default config
193
- with open("train_lora_flux_24gb.yaml" if is_spaces else "ai-toolkit/config/examples/train_lora_flux_24gb.yaml", "r") as f:
194
  config = yaml.safe_load(f)
195
 
196
  # Update the config with user inputs
197
  config["config"]["name"] = slugged_lora_name
198
- config["config"]["process"][0]["model"]["low_vram"] = True
199
  config["config"]["process"][0]["train"]["skip_first_sample"] = True
200
  config["config"]["process"][0]["train"]["steps"] = int(steps)
201
  config["config"]["process"][0]["train"]["lr"] = float(lr)
@@ -203,12 +194,15 @@ def start_training(
203
  config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
204
  config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
205
  config["config"]["process"][0]["save"]["push_to_hub"] = True
 
206
  try:
207
- username = whoami()["name"] if not is_spaces else profile.username
208
  except:
209
  raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
 
210
  config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
211
  config["config"]["process"][0]["save"]["hf_private"] = True
 
212
  if concept_sentence:
213
  config["config"]["process"][0]["trigger_word"] = concept_sentence
214
 
@@ -237,39 +231,20 @@ def start_training(
237
  print(config)
238
 
239
  # Save the updated config
240
- # generate a random name for the config
241
  random_config_name = str(uuid.uuid4())
242
  os.makedirs("tmp", exist_ok=True)
243
  config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
244
  with open(config_path, "w") as f:
245
  yaml.dump(config, f)
246
- if is_spaces:
247
- # copy config to dataset_folder as config.yaml
248
- shutil.copy(config_path, dataset_folder + "/config.yaml")
249
- # get location of this script
250
- script_location = os.path.dirname(os.path.abspath(__file__))
251
- # copy script.py from current directory to dataset_folder
252
- shutil.copy(script_location + "/script.py", dataset_folder)
253
- # copy requirements.autotrain to dataset_folder as requirements.txt
254
- shutil.copy(script_location + "/requirements.autotrain", dataset_folder + "/requirements.txt")
255
- # command to run autotrain spacerunner
256
- cmd = f"autotrain spacerunner --project-name {slugged_lora_name} --script-path {dataset_folder}"
257
- cmd += f" --username {profile.username} --token {oauth_token.token} --backend spaces-l4x1"
258
- outcome = subprocess.run(cmd.split())
259
- if outcome.returncode == 0:
260
- return f"""# Your training has started.
261
- ## - Training Status: <a href='https://huggingface.co/spaces/{profile.username}/autotrain-{slugged_lora_name}?logs=container'>{profile.username}/autotrain-{slugged_lora_name}</a> <small>(in the logs tab)</small>
262
- ## - Model page: <a href='https://huggingface.co/{profile.username}/{slugged_lora_name}'>{profile.username}/{slugged_lora_name}</a> <small>(will be available when training finishes)</small>"""
263
- else:
264
- print("Error: ", outcome.stderr)
265
- raise gr.Error("Something went wrong. Make sure the name of your LoRA is unique and try again")
266
- else:
267
- # run the job locally
268
- job = get_job(config_path)
269
- job.run()
270
- job.cleanup()
271
 
272
- return f"Training completed successfully. Model saved as {slugged_lora_name}"
 
 
 
 
 
 
 
273
 
274
  def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
275
  if is_spaces:
 
157
  d[k] = v
158
  return d
159
 
160
+ # start_training 함수 수정 부분
161
  def start_training(
162
  lora_name,
163
  concept_sentence,
 
174
  profile: Union[gr.OAuthProfile, None],
175
  oauth_token: Union[gr.OAuthToken, None],
176
  ):
 
177
  if not lora_name:
178
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
179
 
 
 
 
 
 
 
 
 
 
180
  print("Started training")
181
  slugged_lora_name = slugify(lora_name)
182
 
183
  # Load the default config
184
+ with open("train_lora_flux_24gb.yaml", "r") as f:
185
  config = yaml.safe_load(f)
186
 
187
  # Update the config with user inputs
188
  config["config"]["name"] = slugged_lora_name
189
+ config["config"]["process"][0]["model"]["low_vram"] = False # L40S has enough VRAM
190
  config["config"]["process"][0]["train"]["skip_first_sample"] = True
191
  config["config"]["process"][0]["train"]["steps"] = int(steps)
192
  config["config"]["process"][0]["train"]["lr"] = float(lr)
 
194
  config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
195
  config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
196
  config["config"]["process"][0]["save"]["push_to_hub"] = True
197
+
198
  try:
199
+ username = profile.username
200
  except:
201
  raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
202
+
203
  config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
204
  config["config"]["process"][0]["save"]["hf_private"] = True
205
+
206
  if concept_sentence:
207
  config["config"]["process"][0]["trigger_word"] = concept_sentence
208
 
 
231
  print(config)
232
 
233
  # Save the updated config
 
234
  random_config_name = str(uuid.uuid4())
235
  os.makedirs("tmp", exist_ok=True)
236
  config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
237
  with open(config_path, "w") as f:
238
  yaml.dump(config, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ # 직접 로컬 GPU에서 학습 실행
241
+ from toolkit.job import get_job
242
+ job = get_job(config_path)
243
+ job.run()
244
+ job.cleanup()
245
+
246
+ return f"""# Training completed successfully!
247
+ ## Your model is available at: <a href='https://huggingface.co/{username}/{slugged_lora_name}'>{username}/{slugged_lora_name}</a>"""
248
 
249
  def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
250
  if is_spaces: