|
from st_keyup import st_keyup |
|
from streamlit_helpers import * |
|
|
|
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler |
|
|
|
VERSION2SPECS = { |
|
"SDXL-Turbo": { |
|
"H": 512, |
|
"W": 512, |
|
"C": 4, |
|
"f": 8, |
|
"is_legacy": False, |
|
"config": "configs/inference/sd_xl_base.yaml", |
|
"ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors", |
|
}, |
|
"SD-Turbo": { |
|
"H": 512, |
|
"W": 512, |
|
"C": 4, |
|
"f": 8, |
|
"is_legacy": False, |
|
"config": "configs/inference/sd_2_1.yaml", |
|
"ckpt": "checkpoints/sd_turbo.safetensors", |
|
}, |
|
} |
|
|
|
|
|
class SubstepSampler(EulerAncestralSampler): |
|
def __init__(self, n_sample_steps=1, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.n_sample_steps = n_sample_steps |
|
self.steps_subset = [0, 100, 200, 300, 1000] |
|
|
|
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): |
|
sigmas = self.discretization( |
|
self.num_steps if num_steps is None else num_steps, device=self.device |
|
) |
|
sigmas = sigmas[ |
|
self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:] |
|
] |
|
uc = cond |
|
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) |
|
num_sigmas = len(sigmas) |
|
s_in = x.new_ones([x.shape[0]]) |
|
return x, s_in, sigmas, num_sigmas, cond, uc |
|
|
|
|
|
def seeded_randn(shape, seed): |
|
randn = np.random.RandomState(seed).randn(*shape) |
|
randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32) |
|
return randn |
|
|
|
|
|
class SeededNoise: |
|
def __init__(self, seed): |
|
self.seed = seed |
|
|
|
def __call__(self, x): |
|
self.seed = self.seed + 1 |
|
return seeded_randn(x.shape, self.seed) |
|
|
|
|
|
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): |
|
value_dict = {} |
|
for key in keys: |
|
if key == "txt": |
|
value_dict["prompt"] = prompt |
|
value_dict["negative_prompt"] = "" |
|
|
|
if key == "original_size_as_tuple": |
|
orig_width = init_dict["orig_width"] |
|
orig_height = init_dict["orig_height"] |
|
|
|
value_dict["orig_width"] = orig_width |
|
value_dict["orig_height"] = orig_height |
|
|
|
if key == "crop_coords_top_left": |
|
crop_coord_top = 0 |
|
crop_coord_left = 0 |
|
|
|
value_dict["crop_coords_top"] = crop_coord_top |
|
value_dict["crop_coords_left"] = crop_coord_left |
|
|
|
if key == "aesthetic_score": |
|
value_dict["aesthetic_score"] = 6.0 |
|
value_dict["negative_aesthetic_score"] = 2.5 |
|
|
|
if key == "target_size_as_tuple": |
|
value_dict["target_width"] = init_dict["target_width"] |
|
value_dict["target_height"] = init_dict["target_height"] |
|
|
|
return value_dict |
|
|
|
|
|
def sample( |
|
model, |
|
sampler, |
|
prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.", |
|
H=1024, |
|
W=1024, |
|
seed=0, |
|
filter=None, |
|
): |
|
F = 8 |
|
C = 4 |
|
shape = (1, C, H // F, W // F) |
|
|
|
value_dict = init_embedder_options( |
|
keys=get_unique_embedder_keys_from_conditioner(model.conditioner), |
|
init_dict={ |
|
"orig_width": W, |
|
"orig_height": H, |
|
"target_width": W, |
|
"target_height": H, |
|
}, |
|
prompt=prompt, |
|
) |
|
|
|
if seed is None: |
|
seed = torch.seed() |
|
precision_scope = autocast |
|
with torch.no_grad(): |
|
with precision_scope("cuda"): |
|
batch, batch_uc = get_batch( |
|
get_unique_embedder_keys_from_conditioner(model.conditioner), |
|
value_dict, |
|
[1], |
|
) |
|
c = model.conditioner(batch) |
|
uc = None |
|
randn = seeded_randn(shape, seed) |
|
|
|
def denoiser(input, sigma, c): |
|
return model.denoiser( |
|
model.model, |
|
input, |
|
sigma, |
|
c, |
|
) |
|
|
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc) |
|
samples_x = model.decode_first_stage(samples_z) |
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) |
|
if filter is not None: |
|
samples = filter(samples) |
|
samples = ( |
|
(255 * samples) |
|
.to(dtype=torch.uint8) |
|
.permute(0, 2, 3, 1) |
|
.detach() |
|
.cpu() |
|
.numpy() |
|
) |
|
return samples |
|
|
|
|
|
def v_spacer(height) -> None: |
|
for _ in range(height): |
|
st.write("\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
st.title("Turbo") |
|
|
|
head_cols = st.columns([1, 1, 1]) |
|
with head_cols[0]: |
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) |
|
version_dict = VERSION2SPECS[version] |
|
|
|
with head_cols[1]: |
|
v_spacer(2) |
|
if st.checkbox("Load Model"): |
|
mode = "txt2img" |
|
else: |
|
mode = "skip" |
|
|
|
if mode != "skip": |
|
state = init_st(version_dict, load_filter=True) |
|
if state["msg"]: |
|
st.info(state["msg"]) |
|
model = state["model"] |
|
load_model(model) |
|
|
|
|
|
if "seed" not in st.session_state: |
|
st.session_state.seed = 0 |
|
|
|
def increment_counter(): |
|
st.session_state.seed += 1 |
|
|
|
def decrement_counter(): |
|
if st.session_state.seed > 0: |
|
st.session_state.seed -= 1 |
|
|
|
with head_cols[2]: |
|
n_steps = st.number_input(label="number of steps", min_value=1, max_value=4) |
|
|
|
sampler = SubstepSampler( |
|
n_sample_steps=1, |
|
num_steps=1000, |
|
eta=1.0, |
|
discretization_config=dict( |
|
target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" |
|
), |
|
) |
|
sampler.n_sample_steps = n_steps |
|
default_prompt = ( |
|
"A cinematic shot of a baby racoon wearing an intricate italian priest robe." |
|
) |
|
prompt = st_keyup( |
|
"Enter a value", value=default_prompt, debounce=300, key="interactive_text" |
|
) |
|
|
|
cols = st.columns([1, 5, 1]) |
|
if mode != "skip": |
|
with cols[0]: |
|
v_spacer(14) |
|
st.button("↩", on_click=decrement_counter) |
|
with cols[2]: |
|
v_spacer(14) |
|
st.button("↪", on_click=increment_counter) |
|
|
|
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed) |
|
out = sample( |
|
model, |
|
sampler, |
|
H=512, |
|
W=512, |
|
seed=st.session_state.seed, |
|
prompt=prompt, |
|
filter=state.get("filter"), |
|
) |
|
with cols[1]: |
|
st.image(out[0]) |
|
|