Spaces:
Running
Running
import gradio as gr | |
from diffusers import DiffusionPipeline | |
import cv2 | |
import torch | |
import numpy as np | |
from PIL import Image | |
import os | |
model = "shadowlamer/sd-zxspectrum-model-256" | |
image_width = 256 | |
image_height = 192 | |
samples_dir = "/tmp" | |
pipe = DiffusionPipeline.from_pretrained(model, safety_checker=None, requires_safety_checker=False) | |
# Borrowed from here: https://stackoverflow.com/a/73667318 | |
def quantize_to_palette(_image, _palette): | |
x_query = _image.reshape(-1, 3).astype(np.float32) | |
x_index = _palette.astype(np.float32) | |
knn = cv2.ml.KNearest_create() | |
knn.train(x_index, cv2.ml.ROW_SAMPLE, np.arange(len(_palette))) | |
ret, results, neighbours, dist = knn.findNearest(x_query, 1) | |
_quantized_image = np.array([_palette[idx] for idx in neighbours.astype(int)]) | |
_quantized_image = _quantized_image.reshape(_image.shape) | |
return Image.fromarray(cv2.cvtColor(np.array(_quantized_image, dtype=np.uint8), cv2.COLOR_BGR2RGB)) | |
def collect_char_colors(image, _x, _y): | |
_colors = {} | |
for _char_y in range(8): | |
for _char_x in range(8): | |
_pixel = image.getpixel((_x + _char_x, _y + _char_y)) | |
_colors[_pixel] = 1 if _pixel not in _colors else _colors[_pixel] + 1 | |
_colors = sorted(_colors.items(), key=lambda _v: _v[1], reverse=True) | |
return [list(_tuple[0]) for _tuple in list(_colors)] | |
def palette_to_attr(_palette): | |
if len(_palette) == 0: | |
return 0x00 | |
_attr = 0x40 | |
_paper = _palette[0] | |
if _paper[0] > 0: | |
_attr = _attr + 0x10 # r | |
if _paper[1] > 0: | |
_attr = _attr + 0x20 # g | |
if _paper[2] > 0: | |
_attr = _attr + 0x08 # b | |
if len(_palette) == 1: | |
return _attr | |
_ink = _palette[1] | |
if _ink[0] > 0: | |
_attr = _attr + 0x02 # r | |
if _ink[1] > 0: | |
_attr = _attr + 0x04 # g | |
if _ink[2] > 0: | |
_attr = _attr + 0x01 # b | |
return _attr | |
def generate(prompt, seed): | |
generator = torch.Generator("cpu").manual_seed(int(seed)) | |
raw_image = \ | |
pipe(prompt, height=image_height, width=image_width, num_inference_steps=20, generator=generator).images[0] | |
palette = np.array( | |
[[0, 0, 0], [0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0], | |
[255, 255, 255]]) | |
input_image = np.array(raw_image) | |
input_image = input_image[:, :, ::-1].copy() | |
image = quantize_to_palette(_image=input_image, _palette=palette) | |
out = samples_dir + "/" + prompt.replace(" ", "_") + "_" + str(seed) + ".scr" | |
if not os.path.exists(out): | |
byte_buffer = [0] * 0x1800 | |
attr_buffer = [0b00111000] * 0x300 | |
for y in range(0, image_height, 8): | |
for x in range(0, image_width, 8): | |
px = int(x / 8) | |
py = int(y / 8) | |
palette = collect_char_colors(image, x, y) | |
byte_index = int(py / 8) * 0x800 + (py % 8) * 32 + px | |
for cy in range(8): | |
byte = 0 | |
for cx in range(8): | |
byte = byte * 2 | |
pixel = list(image.getpixel((x + cx, y + cy))) | |
if palette[0] != pixel: | |
byte = byte + 1 | |
byte_buffer[byte_index] = byte | |
byte_index = byte_index + 0x100 | |
attr = palette_to_attr(palette) | |
attr_buffer[py * 32 + px] = attr | |
scr = open(out, 'wb') | |
scr.write(bytearray(byte_buffer)) | |
scr.write(bytearray(attr_buffer)) | |
scr.close() | |
return [image, out] | |
iface = gr.Interface(fn=generate, | |
title="ZX-Spectrum inspired images generator ", | |
inputs=["text", "number"], | |
outputs=["image", "file"], | |
examples=[["Cute cat", 123], ["Solar system", 123], ["Disco ball", 123]]) | |
iface.launch() | |