RohitGandikota commited on
Commit
81a83c8
β€’
1 Parent(s): e89a824

inference test

Browse files
Files changed (1) hide show
  1. app.py +67 -36
app.py CHANGED
@@ -5,6 +5,9 @@ from utils import call
5
  from diffusers.pipelines import StableDiffusionXLPipeline
6
  StableDiffusionXLPipeline.__call__ = call
7
  import os
 
 
 
8
  os.environ['CURL_CA_BUNDLE'] = ''
9
  model_map = {'Age' : 'models/age.pt',
10
  'Chubby': 'models/chubby.pt',
@@ -36,7 +39,7 @@ class Demo:
36
  self.generating = False
37
  self.device = 'cuda'
38
  self.weight_dtype = torch.float16
39
- self.pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=self.weight_dtype)
40
 
41
  with gr.Blocks() as demo:
42
  self.layout()
@@ -57,7 +60,7 @@ class Demo:
57
 
58
  with gr.Row():
59
 
60
- self.explain_infr = gr.Markdown(value='This is a demo of [Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models](https://sliders.baulab.info/). To try out a model that can control a particular concept, select a model and enter any prompt. For example, if you select the model "Surprised Look" you can generate images for the prompt "A picture of a person, realistic, 8k" and compare the slider effect to the image generated by original model. We have also provided several other pre-fine-tuned models like "repair" sliders to repair flaws in SDXL generated images (Check out the "Pretrained Sliders" drop-down). You can also train and run your own custom sliders. Check out the "train" section for custom concept slider training.')
61
 
62
  with gr.Row():
63
 
@@ -82,6 +85,20 @@ class Demo:
82
  label="Seed",
83
  value=12345
84
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  with gr.Column(scale=2):
87
 
@@ -162,6 +179,8 @@ class Demo:
162
  self.infr_button.click(self.inference, inputs = [
163
  self.prompt_input_infr,
164
  self.seed_infr,
 
 
165
  self.model_dropdown
166
  ],
167
  outputs=[
@@ -217,7 +236,7 @@ class Demo:
217
  # return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
218
  return None
219
 
220
- def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
221
 
222
  seed = seed or 12345
223
 
@@ -225,41 +244,53 @@ class Demo:
225
 
226
  model_path = model_map[model_name]
227
 
228
- checkpoint = torch.load(model_path)
229
-
230
- return None
231
-
232
- # finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
233
-
234
- # torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- # images = self.diffuser(
237
- # prompt,
238
- # n_steps=50,
239
- # generator=generator
240
- # )
241
 
 
 
242
 
243
- # orig_image = images[0][0]
244
-
245
- # torch.cuda.empty_cache()
246
-
247
- # generator = torch.manual_seed(seed)
248
-
249
- # with finetuner:
250
-
251
- # images = self.diffuser(
252
- # prompt,
253
- # n_steps=50,
254
- # generator=generator
255
- # )
256
-
257
- # edited_image = images[0][0]
258
-
259
- # del finetuner
260
- # torch.cuda.empty_cache()
261
-
262
- # return edited_image, orig_image
263
-
264
 
265
  demo = Demo()
 
5
  from diffusers.pipelines import StableDiffusionXLPipeline
6
  StableDiffusionXLPipeline.__call__ = call
7
  import os
8
+ from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
9
+
10
+
11
  os.environ['CURL_CA_BUNDLE'] = ''
12
  model_map = {'Age' : 'models/age.pt',
13
  'Chubby': 'models/chubby.pt',
 
39
  self.generating = False
40
  self.device = 'cuda'
41
  self.weight_dtype = torch.float16
42
+ self.pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=self.weight_dtype).to(self.device)
43
 
44
  with gr.Blocks() as demo:
45
  self.layout()
 
60
 
61
  with gr.Row():
62
 
63
+ self.explain_infr = gr.Markdown(value='This is a demo of [Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models](https://sliders.baulab.info/). To try out a model that can control a particular concept, select a model and enter any prompt, choose a seed, and finally choose the SDEdit timestep for structural preservation. Higher SDEdit timesteps results in more structural change. For example, if you select the model "Surprised Look" you can generate images for the prompt "A picture of a person, realistic, 8k" and compare the slider effect to the image generated by original model. We have also provided several other pre-fine-tuned models like "repair" sliders to repair flaws in SDXL generated images (Check out the "Pretrained Sliders" drop-down). You can also train and run your own custom sliders. Check out the "train" section for custom concept slider training.')
64
 
65
  with gr.Row():
66
 
 
85
  label="Seed",
86
  value=12345
87
  )
88
+
89
+ self.slider_scale_infr = gr.Number(
90
+ label="Slider Scale",
91
+ value=2,
92
+ info="Larger slider scale result in stronger edit"
93
+ )
94
+
95
+
96
+ self.start_noise_infr = gr.Slider(
97
+ 600, 900,
98
+ value=750,
99
+ label="SDEdit Timestep",
100
+ info="Choose smaller values for more structural preservation"
101
+ )
102
 
103
  with gr.Column(scale=2):
104
 
 
179
  self.infr_button.click(self.inference, inputs = [
180
  self.prompt_input_infr,
181
  self.seed_infr,
182
+ self.start_noise_infr,
183
+ self.slider_scale_infr,
184
  self.model_dropdown
185
  ],
186
  outputs=[
 
236
  # return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
237
  return None
238
 
239
+ def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
240
 
241
  seed = seed or 12345
242
 
 
244
 
245
  model_path = model_map[model_name]
246
 
247
+ unet = self.pipe.unet
248
+ network_type = "c3lier"
249
+ if 'full' in model_path:
250
+ train_method = 'full'
251
+ elif 'noxattn' in model_path:
252
+ train_method = 'noxattn'
253
+ elif 'xattn' in model_path:
254
+ train_method = 'xattn'
255
+ network_type = 'lierla'
256
+ else:
257
+ train_method = 'noxattn'
258
+
259
+ modules = DEFAULT_TARGET_REPLACE
260
+ if network_type == "c3lier":
261
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
262
+
263
+ name = os.path.basename(model_path)
264
+ rank = 4
265
+ alpha = 1
266
+ if 'rank4' in model_path:
267
+ rank = 4
268
+ if 'rank8' in model_path:
269
+ rank = 8
270
+ if 'alpha1' in model_path:
271
+ alpha = 1.0
272
+ network = LoRANetwork(
273
+ unet,
274
+ rank=rank,
275
+ multiplier=1.0,
276
+ alpha=alpha,
277
+ train_method=train_method,
278
+ ).to(self.device, dtype=self.weight_dtype)
279
+ network.load_state_dict(torch.load(model_path))
280
 
 
 
 
 
 
281
 
282
+ generator = torch.manual_seed(seed)
283
+ edited_image = pipe(prompt, num_images_per_prompt=1, num_inference_steps=50, generator=generator, network=network, start_noise=start_noise, scale=scale, unet=unet).images[0]
284
 
285
+ generator = torch.manual_seed(seed)
286
+ original_image = pipe(prompt, num_images_per_prompt=1, num_inference_steps=50, generator=generator, network=network, start_noise=start_noise, scale=0, unet=unet).images[0]
287
+
288
+ del unet, network
289
+ unet = None
290
+ network = None
291
+ pipe = None
292
+ torch.cuda.empty_cache()
293
+
294
+ return edited_image, original_image
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  demo = Demo()