Spaces:
Paused
Paused
Commit
·
b7b4c25
1
Parent(s):
d9cf71d
updatwe, added mreo optiosn
Browse files- app.py +211 -20
- disclaimer.md +30 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/attention.cpython-312.pyc +0 -0
- src/__pycache__/clip.cpython-312.pyc +0 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/ddpm.cpython-312.pyc +0 -0
- src/__pycache__/decoder.cpython-312.pyc +0 -0
- src/__pycache__/diffusion.cpython-312.pyc +0 -0
- src/__pycache__/encoder.cpython-312.pyc +0 -0
- src/__pycache__/model_converter.cpython-312.pyc +0 -3
- src/__pycache__/model_loader.cpython-312.pyc +0 -0
- src/__pycache__/pipeline.cpython-312.pyc +0 -0
- src/pipeline.py +99 -40
app.py
CHANGED
@@ -52,7 +52,7 @@ config.models = model_loader.load_models(str(model_file), device)
|
|
52 |
MAX_SEED = np.iinfo(np.int32).max
|
53 |
MAX_IMAGE_SIZE = 1024
|
54 |
|
55 |
-
def
|
56 |
prompt,
|
57 |
negative_prompt,
|
58 |
seed,
|
@@ -77,6 +77,7 @@ def infer(
|
|
77 |
output_image = pipeline.generate(
|
78 |
prompt=prompt,
|
79 |
uncond_prompt=negative_prompt,
|
|
|
80 |
config=config
|
81 |
)
|
82 |
|
@@ -85,6 +86,103 @@ def infer(
|
|
85 |
|
86 |
return image, seed
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
examples = [
|
89 |
"A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars",
|
90 |
"A serene mountain landscape at sunset with snow-capped peaks and a clear lake reflection",
|
@@ -96,31 +194,81 @@ css = """
|
|
96 |
margin: 0 auto;
|
97 |
max-width: 640px;
|
98 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
"""
|
100 |
|
101 |
with gr.Blocks(css=css) as demo:
|
102 |
with gr.Column(elem_id="col-container"):
|
103 |
gr.Markdown(" # LiteDiffusion")
|
104 |
|
105 |
-
with gr.
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
run_button = gr.Button("Run", scale=0, variant="primary")
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
with gr.Accordion("Advanced Settings", open=False):
|
119 |
negative_prompt = gr.Text(
|
120 |
label="Negative prompt",
|
121 |
max_lines=1,
|
122 |
placeholder="Enter a negative prompt",
|
123 |
-
visible=False,
|
124 |
)
|
125 |
|
126 |
seed = gr.Slider(
|
@@ -166,14 +314,54 @@ with gr.Blocks(css=css) as demo:
|
|
166 |
step=1,
|
167 |
value=50,
|
168 |
)
|
169 |
-
|
170 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
inputs=[
|
176 |
-
|
177 |
negative_prompt,
|
178 |
seed,
|
179 |
randomize_seed,
|
@@ -181,8 +369,11 @@ with gr.Blocks(css=css) as demo:
|
|
181 |
height,
|
182 |
guidance_scale,
|
183 |
num_inference_steps,
|
|
|
|
|
|
|
184 |
],
|
185 |
-
outputs=[
|
186 |
)
|
187 |
|
188 |
if __name__ == "__main__":
|
|
|
52 |
MAX_SEED = np.iinfo(np.int32).max
|
53 |
MAX_IMAGE_SIZE = 1024
|
54 |
|
55 |
+
def txt2img(
|
56 |
prompt,
|
57 |
negative_prompt,
|
58 |
seed,
|
|
|
77 |
output_image = pipeline.generate(
|
78 |
prompt=prompt,
|
79 |
uncond_prompt=negative_prompt,
|
80 |
+
input_image=None,
|
81 |
config=config
|
82 |
)
|
83 |
|
|
|
86 |
|
87 |
return image, seed
|
88 |
|
89 |
+
def img2img(
|
90 |
+
prompt,
|
91 |
+
negative_prompt,
|
92 |
+
seed,
|
93 |
+
randomize_seed,
|
94 |
+
width,
|
95 |
+
height,
|
96 |
+
guidance_scale,
|
97 |
+
num_inference_steps,
|
98 |
+
input_image,
|
99 |
+
strength,
|
100 |
+
progress=gr.Progress(track_tqdm=True),
|
101 |
+
):
|
102 |
+
try:
|
103 |
+
if randomize_seed:
|
104 |
+
seed = random.randint(0, MAX_SEED)
|
105 |
+
|
106 |
+
if input_image is None:
|
107 |
+
return None, seed
|
108 |
+
|
109 |
+
# Update config with user settings
|
110 |
+
config.seed = seed
|
111 |
+
config.diffusion.cfg_scale = guidance_scale
|
112 |
+
config.diffusion.n_inference_steps = num_inference_steps
|
113 |
+
config.model.width = width
|
114 |
+
config.model.height = height
|
115 |
+
config.diffusion.strength = strength
|
116 |
+
|
117 |
+
# Generate image
|
118 |
+
output_image = pipeline.generate(
|
119 |
+
prompt=prompt,
|
120 |
+
uncond_prompt=negative_prompt,
|
121 |
+
input_image=input_image,
|
122 |
+
config=config
|
123 |
+
)
|
124 |
+
|
125 |
+
# Convert numpy array to PIL Image
|
126 |
+
image = Image.fromarray(output_image)
|
127 |
+
|
128 |
+
return image, seed
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Error in img2img: {str(e)}")
|
131 |
+
gr.Warning(f"Error: {str(e)}")
|
132 |
+
return None, seed
|
133 |
+
|
134 |
+
def inpaint(
|
135 |
+
prompt,
|
136 |
+
negative_prompt,
|
137 |
+
seed,
|
138 |
+
randomize_seed,
|
139 |
+
width,
|
140 |
+
height,
|
141 |
+
guidance_scale,
|
142 |
+
num_inference_steps,
|
143 |
+
input_image,
|
144 |
+
mask_image,
|
145 |
+
strength,
|
146 |
+
progress=gr.Progress(track_tqdm=True),
|
147 |
+
):
|
148 |
+
try:
|
149 |
+
if randomize_seed:
|
150 |
+
seed = random.randint(0, MAX_SEED)
|
151 |
+
|
152 |
+
if input_image is None or mask_image is None:
|
153 |
+
gr.Warning("Both input image and mask are required for inpainting")
|
154 |
+
return None, seed
|
155 |
+
|
156 |
+
# Ensure mask is in the right format
|
157 |
+
if mask_image.mode != "L":
|
158 |
+
mask_image = mask_image.convert("L")
|
159 |
+
|
160 |
+
# Update config with user settings
|
161 |
+
config.seed = seed
|
162 |
+
config.diffusion.cfg_scale = guidance_scale
|
163 |
+
config.diffusion.n_inference_steps = num_inference_steps
|
164 |
+
config.model.width = width
|
165 |
+
config.model.height = height
|
166 |
+
config.diffusion.strength = strength
|
167 |
+
|
168 |
+
# Generate image with mask
|
169 |
+
output_image = pipeline.generate(
|
170 |
+
prompt=prompt,
|
171 |
+
uncond_prompt=negative_prompt,
|
172 |
+
input_image=input_image,
|
173 |
+
mask_image=mask_image,
|
174 |
+
config=config
|
175 |
+
)
|
176 |
+
|
177 |
+
# Convert numpy array to PIL Image
|
178 |
+
image = Image.fromarray(output_image)
|
179 |
+
|
180 |
+
return image, seed
|
181 |
+
except Exception as e:
|
182 |
+
print(f"Error in inpainting: {str(e)}")
|
183 |
+
gr.Warning(f"Error: {str(e)}")
|
184 |
+
return None, seed
|
185 |
+
|
186 |
examples = [
|
187 |
"A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars",
|
188 |
"A serene mountain landscape at sunset with snow-capped peaks and a clear lake reflection",
|
|
|
194 |
margin: 0 auto;
|
195 |
max-width: 640px;
|
196 |
}
|
197 |
+
|
198 |
+
.tabs {
|
199 |
+
margin-top: 10px;
|
200 |
+
margin-bottom: 10px;
|
201 |
+
}
|
202 |
+
|
203 |
+
.disclaimer {
|
204 |
+
font-size: 0.8em;
|
205 |
+
color: #666;
|
206 |
+
margin-top: 20px;
|
207 |
+
}
|
208 |
"""
|
209 |
|
210 |
with gr.Blocks(css=css) as demo:
|
211 |
with gr.Column(elem_id="col-container"):
|
212 |
gr.Markdown(" # LiteDiffusion")
|
213 |
|
214 |
+
with gr.Tabs(elem_classes="tabs") as tabs:
|
215 |
+
with gr.TabItem("Text-to-Image"):
|
216 |
+
txt2img_prompt = gr.Text(
|
217 |
+
label="Prompt",
|
218 |
+
max_lines=1,
|
219 |
+
placeholder="Enter your prompt",
|
220 |
+
)
|
221 |
+
txt2img_run = gr.Button("Generate", variant="primary")
|
222 |
+
txt2img_result = gr.Image(label="Result")
|
|
|
223 |
|
224 |
+
with gr.TabItem("Image-to-Image"):
|
225 |
+
img2img_prompt = gr.Text(
|
226 |
+
label="Prompt",
|
227 |
+
max_lines=1,
|
228 |
+
placeholder="Enter your prompt",
|
229 |
+
)
|
230 |
+
with gr.Row():
|
231 |
+
with gr.Column(scale=1):
|
232 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
233 |
+
strength_slider = gr.Slider(
|
234 |
+
label="Strength",
|
235 |
+
minimum=0.0,
|
236 |
+
maximum=1.0,
|
237 |
+
step=0.01,
|
238 |
+
value=0.8,
|
239 |
+
)
|
240 |
+
img2img_run = gr.Button("Generate", variant="primary")
|
241 |
+
|
242 |
+
with gr.Column(scale=1):
|
243 |
+
img2img_result = gr.Image(label="Result")
|
244 |
+
|
245 |
+
with gr.TabItem("Inpainting"):
|
246 |
+
inpaint_prompt = gr.Text(
|
247 |
+
label="Prompt",
|
248 |
+
max_lines=1,
|
249 |
+
placeholder="Enter your prompt",
|
250 |
+
)
|
251 |
+
with gr.Row():
|
252 |
+
with gr.Column(scale=1):
|
253 |
+
inpaint_image = gr.Image(label="Input Image", type="pil")
|
254 |
+
inpaint_mask = gr.Image(label="Mask (White areas will be inpainted)", type="pil")
|
255 |
+
inpaint_strength = gr.Slider(
|
256 |
+
label="Strength",
|
257 |
+
minimum=0.0,
|
258 |
+
maximum=1.0,
|
259 |
+
step=0.01,
|
260 |
+
value=0.8,
|
261 |
+
)
|
262 |
+
inpaint_run = gr.Button("Generate", variant="primary")
|
263 |
+
|
264 |
+
with gr.Column(scale=1):
|
265 |
+
inpaint_result = gr.Image(label="Result")
|
266 |
|
267 |
with gr.Accordion("Advanced Settings", open=False):
|
268 |
negative_prompt = gr.Text(
|
269 |
label="Negative prompt",
|
270 |
max_lines=1,
|
271 |
placeholder="Enter a negative prompt",
|
|
|
272 |
)
|
273 |
|
274 |
seed = gr.Slider(
|
|
|
314 |
step=1,
|
315 |
value=50,
|
316 |
)
|
317 |
+
|
318 |
+
gr.Markdown(
|
319 |
+
"By using LiteDiffusion, you agree to the terms in our [disclaimer](disclaimer.md).",
|
320 |
+
elem_classes="disclaimer"
|
321 |
+
)
|
322 |
+
|
323 |
+
# Example prompts for text to image
|
324 |
+
gr.Examples(examples=examples, inputs=[txt2img_prompt])
|
325 |
|
326 |
+
# Text-to-Image generation
|
327 |
+
txt2img_run.click(
|
328 |
+
fn=txt2img,
|
329 |
+
inputs=[
|
330 |
+
txt2img_prompt,
|
331 |
+
negative_prompt,
|
332 |
+
seed,
|
333 |
+
randomize_seed,
|
334 |
+
width,
|
335 |
+
height,
|
336 |
+
guidance_scale,
|
337 |
+
num_inference_steps,
|
338 |
+
],
|
339 |
+
outputs=[txt2img_result, seed],
|
340 |
+
)
|
341 |
+
|
342 |
+
# Image-to-Image generation
|
343 |
+
img2img_run.click(
|
344 |
+
fn=img2img,
|
345 |
+
inputs=[
|
346 |
+
img2img_prompt,
|
347 |
+
negative_prompt,
|
348 |
+
seed,
|
349 |
+
randomize_seed,
|
350 |
+
width,
|
351 |
+
height,
|
352 |
+
guidance_scale,
|
353 |
+
num_inference_steps,
|
354 |
+
input_image,
|
355 |
+
strength_slider,
|
356 |
+
],
|
357 |
+
outputs=[img2img_result, seed],
|
358 |
+
)
|
359 |
+
|
360 |
+
# Inpainting
|
361 |
+
inpaint_run.click(
|
362 |
+
fn=inpaint,
|
363 |
inputs=[
|
364 |
+
inpaint_prompt,
|
365 |
negative_prompt,
|
366 |
seed,
|
367 |
randomize_seed,
|
|
|
369 |
height,
|
370 |
guidance_scale,
|
371 |
num_inference_steps,
|
372 |
+
inpaint_image,
|
373 |
+
inpaint_mask,
|
374 |
+
inpaint_strength,
|
375 |
],
|
376 |
+
outputs=[inpaint_result, seed],
|
377 |
)
|
378 |
|
379 |
if __name__ == "__main__":
|
disclaimer.md
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Disclaimer
|
2 |
+
|
3 |
+
## LiteDiffusion - Legal Disclaimer
|
4 |
+
|
5 |
+
The LiteDiffusion model ("the Model") is provided by Torin Etheridge ("the Author") as-is and without warranty of any kind, express or implied.
|
6 |
+
|
7 |
+
### Limitation of Liability
|
8 |
+
|
9 |
+
Torin Etheridge is not responsible for any misuse of this model or any content generated using this software. Users are solely responsible for how they use the Model and any content they generate with it.
|
10 |
+
|
11 |
+
### Content Generation
|
12 |
+
|
13 |
+
The Model is capable of generating synthetic images based on text prompts. Users are responsible for:
|
14 |
+
- Ensuring they have the right to generate specific content
|
15 |
+
- Using the generated content in accordance with applicable laws and regulations
|
16 |
+
- Not using the Model to create harmful, offensive, or illegal content
|
17 |
+
|
18 |
+
### No Medical or Professional Advice
|
19 |
+
|
20 |
+
Content generated by the Model should not be used for medical, legal, financial, or other professional advice.
|
21 |
+
|
22 |
+
### Changes to this Disclaimer
|
23 |
+
|
24 |
+
This disclaimer may be updated from time to time without notice.
|
25 |
+
|
26 |
+
### Contact
|
27 |
+
|
28 |
+
If you have any questions about this disclaimer, please contact the Author.
|
29 |
+
|
30 |
+
**By using LiteDiffusion, you acknowledge that you have read and understood this disclaimer.**
|
src/__pycache__/__init__.cpython-312.pyc
DELETED
Binary file (196 Bytes)
|
|
src/__pycache__/attention.cpython-312.pyc
DELETED
Binary file (4.69 kB)
|
|
src/__pycache__/clip.cpython-312.pyc
DELETED
Binary file (4.02 kB)
|
|
src/__pycache__/config.cpython-312.pyc
DELETED
Binary file (3.4 kB)
|
|
src/__pycache__/ddpm.cpython-312.pyc
DELETED
Binary file (6.46 kB)
|
|
src/__pycache__/decoder.cpython-312.pyc
DELETED
Binary file (4.93 kB)
|
|
src/__pycache__/diffusion.cpython-312.pyc
DELETED
Binary file (14.2 kB)
|
|
src/__pycache__/encoder.cpython-312.pyc
DELETED
Binary file (2.56 kB)
|
|
src/__pycache__/model_converter.cpython-312.pyc
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:cc31a7458a7d5afc6251204fd5949d56297f0e0bc97b6b307d2d70b3e2b38d97
|
3 |
-
size 170127
|
|
|
|
|
|
|
|
src/__pycache__/model_loader.cpython-312.pyc
DELETED
Binary file (1.86 kB)
|
|
src/__pycache__/pipeline.cpython-312.pyc
DELETED
Binary file (8.11 kB)
|
|
src/pipeline.py
CHANGED
@@ -13,20 +13,6 @@ LATENTS_HEIGHT = HEIGHT // 8
|
|
13 |
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
|
16 |
-
def generate(
|
17 |
-
prompt,
|
18 |
-
uncond_prompt=None,
|
19 |
-
input_image=None,
|
20 |
-
config: Config = default_config,
|
21 |
-
):
|
22 |
-
with torch.no_grad():
|
23 |
-
validate_strength(config.diffusion.strength)
|
24 |
-
generator = initialize_generator(config.seed, config.device.device)
|
25 |
-
context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, config.tokenizer, config.models["clip"], config.device.device)
|
26 |
-
latents = initialize_latents(input_image, config.diffusion.strength, generator, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps)
|
27 |
-
images = run_diffusion(latents, context, config.diffusion.do_cfg, config.diffusion.cfg_scale, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps, generator)
|
28 |
-
return postprocess_images(images)
|
29 |
-
|
30 |
def validate_strength(strength):
|
31 |
if not 0 < strength <= 1:
|
32 |
raise ValueError("Strength must be between 0 and 1")
|
@@ -45,7 +31,7 @@ def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
|
|
45 |
cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
|
46 |
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
|
47 |
cond_context = clip(cond_tokens)
|
48 |
-
uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
|
49 |
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
|
50 |
uncond_context = clip(uncond_tokens)
|
51 |
context = torch.cat([cond_context, uncond_context])
|
@@ -55,17 +41,15 @@ def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
|
|
55 |
context = clip(tokens)
|
56 |
return context
|
57 |
|
58 |
-
def
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
latents = (1 - strength) * latents + strength * noise
|
68 |
-
return latents
|
69 |
|
70 |
def preprocess_image(input_image):
|
71 |
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
|
@@ -76,6 +60,51 @@ def preprocess_image(input_image):
|
|
76 |
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
|
77 |
return input_image_tensor
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def get_sampler(sampler_name, generator, n_inference_steps):
|
80 |
if sampler_name == "ddpm":
|
81 |
sampler = DDPMSampler(generator)
|
@@ -84,6 +113,11 @@ def get_sampler(sampler_name, generator, n_inference_steps):
|
|
84 |
raise ValueError(f"Unknown sampler value {sampler_name}.")
|
85 |
return sampler
|
86 |
|
|
|
|
|
|
|
|
|
|
|
87 |
def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
|
88 |
diffusion = models["diffusion"]
|
89 |
diffusion.to(device)
|
@@ -108,17 +142,42 @@ def postprocess_images(images):
|
|
108 |
images = images.to("cpu", torch.uint8).numpy()
|
109 |
return images[0]
|
110 |
|
111 |
-
def
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def validate_strength(strength):
|
17 |
if not 0 < strength <= 1:
|
18 |
raise ValueError("Strength must be between 0 and 1")
|
|
|
31 |
cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
|
32 |
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
|
33 |
cond_context = clip(cond_tokens)
|
34 |
+
uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt or ""], padding="max_length", max_length=77).input_ids
|
35 |
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
|
36 |
uncond_context = clip(uncond_tokens)
|
37 |
context = torch.cat([cond_context, uncond_context])
|
|
|
41 |
context = clip(tokens)
|
42 |
return context
|
43 |
|
44 |
+
def rescale(x, old_range, new_range, clamp=False):
|
45 |
+
old_min, old_max = old_range
|
46 |
+
new_min, new_max = new_range
|
47 |
+
x -= old_min
|
48 |
+
x *= (new_max - new_min) / (old_max - old_min)
|
49 |
+
x += new_min
|
50 |
+
if clamp:
|
51 |
+
x = x.clamp(new_min, new_max)
|
52 |
+
return x
|
|
|
|
|
53 |
|
54 |
def preprocess_image(input_image):
|
55 |
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
|
|
|
60 |
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
|
61 |
return input_image_tensor
|
62 |
|
63 |
+
def encode_image(input_image, models, device):
|
64 |
+
# Preprocess the input image
|
65 |
+
image_tensor = preprocess_image(input_image).to(device)
|
66 |
+
|
67 |
+
# Encode the image using the VAE encoder
|
68 |
+
encoder = models["encoder"]
|
69 |
+
encoder.to(device)
|
70 |
+
with torch.no_grad():
|
71 |
+
# Create deterministic noise (zeros) since we want exact reconstruction
|
72 |
+
noise = torch.zeros((1, 4, LATENTS_WIDTH, LATENTS_HEIGHT), device=device)
|
73 |
+
latents = encoder(image_tensor, noise)
|
74 |
+
|
75 |
+
return latents
|
76 |
+
|
77 |
+
def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps, mask_image=None):
|
78 |
+
if input_image is None:
|
79 |
+
# Initialize with random noise
|
80 |
+
latents = torch.randn((1, 4, LATENTS_WIDTH, LATENTS_HEIGHT), generator=generator, device=device)
|
81 |
+
else:
|
82 |
+
# Initialize with encoded input image
|
83 |
+
latents = encode_image(input_image, models, device)
|
84 |
+
|
85 |
+
# If mask is provided for inpainting
|
86 |
+
if mask_image is not None:
|
87 |
+
# Process mask
|
88 |
+
mask = mask_image.resize((WIDTH, HEIGHT))
|
89 |
+
mask = np.array(mask)
|
90 |
+
mask = torch.tensor(mask, dtype=torch.float32).to(device)
|
91 |
+
mask = mask / 255.0 # Normalize to 0-1
|
92 |
+
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
|
93 |
+
mask = F.interpolate(mask, (LATENTS_WIDTH, LATENTS_HEIGHT))
|
94 |
+
mask = mask.repeat(1, 4, 1, 1) # Repeat for all latent channels
|
95 |
+
|
96 |
+
# Create masked noise - torch.randn_like doesn't accept generator
|
97 |
+
noise = torch.randn(latents.shape, device=device)
|
98 |
+
masked_latents = latents * (1 - mask) + noise * mask
|
99 |
+
latents = masked_latents
|
100 |
+
|
101 |
+
# Add noise based on strength (for img2img)
|
102 |
+
# torch.randn_like doesn't accept generator
|
103 |
+
noise = torch.randn(latents.shape, device=device)
|
104 |
+
latents = (1 - strength) * latents + strength * noise
|
105 |
+
|
106 |
+
return latents
|
107 |
+
|
108 |
def get_sampler(sampler_name, generator, n_inference_steps):
|
109 |
if sampler_name == "ddpm":
|
110 |
sampler = DDPMSampler(generator)
|
|
|
113 |
raise ValueError(f"Unknown sampler value {sampler_name}.")
|
114 |
return sampler
|
115 |
|
116 |
+
def get_time_embedding(timestep):
|
117 |
+
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
|
118 |
+
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
119 |
+
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
120 |
+
|
121 |
def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
|
122 |
diffusion = models["diffusion"]
|
123 |
diffusion.to(device)
|
|
|
142 |
images = images.to("cpu", torch.uint8).numpy()
|
143 |
return images[0]
|
144 |
|
145 |
+
def generate(
|
146 |
+
prompt,
|
147 |
+
uncond_prompt=None,
|
148 |
+
input_image=None,
|
149 |
+
mask_image=None,
|
150 |
+
config: Config = default_config,
|
151 |
+
):
|
152 |
+
with torch.no_grad():
|
153 |
+
# Validate inputs and parameters
|
154 |
+
if prompt is None or prompt.strip() == "":
|
155 |
+
raise ValueError("Prompt cannot be empty")
|
156 |
+
|
157 |
+
if uncond_prompt is None:
|
158 |
+
uncond_prompt = ""
|
159 |
+
|
160 |
+
validate_strength(config.diffusion.strength)
|
161 |
+
|
162 |
+
# Initialize generator for reproducibility
|
163 |
+
generator = initialize_generator(config.seed, config.device.device)
|
164 |
+
|
165 |
+
# Encode text prompt
|
166 |
+
context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg,
|
167 |
+
config.tokenizer, config.models["clip"], config.device.device)
|
168 |
+
|
169 |
+
# Initialize latents (either from noise or from input image)
|
170 |
+
latents = initialize_latents(input_image, config.diffusion.strength, generator,
|
171 |
+
config.models, config.device.device,
|
172 |
+
config.diffusion.sampler_name,
|
173 |
+
config.diffusion.n_inference_steps,
|
174 |
+
mask_image)
|
175 |
+
|
176 |
+
# Run diffusion process
|
177 |
+
images = run_diffusion(latents, context, config.diffusion.do_cfg,
|
178 |
+
config.diffusion.cfg_scale, config.models,
|
179 |
+
config.device.device, config.diffusion.sampler_name,
|
180 |
+
config.diffusion.n_inference_steps, generator)
|
181 |
+
|
182 |
+
# Post-process and return the images
|
183 |
+
return postprocess_images(images)
|