Spaces:
Running
on
T4
Running
on
T4
import os | |
import io | |
import re | |
import time | |
import random | |
import torch | |
from typing import Final, List, Optional, Tuple, cast | |
from PIL import Image, ImageDraw, ImageEnhance | |
from PIL.Image import Image as PILImage | |
from diffusers import StableDiffusionPipeline | |
model_id: Final = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator" | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, torch_dtype=torch.float16, cache_dir="cache" | |
) | |
pipe = pipe.to("cuda") | |
sprite_sides: Final = { | |
"front": "PixelArtFSS", | |
"right": "PixelArtRSS", | |
"back": "PixelArtBSS", | |
"left": "PixelArtLSS", | |
} | |
def torchGenerator(seed: Optional[int], max: int = 1024) -> Tuple[torch.Generator, int]: | |
seed = seed or random.randrange(0, max) | |
return torch.Generator("cuda").manual_seed(seed), seed | |
def generate( | |
prompt: str, | |
sfw_retries: int = 1, | |
seed: Optional[int] = None, | |
) -> PILImage: | |
""" | |
Generate a sprite image from a text description. | |
Return a blank image if the model fails to generate a safe image. | |
""" | |
generator = torchGenerator(seed)[0] | |
image: PILImage | None = None | |
for _ in range(sfw_retries): | |
pipe_output = pipe(prompt, generator=generator, width=512, height=512) | |
image = pipe_output.images[0] | |
if not pipe_output.nsfw_content_detected[0]: | |
break | |
rand_seed = seed | |
while rand_seed == seed: | |
print(f"Regenerating `{prompt}` with different seed.") | |
rand_seed = random.randrange(0, 1024) | |
generator = torchGenerator(rand_seed)[0] | |
return cast(PILImage, image) | |
def generate_sides( | |
prompt: str, sfw_retries: int = 1, sides: dict[str, str] = sprite_sides | |
) -> Tuple[dict[str, PILImage], str]: | |
""" | |
Generate sprite images from a text description of different sides. | |
If both left and right side specified, duplicate and flip left side as the right side | |
""" | |
print(f"Generating sprites for `{prompt}`") | |
seed = random.randrange(0, 1024) | |
sprites = {} | |
# If both left and right side specified, duplicate and flip left side as the right side | |
for side, label in sides.items(): | |
if side == "right" and "left" in sides and "right" in sides: | |
continue | |
sprites[side] = generate(f"({prompt}) [nsfw] [photograph] {label}", sfw_retries, seed) | |
if "left" in sides and "right" in sides: | |
sprites["right"] = sprites["left"].transpose(Image.Transpose.FLIP_LEFT_RIGHT) | |
return sprites, prompt | |
def clean_sprite( | |
image: PILImage, | |
size: Tuple[int, int] = (192, 192), | |
sharpness: float = 1.5, | |
thresh: int = 128, | |
rescaling: Optional[int] = None, | |
) -> PILImage: | |
""" | |
Process image to be more sprite-like. | |
`rescale` will first scale down by value, then up to specified size. | |
""" | |
width, height = image.size | |
sharpener = ImageEnhance.Sharpness(image) | |
image = sharpener.enhance(sharpness) | |
image = image.convert("RGBA") | |
ImageDraw.floodfill(image, (0, 0), (255, 255, 255, 0), thresh=thresh) | |
if type(rescaling) is int: | |
image = image.resize( | |
(int(width / rescaling), int(height / rescaling)), | |
resample=Image.Resampling.NEAREST, | |
) | |
image = image.resize(size, resample=Image.Resampling.NEAREST) | |
return image | |
def split_sprites(image: PILImage, size: Tuple[int, int] = (96, 96)) -> List[PILImage]: | |
"""Split sprite image into individual sides.""" | |
width, height = image.size | |
w, h = size | |
# fmt: off | |
frames = [ | |
image.crop(( | |
0, | |
int(h / 2), | |
int(width / 4), | |
int(height * 0.75), | |
)), | |
image.crop(( | |
int(width / 4), | |
int(h / 2), | |
int(width / 4) * 2, | |
int(height * 0.75), | |
)), | |
image.crop(( | |
int(width / 4) * 2, | |
int(h / 2), | |
int(width / 4) * 3, | |
int(height * 0.75), | |
)), | |
image.crop(( | |
int(width / 4) * 3, | |
int(h / 2), | |
width, | |
int(height * 0.75), | |
)), | |
] | |
# fmt: on | |
new_canvas = Image.new("RGBA", size, (255, 255, 255, 0)) | |
for i in range(len(frames)): | |
canvas = new_canvas.copy() | |
canvas.paste(frames[i], (int(w / 4), 0, int(w * 0.75), h)) | |
frames[i] = canvas | |
return frames | |
def build_spritesheet( | |
images: dict[str, PILImage], | |
text: str = "sd_pixelart", | |
sprite_size: Tuple[int, int] = (96, 96), | |
dir: str = "output", | |
save: bool = False, | |
timestamp: Optional[int] = None, | |
thresh: int = 128, | |
) -> Tuple[PILImage, str | None]: | |
""" | |
Build sprite sheet from sides. | |
1. Clean and scale each image | |
2. Split each image into individual frames | |
3. Create a new spritesheet canvas for all sides[frames] | |
4. Paste each individial frame onto canvas | |
""" | |
frames = {} | |
width, height = sprite_size | |
text = re.sub(r"[^\w()[\]_-]", "", text) | |
filepath = None | |
for side, image in images.items(): | |
image = clean_sprite(image, (width * 2, height * 2), thresh=thresh) | |
frames[side] = split_sprites(image, sprite_size) | |
canvas = Image.new( | |
"RGBA", | |
(width * len(frames["front"]), height * len(frames)), | |
(255, 255, 255, 0), | |
) | |
for j in range(len(frames["front"])): | |
for k, side in enumerate(frames): | |
canvas.paste( | |
frames[side][j], | |
( | |
j * width, | |
k * height, | |
j * width + width, | |
k * height + height, | |
), | |
) | |
spritesheet = io.BytesIO() | |
canvas.save(spritesheet, "PNG") | |
if save: | |
timestamp = timestamp or int(time.time()) | |
filepath = os.path.join(dir, f"{timestamp}_{text}.png") | |
canvas.save(filepath) | |
return Image.open(spritesheet), filepath | |
def build_gifs( | |
images: dict[str, PILImage], | |
text: str = "sd_spritesheet", | |
dir: str = "output", | |
duration: int | List[int] | Tuple[int, ...] = (300, 450, 300, 450), | |
save: bool = False, | |
timestamp: Optional[int] = None, | |
thresh: int = 128, | |
) -> Tuple[dict[str, List[PILImage]], List[str] | None]: | |
"""Build animated GIFs from side frames.""" | |
gifs = {} | |
text = re.sub(r"[^\w()[\]_-]", "", text) | |
filepaths = [] if save else None | |
for side, image in images.items(): | |
image = clean_sprite(image, thresh=thresh) | |
frames = split_sprites(image) | |
gif = io.BytesIO() | |
options = { | |
"fp": gif, | |
"format": "GIF", | |
"save_all": True, | |
"append_images": frames[1:], | |
"disposal": 3, | |
"duration": duration, | |
"loop": 0, | |
} | |
frames[0].save(**options) | |
gifs[side] = Image.open(gif) | |
if save: | |
timestamp = timestamp or int(time.time()) | |
filepath = os.path.join(dir, f"{timestamp}_{text}_{side}.gif") | |
filepaths.append(filepath) | |
options.update({"fp": filepath}) | |
frames[0].save(**options) | |
return gifs, filepaths | |