Ziqi commited on
Commit
8913269
·
1 Parent(s): 2b90a75
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +29 -248
  3. inference.py +81 -66
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ experiments/*
app.py CHANGED
@@ -19,7 +19,8 @@ import pathlib
19
  import gradio as gr
20
  import torch
21
 
22
- from inference import InferencePipeline
 
23
  # from trainer import Trainer
24
  # from uploader import upload
25
 
@@ -69,173 +70,6 @@ def update_output_files() -> dict:
69
  paths = [path.as_posix() for path in paths] # type: ignore
70
  return gr.update(value=paths or None)
71
 
72
-
73
- def create_training_demo(trainer: Trainer,
74
- pipe: InferencePipeline) -> gr.Blocks:
75
- with gr.Blocks() as demo:
76
- base_model = gr.Dropdown(
77
- choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
78
- value='CompVis/stable-diffusion-v1-4',
79
- label='Base Model',
80
- visible=True)
81
- resolution = gr.Dropdown(choices=['512', '768'],
82
- value='512',
83
- label='Resolution',
84
- visible=True)
85
-
86
- with gr.Row():
87
- with gr.Box():
88
- concept_images_collection = []
89
- concept_prompt_collection = []
90
- class_prompt_collection = []
91
- buttons_collection = []
92
- delete_collection = []
93
- is_visible = []
94
- maximum_concepts = 3
95
- row = [None] * maximum_concepts
96
- for x in range(maximum_concepts):
97
- ordinal = lambda n: "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])
98
- ordinal_concept = ["<new1> cat", "<new2> wooden pot", "<new3> chair"]
99
- if(x == 0):
100
- visible = True
101
- is_visible.append(gr.State(value=True))
102
- else:
103
- visible = False
104
- is_visible.append(gr.State(value=False))
105
-
106
- concept_images_collection.append(gr.Files(label=f'''Upload the images for your {ordinal(x+1) if (x>0) else ""} concept''', visible=visible))
107
- with gr.Column(visible=visible) as row[x]:
108
- concept_prompt_collection.append(
109
- gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} concept prompt ''', max_lines=1,
110
- placeholder=f'''Example: "photo of a {ordinal_concept[x]}"''' )
111
- )
112
- class_prompt_collection.append(
113
- gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} class prompt ''',
114
- max_lines=1, placeholder=f'''Example: "{ordinal_concept[x][7:]}"''')
115
- )
116
- with gr.Row():
117
- if(x < maximum_concepts-1):
118
- buttons_collection.append(gr.Button(value=f"Add {ordinal(x+2)} concept", visible=visible))
119
- if(x > 0):
120
- delete_collection.append(gr.Button(value=f"Delete {ordinal(x+1)} concept"))
121
-
122
- counter_add = 1
123
- for button in buttons_collection:
124
- if(counter_add < len(buttons_collection)):
125
- button.click(lambda:
126
- [gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), True, None],
127
- None,
128
- [row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], buttons_collection[counter_add], is_visible[counter_add], concept_images_collection[counter_add]], queue=False)
129
- else:
130
- button.click(lambda:
131
- [gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), True],
132
- None,
133
- [row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], is_visible[counter_add]], queue=False)
134
- counter_add += 1
135
-
136
- counter_delete = 1
137
- for delete_button in delete_collection:
138
- if(counter_delete < len(delete_collection)+1):
139
- if counter_delete == 1:
140
- delete_button.click(lambda:
141
- [gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),False],
142
- None,
143
- [concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], buttons_collection[counter_delete], is_visible[counter_delete]], queue=False)
144
- else:
145
- delete_button.click(lambda:
146
- [gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), False],
147
- None,
148
- [concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], is_visible[counter_delete]], queue=False)
149
- counter_delete += 1
150
- gr.Markdown('''
151
- - We use "\<new1\>" modifier_token in front of the concept, e.g., "\<new1\> cat". For multiple concepts use "\<new2\>", "\<new3\>" etc. Increase the number of steps with more concepts.
152
- - For a new concept an e.g. concept prompt is "photo of a \<new1\> cat" and "cat" for class prompt.
153
- - For a style concept, use "painting in the style of \<new1\> art" for concept prompt and "art" for class prompt.
154
- - Class prompt should be the object category.
155
- - If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
156
- ''')
157
- with gr.Box():
158
- gr.Markdown('Training Parameters')
159
- with gr.Row():
160
- modifier_token = gr.Checkbox(label='modifier token',
161
- value=True)
162
- train_text_encoder = gr.Checkbox(label='Train Text Encoder',
163
- value=False)
164
- num_training_steps = gr.Number(
165
- label='Number of Training Steps', value=1000, precision=0)
166
- learning_rate = gr.Number(label='Learning Rate', value=0.00001)
167
- batch_size = gr.Number(
168
- label='batch_size', value=1, precision=0)
169
- with gr.Row():
170
- use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
171
- gradient_checkpointing = gr.Checkbox(label='Enable gradient checkpointing', value=False)
172
- with gr.Accordion('Other Parameters', open=False):
173
- gradient_accumulation = gr.Number(
174
- label='Number of Gradient Accumulation',
175
- value=1,
176
- precision=0)
177
- num_reg_images = gr.Number(
178
- label='Number of Class Concept images',
179
- value=200,
180
- precision=0)
181
- gen_images = gr.Checkbox(label='Generated images as regularization',
182
- value=False)
183
- gr.Markdown('''
184
- - It will take about ~10 minutes to train for 1000 steps and ~21GB on a 3090 GPU.
185
- - Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
186
- - Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
187
- - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
188
- - We retrieve real images for class concept using clip_retireval library which can take some time.
189
- ''')
190
-
191
- run_button = gr.Button('Start Training')
192
- with gr.Box():
193
- with gr.Row():
194
- check_status_button = gr.Button('Check Training Status')
195
- with gr.Column():
196
- with gr.Box():
197
- gr.Markdown('Message')
198
- training_status = gr.Markdown()
199
- output_files = gr.Files(label='Trained Weight Files')
200
-
201
- run_button.click(fn=pipe.clear,
202
- inputs=None,
203
- outputs=None,)
204
- run_button.click(fn=trainer.run,
205
- inputs=[
206
- base_model,
207
- resolution,
208
- num_training_steps,
209
- learning_rate,
210
- train_text_encoder,
211
- modifier_token,
212
- gradient_accumulation,
213
- batch_size,
214
- use_8bit_adam,
215
- gradient_checkpointing,
216
- gen_images,
217
- num_reg_images,
218
- ] +
219
- concept_images_collection +
220
- concept_prompt_collection +
221
- class_prompt_collection
222
- ,
223
- outputs=[
224
- training_status,
225
- output_files,
226
- ],
227
- queue=False)
228
- check_status_button.click(fn=trainer.check_if_running,
229
- inputs=None,
230
- outputs=training_status,
231
- queue=False)
232
- check_status_button.click(fn=update_output_files,
233
- inputs=None,
234
- outputs=output_files,
235
- queue=False)
236
- return demo
237
-
238
-
239
  def find_weight_files() -> list[str]:
240
  curr_dir = pathlib.Path(__file__).parent
241
  paths = sorted(curr_dir.rglob('*.bin'))
@@ -251,49 +85,32 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
251
  with gr.Blocks() as demo:
252
  with gr.Row():
253
  with gr.Column():
254
- base_model = gr.Dropdown(
255
- choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
256
- value='CompVis/stable-diffusion-v1-4',
257
- label='Base Model',
258
  visible=True)
259
- resolution = gr.Dropdown(choices=[512, 768],
260
- value=512,
261
- label='Resolution',
262
- visible=True)
263
  reload_button = gr.Button('Reload Weight List')
264
- weight_name = gr.Dropdown(choices=find_weight_files(),
265
- value='custom-diffusion-models/cat.bin',
266
- label='Custom Diffusion Weight File')
267
  prompt = gr.Textbox(
268
  label='Prompt',
269
  max_lines=1,
270
- placeholder='Example: "\<new1\> cat in outer space"')
271
- seed = gr.Slider(label='Seed',
272
- minimum=0,
273
- maximum=100000,
274
- step=1,
275
- value=42)
276
  with gr.Accordion('Other Parameters', open=False):
277
- num_steps = gr.Slider(label='Number of Steps',
278
- minimum=0,
279
- maximum=500,
280
- step=1,
281
- value=100)
282
- guidance_scale = gr.Slider(label='CFG Scale',
283
  minimum=0,
284
  maximum=50,
285
  step=0.1,
286
- value=6)
287
- eta = gr.Slider(label='DDIM eta',
288
- minimum=0,
289
- maximum=1.,
290
- step=0.1,
291
- value=1.)
292
- batch_size = gr.Slider(label='Batch Size',
293
  minimum=0,
294
  maximum=10.,
295
  step=1,
296
- value=1)
297
 
298
  run_button = gr.Button('Generate')
299
 
@@ -308,61 +125,27 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
308
  reload_button.click(fn=reload_custom_diffusion_weight_list,
309
  inputs=None,
310
  outputs=weight_name)
311
- prompt.submit(fn=pipe.run,
312
  inputs=[
313
- base_model,
314
- weight_name,
315
  prompt,
316
- seed,
317
- num_steps,
318
- guidance_scale,
319
- eta,
320
- batch_size,
321
- resolution
322
  ],
323
  outputs=result,
324
  queue=False)
325
- run_button.click(fn=pipe.run,
326
- inputs=[
327
- base_model,
328
- weight_name,
329
- prompt,
330
- seed,
331
- num_steps,
332
- guidance_scale,
333
- eta,
334
- batch_size,
335
- resolution
336
- ],
337
  outputs=result,
338
  queue=False)
339
  return demo
340
 
341
 
342
- def create_upload_demo() -> gr.Blocks:
343
- with gr.Blocks() as demo:
344
- model_name = gr.Textbox(label='Model Name')
345
- hf_token = gr.Textbox(
346
- label='Hugging Face Token (with write permission)')
347
- upload_button = gr.Button('Upload')
348
- with gr.Box():
349
- gr.Markdown('Message')
350
- result = gr.Markdown()
351
- gr.Markdown('''
352
- - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
353
- - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
354
- ''')
355
-
356
- upload_button.click(fn=upload,
357
- inputs=[model_name, hf_token],
358
- outputs=result)
359
-
360
- return demo
361
-
362
-
363
- pipe = InferencePipeline()
364
- trainer = Trainer()
365
-
366
  with gr.Blocks(css='style.css') as demo:
367
  if os.getenv('IS_SHARED_UI'):
368
  show_warning(SHARED_UI_WARNING)
@@ -374,12 +157,10 @@ with gr.Blocks(css='style.css') as demo:
374
  gr.Markdown(DETAILDESCRIPTION)
375
 
376
  with gr.Tabs():
377
- with gr.TabItem('Train'):
378
- create_training_demo(trainer, pipe)
379
  with gr.TabItem('Test'):
380
  create_inference_demo(pipe)
381
- with gr.TabItem('Upload'):
382
- create_upload_demo()
383
 
384
  demo.queue(default_enabled=False).launch(share=False)
385
 
 
19
  import gradio as gr
20
  import torch
21
 
22
+ from inference import inference_fn
23
+ # from inference_custom_diffusion import InferencePipeline
24
  # from trainer import Trainer
25
  # from uploader import upload
26
 
 
70
  paths = [path.as_posix() for path in paths] # type: ignore
71
  return gr.update(value=paths or None)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def find_weight_files() -> list[str]:
74
  curr_dir = pathlib.Path(__file__).parent
75
  paths = sorted(curr_dir.rglob('*.bin'))
 
85
  with gr.Blocks() as demo:
86
  with gr.Row():
87
  with gr.Column():
88
+ model_id = gr.Dropdown(
89
+ choices=['experiments/painted_on'],
90
+ value='experiments/painted_on',
91
+ label='Relation',
92
  visible=True)
 
 
 
 
93
  reload_button = gr.Button('Reload Weight List')
 
 
 
94
  prompt = gr.Textbox(
95
  label='Prompt',
96
  max_lines=1,
97
+ placeholder='Example: "cat <R> stone"')
98
+ placeholder_string = gr.Textbox(
99
+ label='Placeholder String',
100
+ max_lines=1,
101
+ placeholder='Example: "<R>"')
102
+
103
  with gr.Accordion('Other Parameters', open=False):
104
+ guidance_scale = gr.Slider(label='Classifier-Free Guidance Scale',
 
 
 
 
 
105
  minimum=0,
106
  maximum=50,
107
  step=0.1,
108
+ value=7.5)
109
+ num_samples = gr.Slider(label='Batch Size',
 
 
 
 
 
110
  minimum=0,
111
  maximum=10.,
112
  step=1,
113
+ value=10)
114
 
115
  run_button = gr.Button('Generate')
116
 
 
125
  reload_button.click(fn=reload_custom_diffusion_weight_list,
126
  inputs=None,
127
  outputs=weight_name)
128
+ prompt.submit(fn=inference_fn,
129
  inputs=[
130
+ model_id,
 
131
  prompt,
132
+ placeholder_string,
133
+ guidance_scale
 
 
 
 
134
  ],
135
  outputs=result,
136
  queue=False)
137
+ run_button.click(fn=inference_fn,
138
+ inputs=[
139
+ model_id,
140
+ prompt,
141
+ placeholder_string,
142
+ guidance_scale
143
+ ],
 
 
 
 
 
144
  outputs=result,
145
  queue=False)
146
  return demo
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  with gr.Blocks(css='style.css') as demo:
150
  if os.getenv('IS_SHARED_UI'):
151
  show_warning(SHARED_UI_WARNING)
 
157
  gr.Markdown(DETAILDESCRIPTION)
158
 
159
  with gr.Tabs():
160
+
 
161
  with gr.TabItem('Test'):
162
  create_inference_demo(pipe)
163
+
 
164
 
165
  demo.queue(default_enabled=False).launch(share=False)
166
 
inference.py CHANGED
@@ -12,70 +12,85 @@ import torch
12
  from diffusers import StableDiffusionPipeline
13
  sys.path.insert(0, './ReVersion')
14
 
 
 
 
15
 
16
- class InferencePipeline:
17
- def __init__(self):
18
- self.pipe = None
19
- self.device = torch.device(
20
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
- self.weight_path = None
22
-
23
- def clear(self) -> None:
24
- self.weight_path = None
25
- del self.pipe
26
- self.pipe = None
27
- torch.cuda.empty_cache()
28
- gc.collect()
29
-
30
- @staticmethod
31
- def get_weight_path(name: str) -> pathlib.Path:
32
- curr_dir = pathlib.Path(__file__).parent
33
- return curr_dir / name
34
-
35
- def load_pipe(self, model_id: str, filename: str) -> None:
36
- weight_path = self.get_weight_path(filename)
37
- if weight_path == self.weight_path:
38
- return
39
- self.weight_path = weight_path
40
- weight = torch.load(self.weight_path, map_location=self.device)
41
-
42
- if self.device.type == 'cpu':
43
- pipe = StableDiffusionPipeline.from_pretrained(model_id)
44
- else:
45
- pipe = StableDiffusionPipeline.from_pretrained(
46
- model_id, torch_dtype=torch.float16)
47
- pipe = pipe.to(self.device)
48
-
49
- from src import diffuser_training
50
- diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, compress=False)
51
-
52
- self.pipe = pipe
53
-
54
- def run(
55
- self,
56
- base_model: str,
57
- weight_name: str,
58
- prompt: str,
59
- seed: int,
60
- n_steps: int,
61
- guidance_scale: float,
62
- eta: float,
63
- batch_size: int,
64
- resolution: int,
65
- ) -> PIL.Image.Image:
66
- if not torch.cuda.is_available():
67
- raise gr.Error('CUDA is not available.')
68
-
69
- self.load_pipe(base_model, weight_name)
70
-
71
- generator = torch.Generator(device=self.device).manual_seed(seed)
72
- out = self.pipe([prompt]*batch_size,
73
- num_inference_steps=n_steps,
74
- guidance_scale=guidance_scale,
75
- height=resolution, width=resolution,
76
- eta = eta,
77
- generator=generator) # type: ignore
78
- torch.cuda.empty_cache()
79
- out = out.images
80
- out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
81
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from diffusers import StableDiffusionPipeline
13
  sys.path.insert(0, './ReVersion')
14
 
15
+ # below are original
16
+ import os
17
+ # import argparse
18
 
19
+ # import torch
20
+ from PIL import Image
21
+
22
+ # from diffusers import StableDiffusionPipeline
23
+ # sys.path.insert(0, './ReVersion')
24
+ from templates.templates import inference_templates
25
+
26
+ import math
27
+
28
+ """
29
+ Inference script for generating batch results
30
+ """
31
+
32
+ def make_image_grid(imgs, rows, cols):
33
+ assert len(imgs) == rows*cols
34
+
35
+ w, h = imgs[0].size
36
+ grid = Image.new('RGB', size=(cols*w, rows*h))
37
+ grid_w, grid_h = grid.size
38
+
39
+ for i, img in enumerate(imgs):
40
+ grid.paste(img, box=(i%cols*w, i//cols*h))
41
+ return grid
42
+
43
+
44
+ def inference_fn(
45
+ model_id,
46
+ prompt,
47
+ placeholder_string,
48
+ num_samples,
49
+ guidance_scale
50
+ ):
51
+
52
+ # create inference pipeline
53
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
54
+
55
+ # make directory to save images
56
+ image_root_folder = os.path.join(model_id, 'inference')
57
+ os.makedirs(image_root_folder, exist_ok = True)
58
+
59
+ if prompt is None and args.template_name is None:
60
+ raise ValueError("please input a single prompt through'--prompt' or select a batch of prompts using '--template_name'.")
61
+
62
+ # single text prompt
63
+ if prompt is not None:
64
+ prompt_list = [prompt]
65
+ else:
66
+ prompt_list = []
67
+
68
+ if args.template_name is not None:
69
+ # read the selected text prompts for generation
70
+ prompt_list.extend(inference_templates[args.template_name])
71
+
72
+ for prompt in prompt_list:
73
+ # insert relation prompt <R>
74
+ prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string)
75
+
76
+ # make sub-folder
77
+ image_folder = os.path.join(image_root_folder, prompt, 'samples')
78
+ os.makedirs(image_folder, exist_ok = True)
79
+
80
+ # batch generation
81
+ images = pipe(prompt, num_inference_steps=50, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
82
+
83
+ # save generated images
84
+ for idx, image in enumerate(images):
85
+ image_name = f"{str(idx).zfill(4)}.png"
86
+ image_path = os.path.join(image_folder, image_name)
87
+ image.save(image_path)
88
+
89
+ # save a grid of images
90
+ image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
91
+ image_grid_path = os.path.join(image_root_folder, prompt, f'{prompt}.png')
92
+
93
+ return image_grid
94
+
95
+ if __name__ == "__main__":
96
+ main()