kadirnar's picture
update
2a37fe9
import json
import math
import random
import time
from pathlib import Path
from uuid import uuid4
import torch
from diffusers import __version__ as diffusers_version
from huggingface_hub import CommitOperationAdd, create_commit, create_repo
from .upsampling import RealESRGANModel
from .utils import pad_along_axis
def get_all_files(root: Path):
dirs = [root]
while len(dirs) > 0:
dir = dirs.pop()
for candidate in dir.iterdir():
if candidate.is_file():
yield candidate
if candidate.is_dir():
dirs.append(candidate)
def get_groups_of_n(n: int, iterator):
assert n > 1
buffer = []
for elt in iterator:
if len(buffer) == n:
yield buffer
buffer = []
buffer.append(elt)
if len(buffer) != 0:
yield buffer
def upload_folder_chunked(
repo_id: str,
upload_dir: Path,
n: int = 100,
private: bool = False,
create_pr: bool = False,
):
"""Upload a folder to the Hugging Face Hub in chunks of n files at a time.
Args:
repo_id (str): The repo id to upload to.
upload_dir (Path): The directory to upload.
n (int, *optional*, defaults to 100): The number of files to upload at a time.
private (bool, *optional*): Whether to upload the repo as private.
create_pr (bool, *optional*): Whether to create a PR after uploading instead of commiting directly.
"""
url = create_repo(repo_id, exist_ok=True, private=private, repo_type="dataset")
print(f"Uploading files to: {url}")
root = Path(upload_dir)
if not root.exists():
raise ValueError(f"Upload directory {root} does not exist.")
for i, file_paths in enumerate(get_groups_of_n(n, get_all_files(root))):
print(f"Committing {file_paths}")
operations = [
CommitOperationAdd(
path_in_repo=f"{file_path.parent.name}/{file_path.name}",
path_or_fileobj=str(file_path),
)
for file_path in file_paths
]
create_commit(
repo_id=repo_id,
operations=operations,
commit_message=f"Upload part {i}",
repo_type="dataset",
create_pr=create_pr,
)
def generate_input_batches(pipeline, prompts, seeds, batch_size, height, width):
if len(prompts) != len(seeds):
raise ValueError("Number of prompts and seeds must be equal.")
embeds_batch, noise_batch = None, None
batch_idx = 0
for i, (prompt, seed) in enumerate(zip(prompts, seeds)):
embeds = pipeline.embed_text(prompt)
noise = torch.randn(
(1, pipeline.unet.in_channels, height // 8, width // 8),
device=pipeline.device,
generator=torch.Generator(device="cpu" if pipeline.device.type == "mps" else pipeline.device).manual_seed(
seed
),
)
embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds])
noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise])
batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == len(prompts)
if not batch_is_ready:
continue
yield batch_idx, embeds_batch.type(torch.cuda.HalfTensor), noise_batch.type(torch.cuda.HalfTensor)
batch_idx += 1
del embeds_batch, noise_batch
torch.cuda.empty_cache()
embeds_batch, noise_batch = None, None
def generate_images(
pipeline,
prompt,
batch_size=1,
num_batches=1,
seeds=None,
num_inference_steps=50,
guidance_scale=7.5,
output_dir="./images",
image_file_ext=".jpg",
upsample=False,
height=512,
width=512,
eta=0.0,
push_to_hub=False,
repo_id=None,
private=False,
create_pr=False,
name=None,
):
"""Generate images using the StableDiffusion pipeline.
Args:
pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance.
prompt (str): The prompt to use for the image generation.
batch_size (int, *optional*, defaults to 1): The batch size to use for image generation.
num_batches (int, *optional*, defaults to 1): The number of batches to generate.
seeds (list[int], *optional*): The seeds to use for the image generation.
num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take.
guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation.
output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to.
image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use.
upsample (bool, *optional*, defaults to False): Whether to upsample the images.
height (int, *optional*, defaults to 512): The height of the images to generate.
width (int, *optional*, defaults to 512): The width of the images to generate.
eta (float, *optional*, defaults to 0.0): The eta parameter to use for image generation.
push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub.
repo_id (str, *optional*): The repo id to push the images to.
private (bool, *optional*): Whether to push the repo as private.
create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly.
name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of
output_dir to save the images to.
"""
if push_to_hub:
if repo_id is None:
raise ValueError("Must provide repo_id if push_to_hub is True.")
name = name or time.strftime("%Y%m%d-%H%M%S")
save_path = Path(output_dir) / name
save_path.mkdir(exist_ok=False, parents=True)
prompt_config_path = save_path / "prompt_config.json"
num_images = batch_size * num_batches
seeds = seeds or [random.choice(list(range(0, 9999999))) for _ in range(num_images)]
if len(seeds) != num_images:
raise ValueError("Number of seeds must be equal to batch_size * num_batches.")
if upsample:
if getattr(pipeline, "upsampler", None) is None:
pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
pipeline.upsampler.to(pipeline.device)
cfg = dict(
prompt=prompt,
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
upsample=upsample,
height=height,
width=width,
scheduler=dict(pipeline.scheduler.config),
tiled=pipeline.tiled,
diffusers_version=diffusers_version,
device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown",
)
prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False))
frame_index = 0
frame_filepaths = []
for batch_idx, embeds, noise in generate_input_batches(
pipeline, [prompt] * num_images, seeds, batch_size, height, width
):
print(f"Generating batch {batch_idx}")
outputs = pipeline(
text_embeddings=embeds,
latents=noise,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=eta,
height=height,
width=width,
output_type="pil" if not upsample else "numpy",
)["images"]
if upsample:
images = []
for output in outputs:
images.append(pipeline.upsampler(output))
else:
images = outputs
for image in images:
frame_filepath = save_path / f"{seeds[frame_index]}{image_file_ext}"
image.save(frame_filepath)
frame_filepaths.append(str(frame_filepath))
frame_index += 1
return frame_filepaths
if push_to_hub:
upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr)
def generate_images_flax(
pipeline,
params,
prompt,
batch_size=1,
num_batches=1,
seeds=None,
num_inference_steps=50,
guidance_scale=7.5,
output_dir="./images",
image_file_ext=".jpg",
upsample=False,
height=512,
width=512,
push_to_hub=False,
repo_id=None,
private=False,
create_pr=False,
name=None,
):
import jax
from flax.training.common_utils import shard
"""Generate images using the StableDiffusion pipeline.
Args:
pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance.
params (`Union[Dict, FrozenDict]`): The model parameters.
prompt (str): The prompt to use for the image generation.
batch_size (int, *optional*, defaults to 1): The batch size to use for image generation.
num_batches (int, *optional*, defaults to 1): The number of batches to generate.
seeds (int, *optional*): The seed to use for the image generation.
num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take.
guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation.
output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to.
image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use.
upsample (bool, *optional*, defaults to False): Whether to upsample the images.
height (int, *optional*, defaults to 512): The height of the images to generate.
width (int, *optional*, defaults to 512): The width of the images to generate.
push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub.
repo_id (str, *optional*): The repo id to push the images to.
private (bool, *optional*): Whether to push the repo as private.
create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly.
name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of
output_dir to save the images to.
"""
if push_to_hub:
if repo_id is None:
raise ValueError("Must provide repo_id if push_to_hub is True.")
name = name or time.strftime("%Y%m%d-%H%M%S")
save_path = Path(output_dir) / name
save_path.mkdir(exist_ok=False, parents=True)
prompt_config_path = save_path / "prompt_config.json"
num_images = batch_size * num_batches
seeds = seeds or random.choice(list(range(0, 9999999)))
prng_seed = jax.random.PRNGKey(seeds)
if upsample:
if getattr(pipeline, "upsampler", None) is None:
pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
if not torch.cuda.is_available():
print("Upsampling is recommended to be done on a GPU, as it is very slow on CPU")
else:
pipeline.upsampler = pipeline.upsampler.cuda()
cfg = dict(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
upsample=upsample,
height=height,
width=width,
scheduler=dict(pipeline.scheduler.config),
# tiled=pipeline.tiled,
diffusers_version=diffusers_version,
device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown",
)
prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False))
NUM_TPU_CORES = jax.device_count()
jit = True # force jit, assume params are already sharded
batch_size_total = NUM_TPU_CORES * batch_size if jit else batch_size
def generate_input_batches(prompts, batch_size):
prompt_batch = None
for batch_idx in range(math.ceil(len(prompts) / batch_size)):
prompt_batch = prompts[batch_idx * batch_size : (batch_idx + 1) * batch_size]
yield batch_idx, prompt_batch
frame_index = 0
frame_filepaths = []
for batch_idx, prompt_batch in generate_input_batches([prompt] * num_images, batch_size_total):
# This batch size correspond to each TPU core, so we are generating batch_size * NUM_TPU_CORES images
print(f"Generating batches: {batch_idx*NUM_TPU_CORES} - {min((batch_idx+1)*NUM_TPU_CORES, num_batches)}")
prompt_ids_batch = pipeline.prepare_inputs(prompt_batch)
prng_seed_batch = prng_seed
if jit:
padded = False
# Check if len of prompt_batch is multiple of NUM_TPU_CORES, if not pad its ids
if len(prompt_batch) % NUM_TPU_CORES != 0:
padded = True
pad_size = NUM_TPU_CORES - (len(prompt_batch) % NUM_TPU_CORES)
# Pad embeds_batch and noise_batch with zeros in batch dimension
prompt_ids_batch = pad_along_axis(prompt_ids_batch, pad_size, axis=0)
prompt_ids_batch = shard(prompt_ids_batch)
prng_seed_batch = jax.random.split(prng_seed, jax.device_count())
outputs = pipeline(
params,
prng_seed=prng_seed_batch,
prompt_ids=prompt_ids_batch,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
output_type="pil" if not upsample else "numpy",
jit=jit,
)["images"]
if jit:
# check if we padded and remove that padding from outputs
if padded:
outputs = outputs[:-pad_size]
if upsample:
images = []
for output in outputs:
images.append(pipeline.upsampler(output))
else:
images = outputs
for image in images:
uuid = str(uuid4())
frame_filepath = save_path / f"{uuid}{image_file_ext}"
image.save(frame_filepath)
frame_filepaths.append(str(frame_filepath))
frame_index += 1
return frame_filepaths
if push_to_hub:
upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr)