File size: 7,512 Bytes
ef6c3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b4c25
ef6c3c2
 
 
 
 
 
 
 
 
b7b4c25
 
 
 
 
 
 
 
 
ef6c3c2
 
 
 
 
 
 
 
 
 
b7b4c25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef6c3c2
 
 
 
 
 
 
 
b7b4c25
 
 
 
 
ef6c3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b4c25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from .ddpm import DDPMSampler
import logging
from .config import Config, default_config

WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8

logging.basicConfig(level=logging.INFO)

def validate_strength(strength):
    if not 0 < strength <= 1:
        raise ValueError("Strength must be between 0 and 1")

def initialize_generator(seed, device):
    generator = torch.Generator(device=device)
    if seed is None:
        generator.seed()
    else:
        generator.manual_seed(seed)
    return generator

def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
    clip.to(device)
    if do_cfg:
        cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
        cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
        cond_context = clip(cond_tokens)
        uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt or ""], padding="max_length", max_length=77).input_ids
        uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
        uncond_context = clip(uncond_tokens)
        context = torch.cat([cond_context, uncond_context])
    else:
        tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
        tokens = torch.tensor(tokens, dtype=torch.long, device=device)
        context = clip(tokens)
    return context

def rescale(x, old_range, new_range, clamp=False):
    old_min, old_max = old_range
    new_min, new_max = new_range
    x -= old_min
    x *= (new_max - new_min) / (old_max - old_min)
    x += new_min
    if clamp:
        x = x.clamp(new_min, new_max)
    return x

def preprocess_image(input_image):
    input_image_tensor = input_image.resize((WIDTH, HEIGHT))
    input_image_tensor = np.array(input_image_tensor)
    input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
    input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
    input_image_tensor = input_image_tensor.unsqueeze(0)
    input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
    return input_image_tensor

def encode_image(input_image, models, device):
    # Preprocess the input image
    image_tensor = preprocess_image(input_image).to(device)
    
    # Encode the image using the VAE encoder
    encoder = models["encoder"]
    encoder.to(device)
    with torch.no_grad():
        # Create deterministic noise (zeros) since we want exact reconstruction
        noise = torch.zeros((1, 4, LATENTS_WIDTH, LATENTS_HEIGHT), device=device)
        latents = encoder(image_tensor, noise)
    
    return latents

def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps, mask_image=None):
    if input_image is None:
        # Initialize with random noise
        latents = torch.randn((1, 4, LATENTS_WIDTH, LATENTS_HEIGHT), generator=generator, device=device)
    else:
        # Initialize with encoded input image
        latents = encode_image(input_image, models, device)
        
        # If mask is provided for inpainting
        if mask_image is not None:
            # Process mask
            mask = mask_image.resize((WIDTH, HEIGHT))
            mask = np.array(mask)
            mask = torch.tensor(mask, dtype=torch.float32).to(device)
            mask = mask / 255.0  # Normalize to 0-1
            mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
            mask = F.interpolate(mask, (LATENTS_WIDTH, LATENTS_HEIGHT))
            mask = mask.repeat(1, 4, 1, 1)  # Repeat for all latent channels
            
            # Create masked noise - torch.randn_like doesn't accept generator
            noise = torch.randn(latents.shape, device=device)
            masked_latents = latents * (1 - mask) + noise * mask
            latents = masked_latents
        
        # Add noise based on strength (for img2img)
        # torch.randn_like doesn't accept generator
        noise = torch.randn(latents.shape, device=device)
        latents = (1 - strength) * latents + strength * noise
    
    return latents

def get_sampler(sampler_name, generator, n_inference_steps):
    if sampler_name == "ddpm":
        sampler = DDPMSampler(generator)
        sampler.set_inference_timesteps(n_inference_steps)
    else:
        raise ValueError(f"Unknown sampler value {sampler_name}.")
    return sampler

def get_time_embedding(timestep):
    freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
    x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
    diffusion = models["diffusion"]
    diffusion.to(device)
    sampler = get_sampler(sampler_name, generator, n_inference_steps)
    timesteps = tqdm(sampler.timesteps)
    for timestep in timesteps:
        time_embedding = get_time_embedding(timestep).to(device)
        model_input = latents.repeat(2, 1, 1, 1) if do_cfg else latents
        model_output = diffusion(model_input, context, time_embedding)
        if do_cfg:
            output_cond, output_uncond = model_output.chunk(2)
            model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
        latents = sampler.step(timestep, latents, model_output)
    decoder = models["decoder"]
    decoder.to(device)
    images = decoder(latents)
    return images

def postprocess_images(images):
    images = rescale(images, (-1, 1), (0, 255), clamp=True)
    images = images.permute(0, 2, 3, 1)
    images = images.to("cpu", torch.uint8).numpy()
    return images[0]

def generate(
    prompt,
    uncond_prompt=None,
    input_image=None,
    mask_image=None,
    config: Config = default_config,
):
    with torch.no_grad():
        # Validate inputs and parameters
        if prompt is None or prompt.strip() == "":
            raise ValueError("Prompt cannot be empty")
            
        if uncond_prompt is None:
            uncond_prompt = ""
            
        validate_strength(config.diffusion.strength)
        
        # Initialize generator for reproducibility
        generator = initialize_generator(config.seed, config.device.device)
        
        # Encode text prompt
        context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, 
                               config.tokenizer, config.models["clip"], config.device.device)
        
        # Initialize latents (either from noise or from input image)
        latents = initialize_latents(input_image, config.diffusion.strength, generator, 
                                    config.models, config.device.device, 
                                    config.diffusion.sampler_name, 
                                    config.diffusion.n_inference_steps, 
                                    mask_image)
        
        # Run diffusion process
        images = run_diffusion(latents, context, config.diffusion.do_cfg, 
                              config.diffusion.cfg_scale, config.models, 
                              config.device.device, config.diffusion.sampler_name, 
                              config.diffusion.n_inference_steps, generator)
        
        # Post-process and return the images
        return postprocess_images(images)