RohitGandikota commited on
Commit
a9bcbb2
β€’
1 Parent(s): 47a88ae

pushing training code

Browse files
Files changed (1) hide show
  1. app.py +35 -41
app.py CHANGED
@@ -135,21 +135,33 @@ class Demo:
135
  self.target_concept = gr.Text(
136
  placeholder="Enter target concept to make edit on ...",
137
  label="Prompt of concept on which edit is made",
138
- info="Prompt corresponding to concept to edit"
 
139
  )
140
 
141
  self.positive_prompt = gr.Text(
142
- placeholder="Enter the enhance prompt for the edit...",
143
  label="Prompt to enhance",
144
- info="Prompt corresponding to concept to enhance"
 
145
  )
146
 
147
  self.negative_prompt = gr.Text(
148
- placeholder="Enter the suppress prompt for the edit...",
149
  label="Prompt to suppress",
150
- info="Prompt corresponding to concept to supress"
 
151
  )
152
-
 
 
 
 
 
 
 
 
 
153
 
154
  self.rank = gr.Number(
155
  value=4,
@@ -198,12 +210,22 @@ class Demo:
198
  self.negative_prompt,
199
  self.rank,
200
  self.iterations_input,
201
- self.lr_input
 
 
202
  ],
203
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
204
  )
205
 
206
- def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
207
 
208
 
209
  randn = torch.randint(1, 10000000, (1,)).item()
@@ -214,9 +236,11 @@ class Demo:
214
 
215
  if self.training:
216
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
217
-
 
 
218
  self.training = True
219
- train_xl(target, postive, negative, lr, iterations, config_file, rank, device, attributes)
220
 
221
  self.training = False
222
 
@@ -224,38 +248,8 @@ class Demo:
224
  model_map['Custom Slider'] = f'models/{save_name}'
225
 
226
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom Slider')]
227
- # if train_method == 'ESD-x':
228
-
229
- # modules = ".*attn2$"
230
- # frozen = []
231
-
232
- # elif train_method == 'ESD-u':
233
-
234
- # modules = "unet$"
235
- # frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"]
236
-
237
- # elif train_method == 'ESD-self':
238
-
239
- # modules = ".*attn1$"
240
- # frozen = []
241
-
242
- #
243
-
244
- # save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
245
-
246
- # self.training = True
247
-
248
- # train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
249
-
250
- # self.training = False
251
-
252
- # torch.cuda.empty_cache()
253
-
254
- # model_map['Custom'] = save_path
255
-
256
- #
257
- return [None, None, None, None]
258
 
 
259
  def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
260
 
261
  seed = seed or 12345
 
135
  self.target_concept = gr.Text(
136
  placeholder="Enter target concept to make edit on ...",
137
  label="Prompt of concept on which edit is made",
138
+ info="Prompt corresponding to concept to edit",
139
+ value = ''
140
  )
141
 
142
  self.positive_prompt = gr.Text(
143
+ placeholder="Enter the enhance prompt for the edit ...",
144
  label="Prompt to enhance",
145
+ info="Prompt corresponding to concept to enhance",
146
+ value = ''
147
  )
148
 
149
  self.negative_prompt = gr.Text(
150
+ placeholder="Enter the suppress prompt for the edit ...",
151
  label="Prompt to suppress",
152
+ info="Prompt corresponding to concept to supress",
153
+ value = ''
154
  )
155
+
156
+ self.attributes_input = gr.Text(
157
+ placeholder="Enter the concepts to preserve (comma seperated). Leave empty if not required ...",
158
+ label="Concepts to Preserve",
159
+ info="Comma seperated concepts to preserve/disentangle",
160
+ value = ''
161
+ )
162
+ self.is_person = gr.Checkbox(
163
+ label="Person",
164
+ info="Are you training a slider for person?")
165
 
166
  self.rank = gr.Number(
167
  value=4,
 
210
  self.negative_prompt,
211
  self.rank,
212
  self.iterations_input,
213
+ self.lr_input,
214
+ self.attributes_input,
215
+ self.is_person
216
  ],
217
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
218
  )
219
 
220
+ def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, train_method, neg_guidance, iterations, lr, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
221
+ if '' in attributes_input:
222
+ attributes_input = None
223
+ if '...' in target_concept:
224
+ target_concept = ''
225
+ if '...' in positive_prompt:
226
+ positive_prompt = ''
227
+ if '...' in negative_prompt:
228
+ negative_prompt = ''
229
 
230
 
231
  randn = torch.randint(1, 10000000, (1,)).item()
 
236
 
237
  if self.training:
238
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
239
+ attributes = attributes_input
240
+ if is_person:
241
+ attributes = 'white, black, asian, hispanic, indian, male, female'
242
  self.training = True
243
+ train_xl(target=target_concept, postive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=rank, device=self.device, attributes=attributes, save_name=save_name)
244
 
245
  self.training = False
246
 
 
248
  model_map['Custom Slider'] = f'models/{save_name}'
249
 
250
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom Slider')]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+
253
  def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
254
 
255
  seed = seed or 12345