File size: 3,928 Bytes
d49dda4
 
 
 
 
 
184c926
d49dda4
 
184c926
 
 
 
941019b
 
d49dda4
 
 
 
 
 
 
 
 
 
 
 
 
184c926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7241f9b
d49dda4
941019b
7241f9b
d49dda4
 
 
 
 
184c926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d49dda4
 
184c926
 
7241f9b
184c926
7241f9b
d49dda4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from diffusers import DiffusionPipeline
import cv2
import torch
import numpy as np
from PIL import Image
import os

model = "shadowlamer/sd-zxspectrum-model-256"
image_width = 256
image_height = 192
samples_dir = "/tmp"

pipe = DiffusionPipeline.from_pretrained(model, safety_checker=None, requires_safety_checker=False)


# Borrowed from here: https://stackoverflow.com/a/73667318
def quantize_to_palette(_image, _palette):
    x_query = _image.reshape(-1, 3).astype(np.float32)
    x_index = _palette.astype(np.float32)
    knn = cv2.ml.KNearest_create()
    knn.train(x_index, cv2.ml.ROW_SAMPLE, np.arange(len(_palette)))
    ret, results, neighbours, dist = knn.findNearest(x_query, 1)
    _quantized_image = np.array([_palette[idx] for idx in neighbours.astype(int)])
    _quantized_image = _quantized_image.reshape(_image.shape)
    return Image.fromarray(cv2.cvtColor(np.array(_quantized_image, dtype=np.uint8), cv2.COLOR_BGR2RGB))


def collect_char_colors(image, _x, _y):
    _colors = {}
    for _char_y in range(8):
        for _char_x in range(8):
            _pixel = image.getpixel((_x + _char_x, _y + _char_y))
            _colors[_pixel] = 1 if _pixel not in _colors else _colors[_pixel] + 1
    _colors = sorted(_colors.items(), key=lambda _v: _v[1], reverse=True)
    return [list(_tuple[0]) for _tuple in list(_colors)]


def palette_to_attr(_palette):
    if len(_palette) == 0:
        return 0x00
    _attr = 0x40
    _paper = _palette[0]
    if _paper[0] > 0:
        _attr = _attr + 0x10  # r
    if _paper[1] > 0:
        _attr = _attr + 0x20  # g
    if _paper[2] > 0:
        _attr = _attr + 0x08  # b
    if len(_palette) == 1:
        return _attr
    _ink = _palette[1]
    if _ink[0] > 0:
        _attr = _attr + 0x02  # r
    if _ink[1] > 0:
        _attr = _attr + 0x04  # g
    if _ink[2] > 0:
        _attr = _attr + 0x01  # b
    return _attr


def generate(prompt, seed, steps):
    generator = torch.Generator("cpu").manual_seed(int(seed))
    raw_image = \
    pipe(prompt, height=image_height, width=image_width, num_inference_steps=int(steps), generator=generator).images[0]
    palette = np.array(
        [[0, 0, 0], [0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0],
         [255, 255, 255]])
    input_image = np.array(raw_image)
    input_image = input_image[:, :, ::-1].copy()
    image = quantize_to_palette(_image=input_image, _palette=palette)

    out = samples_dir + "/" + prompt.replace(" ", "_") + "_" + str(seed) + ".scr"

    if not os.path.exists(out):
        byte_buffer = [0] * 0x1800
        attr_buffer = [0b00111000] * 0x300

        for y in range(0, image_height, 8):
            for x in range(0, image_width, 8):
                px = int(x / 8)
                py = int(y / 8)
                palette = collect_char_colors(image, x, y)
                byte_index = int(py / 8) * 0x800 + (py % 8) * 32 + px
                for cy in range(8):
                    byte = 0
                    for cx in range(8):
                        byte = byte * 2
                        pixel = list(image.getpixel((x + cx, y + cy)))
                        if palette[0] != pixel:
                            byte = byte + 1
                    byte_buffer[byte_index] = byte
                    byte_index = byte_index + 0x100
                attr = palette_to_attr(palette)
                attr_buffer[py * 32 + px] = attr

        scr = open(out, 'wb')
        scr.write(bytearray(byte_buffer))
        scr.write(bytearray(attr_buffer))
        scr.close()

    return [image, out]


iface = gr.Interface(fn=generate,
                     title="ZX-Spectrum inspired images generator ",
                     inputs=["text", "number", "number"],
                     outputs=["image", "file"],
                     examples=[["Cute cat", 123, 20], ["Solar system", 123, 20], ["Disco ball", 123, 20]])
iface.launch()