RohitGandikota commited on
Commit
e092b75
·
1 Parent(s): 655863b

fixing training

Browse files
Files changed (2) hide show
  1. app.py +2 -4
  2. trainscripts/textsliders/demotrain.py +1 -1
app.py CHANGED
@@ -233,10 +233,8 @@ class Demo:
233
  # positive_prompt = ''
234
  # if negative_prompt is None:
235
  # negative_prompt = ''
236
- # if is_person is None:
237
- # is_person = False
238
- # else:
239
- # is_person = True
240
  print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
241
 
242
  randn = torch.randint(1, 10000000, (1,)).item()
 
233
  # positive_prompt = ''
234
  # if negative_prompt is None:
235
  # negative_prompt = ''
236
+ if attributes_input == '':
237
+ attributes_input = None
 
 
238
  print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
239
 
240
  randn = torch.randint(1, 10000000, (1,)).item()
trainscripts/textsliders/demotrain.py CHANGED
@@ -432,5 +432,5 @@ def train_xl(target, positive, negative, lr, iterations, config_file, rank, devi
432
 
433
  prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
434
 
435
- device = torch.device(f"cuda:{device}")
436
  train(config, prompts, device)
 
432
 
433
  prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
434
 
435
+ device = torch.device(device)
436
  train(config, prompts, device)