diffusezx / app.py
sl
Steps parametrized.
7241f9b
raw
history blame
3.93 kB
import gradio as gr
from diffusers import DiffusionPipeline
import cv2
import torch
import numpy as np
from PIL import Image
import os
model = "shadowlamer/sd-zxspectrum-model-256"
image_width = 256
image_height = 192
samples_dir = "/tmp"
pipe = DiffusionPipeline.from_pretrained(model, safety_checker=None, requires_safety_checker=False)
# Borrowed from here: https://stackoverflow.com/a/73667318
def quantize_to_palette(_image, _palette):
x_query = _image.reshape(-1, 3).astype(np.float32)
x_index = _palette.astype(np.float32)
knn = cv2.ml.KNearest_create()
knn.train(x_index, cv2.ml.ROW_SAMPLE, np.arange(len(_palette)))
ret, results, neighbours, dist = knn.findNearest(x_query, 1)
_quantized_image = np.array([_palette[idx] for idx in neighbours.astype(int)])
_quantized_image = _quantized_image.reshape(_image.shape)
return Image.fromarray(cv2.cvtColor(np.array(_quantized_image, dtype=np.uint8), cv2.COLOR_BGR2RGB))
def collect_char_colors(image, _x, _y):
_colors = {}
for _char_y in range(8):
for _char_x in range(8):
_pixel = image.getpixel((_x + _char_x, _y + _char_y))
_colors[_pixel] = 1 if _pixel not in _colors else _colors[_pixel] + 1
_colors = sorted(_colors.items(), key=lambda _v: _v[1], reverse=True)
return [list(_tuple[0]) for _tuple in list(_colors)]
def palette_to_attr(_palette):
if len(_palette) == 0:
return 0x00
_attr = 0x40
_paper = _palette[0]
if _paper[0] > 0:
_attr = _attr + 0x10 # r
if _paper[1] > 0:
_attr = _attr + 0x20 # g
if _paper[2] > 0:
_attr = _attr + 0x08 # b
if len(_palette) == 1:
return _attr
_ink = _palette[1]
if _ink[0] > 0:
_attr = _attr + 0x02 # r
if _ink[1] > 0:
_attr = _attr + 0x04 # g
if _ink[2] > 0:
_attr = _attr + 0x01 # b
return _attr
def generate(prompt, seed, steps):
generator = torch.Generator("cpu").manual_seed(int(seed))
raw_image = \
pipe(prompt, height=image_height, width=image_width, num_inference_steps=int(steps), generator=generator).images[0]
palette = np.array(
[[0, 0, 0], [0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0],
[255, 255, 255]])
input_image = np.array(raw_image)
input_image = input_image[:, :, ::-1].copy()
image = quantize_to_palette(_image=input_image, _palette=palette)
out = samples_dir + "/" + prompt.replace(" ", "_") + "_" + str(seed) + ".scr"
if not os.path.exists(out):
byte_buffer = [0] * 0x1800
attr_buffer = [0b00111000] * 0x300
for y in range(0, image_height, 8):
for x in range(0, image_width, 8):
px = int(x / 8)
py = int(y / 8)
palette = collect_char_colors(image, x, y)
byte_index = int(py / 8) * 0x800 + (py % 8) * 32 + px
for cy in range(8):
byte = 0
for cx in range(8):
byte = byte * 2
pixel = list(image.getpixel((x + cx, y + cy)))
if palette[0] != pixel:
byte = byte + 1
byte_buffer[byte_index] = byte
byte_index = byte_index + 0x100
attr = palette_to_attr(palette)
attr_buffer[py * 32 + px] = attr
scr = open(out, 'wb')
scr.write(bytearray(byte_buffer))
scr.write(bytearray(attr_buffer))
scr.close()
return [image, out]
iface = gr.Interface(fn=generate,
title="ZX-Spectrum inspired images generator ",
inputs=["text", "number", "number"],
outputs=["image", "file"],
examples=[["Cute cat", 123, 20], ["Solar system", 123, 20], ["Disco ball", 123, 20]])
iface.launch()