Spaces:
Running
on
A10G
Running
on
A10G
RohitGandikota
commited on
Commit
·
e092b75
1
Parent(s):
655863b
fixing training
Browse files- app.py +2 -4
- trainscripts/textsliders/demotrain.py +1 -1
app.py
CHANGED
@@ -233,10 +233,8 @@ class Demo:
|
|
233 |
# positive_prompt = ''
|
234 |
# if negative_prompt is None:
|
235 |
# negative_prompt = ''
|
236 |
-
|
237 |
-
|
238 |
-
# else:
|
239 |
-
# is_person = True
|
240 |
print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
|
241 |
|
242 |
randn = torch.randint(1, 10000000, (1,)).item()
|
|
|
233 |
# positive_prompt = ''
|
234 |
# if negative_prompt is None:
|
235 |
# negative_prompt = ''
|
236 |
+
if attributes_input == '':
|
237 |
+
attributes_input = None
|
|
|
|
|
238 |
print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
|
239 |
|
240 |
randn = torch.randint(1, 10000000, (1,)).item()
|
trainscripts/textsliders/demotrain.py
CHANGED
@@ -432,5 +432,5 @@ def train_xl(target, positive, negative, lr, iterations, config_file, rank, devi
|
|
432 |
|
433 |
prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
|
434 |
|
435 |
-
device = torch.device(
|
436 |
train(config, prompts, device)
|
|
|
432 |
|
433 |
prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
|
434 |
|
435 |
+
device = torch.device(device)
|
436 |
train(config, prompts, device)
|