NIRVANALAN
init
11e6f7b
raw
history blame
6.65 kB
from st_keyup import st_keyup
from streamlit_helpers import *
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler
VERSION2SPECS = {
"SDXL-Turbo": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors",
},
"SD-Turbo": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/sd_turbo.safetensors",
},
}
class SubstepSampler(EulerAncestralSampler):
def __init__(self, n_sample_steps=1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_sample_steps = n_sample_steps
self.steps_subset = [0, 100, 200, 300, 1000]
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
sigmas = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device
)
sigmas = sigmas[
self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:]
]
uc = cond
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)
s_in = x.new_ones([x.shape[0]])
return x, s_in, sigmas, num_sigmas, cond, uc
def seeded_randn(shape, seed):
randn = np.random.RandomState(seed).randn(*shape)
randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32)
return randn
class SeededNoise:
def __init__(self, seed):
self.seed = seed
def __call__(self, x):
self.seed = self.seed + 1
return seeded_randn(x.shape, self.seed)
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
value_dict = {}
for key in keys:
if key == "txt":
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = ""
if key == "original_size_as_tuple":
orig_width = init_dict["orig_width"]
orig_height = init_dict["orig_height"]
value_dict["orig_width"] = orig_width
value_dict["orig_height"] = orig_height
if key == "crop_coords_top_left":
crop_coord_top = 0
crop_coord_left = 0
value_dict["crop_coords_top"] = crop_coord_top
value_dict["crop_coords_left"] = crop_coord_left
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
return value_dict
def sample(
model,
sampler,
prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.",
H=1024,
W=1024,
seed=0,
filter=None,
):
F = 8
C = 4
shape = (1, C, H // F, W // F)
value_dict = init_embedder_options(
keys=get_unique_embedder_keys_from_conditioner(model.conditioner),
init_dict={
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
},
prompt=prompt,
)
if seed is None:
seed = torch.seed()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[1],
)
c = model.conditioner(batch)
uc = None
randn = seeded_randn(shape, seed)
def denoiser(input, sigma, c):
return model.denoiser(
model.model,
input,
sigma,
c,
)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
samples = (
(255 * samples)
.to(dtype=torch.uint8)
.permute(0, 2, 3, 1)
.detach()
.cpu()
.numpy()
)
return samples
def v_spacer(height) -> None:
for _ in range(height):
st.write("\n")
if __name__ == "__main__":
st.title("Turbo")
head_cols = st.columns([1, 1, 1])
with head_cols[0]:
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version]
with head_cols[1]:
v_spacer(2)
if st.checkbox("Load Model"):
mode = "txt2img"
else:
mode = "skip"
if mode != "skip":
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
load_model(model)
# seed
if "seed" not in st.session_state:
st.session_state.seed = 0
def increment_counter():
st.session_state.seed += 1
def decrement_counter():
if st.session_state.seed > 0:
st.session_state.seed -= 1
with head_cols[2]:
n_steps = st.number_input(label="number of steps", min_value=1, max_value=4)
sampler = SubstepSampler(
n_sample_steps=1,
num_steps=1000,
eta=1.0,
discretization_config=dict(
target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization"
),
)
sampler.n_sample_steps = n_steps
default_prompt = (
"A cinematic shot of a baby racoon wearing an intricate italian priest robe."
)
prompt = st_keyup(
"Enter a value", value=default_prompt, debounce=300, key="interactive_text"
)
cols = st.columns([1, 5, 1])
if mode != "skip":
with cols[0]:
v_spacer(14)
st.button("↩", on_click=decrement_counter)
with cols[2]:
v_spacer(14)
st.button("↪", on_click=increment_counter)
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)
out = sample(
model,
sampler,
H=512,
W=512,
seed=st.session_state.seed,
prompt=prompt,
filter=state.get("filter"),
)
with cols[1]:
st.image(out[0])