torinriley commited on
Commit
b7b4c25
·
1 Parent(s): d9cf71d

updatwe, added mreo optiosn

Browse files
app.py CHANGED
@@ -52,7 +52,7 @@ config.models = model_loader.load_models(str(model_file), device)
52
  MAX_SEED = np.iinfo(np.int32).max
53
  MAX_IMAGE_SIZE = 1024
54
 
55
- def infer(
56
  prompt,
57
  negative_prompt,
58
  seed,
@@ -77,6 +77,7 @@ def infer(
77
  output_image = pipeline.generate(
78
  prompt=prompt,
79
  uncond_prompt=negative_prompt,
 
80
  config=config
81
  )
82
 
@@ -85,6 +86,103 @@ def infer(
85
 
86
  return image, seed
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  examples = [
89
  "A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars",
90
  "A serene mountain landscape at sunset with snow-capped peaks and a clear lake reflection",
@@ -96,31 +194,81 @@ css = """
96
  margin: 0 auto;
97
  max-width: 640px;
98
  }
 
 
 
 
 
 
 
 
 
 
 
99
  """
100
 
101
  with gr.Blocks(css=css) as demo:
102
  with gr.Column(elem_id="col-container"):
103
  gr.Markdown(" # LiteDiffusion")
104
 
105
- with gr.Row():
106
- prompt = gr.Text(
107
- label="Prompt",
108
- show_label=False,
109
- max_lines=1,
110
- placeholder="Enter your prompt",
111
- container=False,
112
- )
113
-
114
- run_button = gr.Button("Run", scale=0, variant="primary")
115
 
116
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  with gr.Accordion("Advanced Settings", open=False):
119
  negative_prompt = gr.Text(
120
  label="Negative prompt",
121
  max_lines=1,
122
  placeholder="Enter a negative prompt",
123
- visible=False,
124
  )
125
 
126
  seed = gr.Slider(
@@ -166,14 +314,54 @@ with gr.Blocks(css=css) as demo:
166
  step=1,
167
  value=50,
168
  )
169
-
170
- gr.Examples(examples=examples, inputs=[prompt])
 
 
 
 
 
 
171
 
172
- gr.on(
173
- triggers=[run_button.click, prompt.submit],
174
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  inputs=[
176
- prompt,
177
  negative_prompt,
178
  seed,
179
  randomize_seed,
@@ -181,8 +369,11 @@ with gr.Blocks(css=css) as demo:
181
  height,
182
  guidance_scale,
183
  num_inference_steps,
 
 
 
184
  ],
185
- outputs=[result, seed],
186
  )
187
 
188
  if __name__ == "__main__":
 
52
  MAX_SEED = np.iinfo(np.int32).max
53
  MAX_IMAGE_SIZE = 1024
54
 
55
+ def txt2img(
56
  prompt,
57
  negative_prompt,
58
  seed,
 
77
  output_image = pipeline.generate(
78
  prompt=prompt,
79
  uncond_prompt=negative_prompt,
80
+ input_image=None,
81
  config=config
82
  )
83
 
 
86
 
87
  return image, seed
88
 
89
+ def img2img(
90
+ prompt,
91
+ negative_prompt,
92
+ seed,
93
+ randomize_seed,
94
+ width,
95
+ height,
96
+ guidance_scale,
97
+ num_inference_steps,
98
+ input_image,
99
+ strength,
100
+ progress=gr.Progress(track_tqdm=True),
101
+ ):
102
+ try:
103
+ if randomize_seed:
104
+ seed = random.randint(0, MAX_SEED)
105
+
106
+ if input_image is None:
107
+ return None, seed
108
+
109
+ # Update config with user settings
110
+ config.seed = seed
111
+ config.diffusion.cfg_scale = guidance_scale
112
+ config.diffusion.n_inference_steps = num_inference_steps
113
+ config.model.width = width
114
+ config.model.height = height
115
+ config.diffusion.strength = strength
116
+
117
+ # Generate image
118
+ output_image = pipeline.generate(
119
+ prompt=prompt,
120
+ uncond_prompt=negative_prompt,
121
+ input_image=input_image,
122
+ config=config
123
+ )
124
+
125
+ # Convert numpy array to PIL Image
126
+ image = Image.fromarray(output_image)
127
+
128
+ return image, seed
129
+ except Exception as e:
130
+ print(f"Error in img2img: {str(e)}")
131
+ gr.Warning(f"Error: {str(e)}")
132
+ return None, seed
133
+
134
+ def inpaint(
135
+ prompt,
136
+ negative_prompt,
137
+ seed,
138
+ randomize_seed,
139
+ width,
140
+ height,
141
+ guidance_scale,
142
+ num_inference_steps,
143
+ input_image,
144
+ mask_image,
145
+ strength,
146
+ progress=gr.Progress(track_tqdm=True),
147
+ ):
148
+ try:
149
+ if randomize_seed:
150
+ seed = random.randint(0, MAX_SEED)
151
+
152
+ if input_image is None or mask_image is None:
153
+ gr.Warning("Both input image and mask are required for inpainting")
154
+ return None, seed
155
+
156
+ # Ensure mask is in the right format
157
+ if mask_image.mode != "L":
158
+ mask_image = mask_image.convert("L")
159
+
160
+ # Update config with user settings
161
+ config.seed = seed
162
+ config.diffusion.cfg_scale = guidance_scale
163
+ config.diffusion.n_inference_steps = num_inference_steps
164
+ config.model.width = width
165
+ config.model.height = height
166
+ config.diffusion.strength = strength
167
+
168
+ # Generate image with mask
169
+ output_image = pipeline.generate(
170
+ prompt=prompt,
171
+ uncond_prompt=negative_prompt,
172
+ input_image=input_image,
173
+ mask_image=mask_image,
174
+ config=config
175
+ )
176
+
177
+ # Convert numpy array to PIL Image
178
+ image = Image.fromarray(output_image)
179
+
180
+ return image, seed
181
+ except Exception as e:
182
+ print(f"Error in inpainting: {str(e)}")
183
+ gr.Warning(f"Error: {str(e)}")
184
+ return None, seed
185
+
186
  examples = [
187
  "A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars",
188
  "A serene mountain landscape at sunset with snow-capped peaks and a clear lake reflection",
 
194
  margin: 0 auto;
195
  max-width: 640px;
196
  }
197
+
198
+ .tabs {
199
+ margin-top: 10px;
200
+ margin-bottom: 10px;
201
+ }
202
+
203
+ .disclaimer {
204
+ font-size: 0.8em;
205
+ color: #666;
206
+ margin-top: 20px;
207
+ }
208
  """
209
 
210
  with gr.Blocks(css=css) as demo:
211
  with gr.Column(elem_id="col-container"):
212
  gr.Markdown(" # LiteDiffusion")
213
 
214
+ with gr.Tabs(elem_classes="tabs") as tabs:
215
+ with gr.TabItem("Text-to-Image"):
216
+ txt2img_prompt = gr.Text(
217
+ label="Prompt",
218
+ max_lines=1,
219
+ placeholder="Enter your prompt",
220
+ )
221
+ txt2img_run = gr.Button("Generate", variant="primary")
222
+ txt2img_result = gr.Image(label="Result")
 
223
 
224
+ with gr.TabItem("Image-to-Image"):
225
+ img2img_prompt = gr.Text(
226
+ label="Prompt",
227
+ max_lines=1,
228
+ placeholder="Enter your prompt",
229
+ )
230
+ with gr.Row():
231
+ with gr.Column(scale=1):
232
+ input_image = gr.Image(label="Input Image", type="pil")
233
+ strength_slider = gr.Slider(
234
+ label="Strength",
235
+ minimum=0.0,
236
+ maximum=1.0,
237
+ step=0.01,
238
+ value=0.8,
239
+ )
240
+ img2img_run = gr.Button("Generate", variant="primary")
241
+
242
+ with gr.Column(scale=1):
243
+ img2img_result = gr.Image(label="Result")
244
+
245
+ with gr.TabItem("Inpainting"):
246
+ inpaint_prompt = gr.Text(
247
+ label="Prompt",
248
+ max_lines=1,
249
+ placeholder="Enter your prompt",
250
+ )
251
+ with gr.Row():
252
+ with gr.Column(scale=1):
253
+ inpaint_image = gr.Image(label="Input Image", type="pil")
254
+ inpaint_mask = gr.Image(label="Mask (White areas will be inpainted)", type="pil")
255
+ inpaint_strength = gr.Slider(
256
+ label="Strength",
257
+ minimum=0.0,
258
+ maximum=1.0,
259
+ step=0.01,
260
+ value=0.8,
261
+ )
262
+ inpaint_run = gr.Button("Generate", variant="primary")
263
+
264
+ with gr.Column(scale=1):
265
+ inpaint_result = gr.Image(label="Result")
266
 
267
  with gr.Accordion("Advanced Settings", open=False):
268
  negative_prompt = gr.Text(
269
  label="Negative prompt",
270
  max_lines=1,
271
  placeholder="Enter a negative prompt",
 
272
  )
273
 
274
  seed = gr.Slider(
 
314
  step=1,
315
  value=50,
316
  )
317
+
318
+ gr.Markdown(
319
+ "By using LiteDiffusion, you agree to the terms in our [disclaimer](disclaimer.md).",
320
+ elem_classes="disclaimer"
321
+ )
322
+
323
+ # Example prompts for text to image
324
+ gr.Examples(examples=examples, inputs=[txt2img_prompt])
325
 
326
+ # Text-to-Image generation
327
+ txt2img_run.click(
328
+ fn=txt2img,
329
+ inputs=[
330
+ txt2img_prompt,
331
+ negative_prompt,
332
+ seed,
333
+ randomize_seed,
334
+ width,
335
+ height,
336
+ guidance_scale,
337
+ num_inference_steps,
338
+ ],
339
+ outputs=[txt2img_result, seed],
340
+ )
341
+
342
+ # Image-to-Image generation
343
+ img2img_run.click(
344
+ fn=img2img,
345
+ inputs=[
346
+ img2img_prompt,
347
+ negative_prompt,
348
+ seed,
349
+ randomize_seed,
350
+ width,
351
+ height,
352
+ guidance_scale,
353
+ num_inference_steps,
354
+ input_image,
355
+ strength_slider,
356
+ ],
357
+ outputs=[img2img_result, seed],
358
+ )
359
+
360
+ # Inpainting
361
+ inpaint_run.click(
362
+ fn=inpaint,
363
  inputs=[
364
+ inpaint_prompt,
365
  negative_prompt,
366
  seed,
367
  randomize_seed,
 
369
  height,
370
  guidance_scale,
371
  num_inference_steps,
372
+ inpaint_image,
373
+ inpaint_mask,
374
+ inpaint_strength,
375
  ],
376
+ outputs=[inpaint_result, seed],
377
  )
378
 
379
  if __name__ == "__main__":
disclaimer.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Disclaimer
2
+
3
+ ## LiteDiffusion - Legal Disclaimer
4
+
5
+ The LiteDiffusion model ("the Model") is provided by Torin Etheridge ("the Author") as-is and without warranty of any kind, express or implied.
6
+
7
+ ### Limitation of Liability
8
+
9
+ Torin Etheridge is not responsible for any misuse of this model or any content generated using this software. Users are solely responsible for how they use the Model and any content they generate with it.
10
+
11
+ ### Content Generation
12
+
13
+ The Model is capable of generating synthetic images based on text prompts. Users are responsible for:
14
+ - Ensuring they have the right to generate specific content
15
+ - Using the generated content in accordance with applicable laws and regulations
16
+ - Not using the Model to create harmful, offensive, or illegal content
17
+
18
+ ### No Medical or Professional Advice
19
+
20
+ Content generated by the Model should not be used for medical, legal, financial, or other professional advice.
21
+
22
+ ### Changes to this Disclaimer
23
+
24
+ This disclaimer may be updated from time to time without notice.
25
+
26
+ ### Contact
27
+
28
+ If you have any questions about this disclaimer, please contact the Author.
29
+
30
+ **By using LiteDiffusion, you acknowledge that you have read and understood this disclaimer.**
src/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (196 Bytes)
 
src/__pycache__/attention.cpython-312.pyc DELETED
Binary file (4.69 kB)
 
src/__pycache__/clip.cpython-312.pyc DELETED
Binary file (4.02 kB)
 
src/__pycache__/config.cpython-312.pyc DELETED
Binary file (3.4 kB)
 
src/__pycache__/ddpm.cpython-312.pyc DELETED
Binary file (6.46 kB)
 
src/__pycache__/decoder.cpython-312.pyc DELETED
Binary file (4.93 kB)
 
src/__pycache__/diffusion.cpython-312.pyc DELETED
Binary file (14.2 kB)
 
src/__pycache__/encoder.cpython-312.pyc DELETED
Binary file (2.56 kB)
 
src/__pycache__/model_converter.cpython-312.pyc DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc31a7458a7d5afc6251204fd5949d56297f0e0bc97b6b307d2d70b3e2b38d97
3
- size 170127
 
 
 
 
src/__pycache__/model_loader.cpython-312.pyc DELETED
Binary file (1.86 kB)
 
src/__pycache__/pipeline.cpython-312.pyc DELETED
Binary file (8.11 kB)
 
src/pipeline.py CHANGED
@@ -13,20 +13,6 @@ LATENTS_HEIGHT = HEIGHT // 8
13
 
14
  logging.basicConfig(level=logging.INFO)
15
 
16
- def generate(
17
- prompt,
18
- uncond_prompt=None,
19
- input_image=None,
20
- config: Config = default_config,
21
- ):
22
- with torch.no_grad():
23
- validate_strength(config.diffusion.strength)
24
- generator = initialize_generator(config.seed, config.device.device)
25
- context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, config.tokenizer, config.models["clip"], config.device.device)
26
- latents = initialize_latents(input_image, config.diffusion.strength, generator, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps)
27
- images = run_diffusion(latents, context, config.diffusion.do_cfg, config.diffusion.cfg_scale, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps, generator)
28
- return postprocess_images(images)
29
-
30
  def validate_strength(strength):
31
  if not 0 < strength <= 1:
32
  raise ValueError("Strength must be between 0 and 1")
@@ -45,7 +31,7 @@ def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
45
  cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
46
  cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
47
  cond_context = clip(cond_tokens)
48
- uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
49
  uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
50
  uncond_context = clip(uncond_tokens)
51
  context = torch.cat([cond_context, uncond_context])
@@ -55,17 +41,15 @@ def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
55
  context = clip(tokens)
56
  return context
57
 
58
- def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps):
59
- if input_image is None:
60
- # Initialize with random noise
61
- latents = torch.randn((1, 4, 64, 64), generator=generator, device=device)
62
- else:
63
- # Initialize with encoded input image
64
- latents = encode_image(input_image, models, device)
65
- # Add noise based on strength
66
- noise = torch.randn_like(latents, generator=generator)
67
- latents = (1 - strength) * latents + strength * noise
68
- return latents
69
 
70
  def preprocess_image(input_image):
71
  input_image_tensor = input_image.resize((WIDTH, HEIGHT))
@@ -76,6 +60,51 @@ def preprocess_image(input_image):
76
  input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
77
  return input_image_tensor
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def get_sampler(sampler_name, generator, n_inference_steps):
80
  if sampler_name == "ddpm":
81
  sampler = DDPMSampler(generator)
@@ -84,6 +113,11 @@ def get_sampler(sampler_name, generator, n_inference_steps):
84
  raise ValueError(f"Unknown sampler value {sampler_name}.")
85
  return sampler
86
 
 
 
 
 
 
87
  def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
88
  diffusion = models["diffusion"]
89
  diffusion.to(device)
@@ -108,17 +142,42 @@ def postprocess_images(images):
108
  images = images.to("cpu", torch.uint8).numpy()
109
  return images[0]
110
 
111
- def rescale(x, old_range, new_range, clamp=False):
112
- old_min, old_max = old_range
113
- new_min, new_max = new_range
114
- x -= old_min
115
- x *= (new_max - new_min) / (old_max - old_min)
116
- x += new_min
117
- if clamp:
118
- x = x.clamp(new_min, new_max)
119
- return x
120
-
121
- def get_time_embedding(timestep):
122
- freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
123
- x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
124
- return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  logging.basicConfig(level=logging.INFO)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def validate_strength(strength):
17
  if not 0 < strength <= 1:
18
  raise ValueError("Strength must be between 0 and 1")
 
31
  cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
32
  cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
33
  cond_context = clip(cond_tokens)
34
+ uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt or ""], padding="max_length", max_length=77).input_ids
35
  uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
36
  uncond_context = clip(uncond_tokens)
37
  context = torch.cat([cond_context, uncond_context])
 
41
  context = clip(tokens)
42
  return context
43
 
44
+ def rescale(x, old_range, new_range, clamp=False):
45
+ old_min, old_max = old_range
46
+ new_min, new_max = new_range
47
+ x -= old_min
48
+ x *= (new_max - new_min) / (old_max - old_min)
49
+ x += new_min
50
+ if clamp:
51
+ x = x.clamp(new_min, new_max)
52
+ return x
 
 
53
 
54
  def preprocess_image(input_image):
55
  input_image_tensor = input_image.resize((WIDTH, HEIGHT))
 
60
  input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
61
  return input_image_tensor
62
 
63
+ def encode_image(input_image, models, device):
64
+ # Preprocess the input image
65
+ image_tensor = preprocess_image(input_image).to(device)
66
+
67
+ # Encode the image using the VAE encoder
68
+ encoder = models["encoder"]
69
+ encoder.to(device)
70
+ with torch.no_grad():
71
+ # Create deterministic noise (zeros) since we want exact reconstruction
72
+ noise = torch.zeros((1, 4, LATENTS_WIDTH, LATENTS_HEIGHT), device=device)
73
+ latents = encoder(image_tensor, noise)
74
+
75
+ return latents
76
+
77
+ def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps, mask_image=None):
78
+ if input_image is None:
79
+ # Initialize with random noise
80
+ latents = torch.randn((1, 4, LATENTS_WIDTH, LATENTS_HEIGHT), generator=generator, device=device)
81
+ else:
82
+ # Initialize with encoded input image
83
+ latents = encode_image(input_image, models, device)
84
+
85
+ # If mask is provided for inpainting
86
+ if mask_image is not None:
87
+ # Process mask
88
+ mask = mask_image.resize((WIDTH, HEIGHT))
89
+ mask = np.array(mask)
90
+ mask = torch.tensor(mask, dtype=torch.float32).to(device)
91
+ mask = mask / 255.0 # Normalize to 0-1
92
+ mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
93
+ mask = F.interpolate(mask, (LATENTS_WIDTH, LATENTS_HEIGHT))
94
+ mask = mask.repeat(1, 4, 1, 1) # Repeat for all latent channels
95
+
96
+ # Create masked noise - torch.randn_like doesn't accept generator
97
+ noise = torch.randn(latents.shape, device=device)
98
+ masked_latents = latents * (1 - mask) + noise * mask
99
+ latents = masked_latents
100
+
101
+ # Add noise based on strength (for img2img)
102
+ # torch.randn_like doesn't accept generator
103
+ noise = torch.randn(latents.shape, device=device)
104
+ latents = (1 - strength) * latents + strength * noise
105
+
106
+ return latents
107
+
108
  def get_sampler(sampler_name, generator, n_inference_steps):
109
  if sampler_name == "ddpm":
110
  sampler = DDPMSampler(generator)
 
113
  raise ValueError(f"Unknown sampler value {sampler_name}.")
114
  return sampler
115
 
116
+ def get_time_embedding(timestep):
117
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
118
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
119
+ return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
120
+
121
  def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
122
  diffusion = models["diffusion"]
123
  diffusion.to(device)
 
142
  images = images.to("cpu", torch.uint8).numpy()
143
  return images[0]
144
 
145
+ def generate(
146
+ prompt,
147
+ uncond_prompt=None,
148
+ input_image=None,
149
+ mask_image=None,
150
+ config: Config = default_config,
151
+ ):
152
+ with torch.no_grad():
153
+ # Validate inputs and parameters
154
+ if prompt is None or prompt.strip() == "":
155
+ raise ValueError("Prompt cannot be empty")
156
+
157
+ if uncond_prompt is None:
158
+ uncond_prompt = ""
159
+
160
+ validate_strength(config.diffusion.strength)
161
+
162
+ # Initialize generator for reproducibility
163
+ generator = initialize_generator(config.seed, config.device.device)
164
+
165
+ # Encode text prompt
166
+ context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg,
167
+ config.tokenizer, config.models["clip"], config.device.device)
168
+
169
+ # Initialize latents (either from noise or from input image)
170
+ latents = initialize_latents(input_image, config.diffusion.strength, generator,
171
+ config.models, config.device.device,
172
+ config.diffusion.sampler_name,
173
+ config.diffusion.n_inference_steps,
174
+ mask_image)
175
+
176
+ # Run diffusion process
177
+ images = run_diffusion(latents, context, config.diffusion.do_cfg,
178
+ config.diffusion.cfg_scale, config.models,
179
+ config.device.device, config.diffusion.sampler_name,
180
+ config.diffusion.n_inference_steps, generator)
181
+
182
+ # Post-process and return the images
183
+ return postprocess_images(images)