File size: 2,580 Bytes
a4048c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
from diffusers.models import UNet2DModel
from huggingface_hub import hf_hub_download
from oadg.sampling import sample, make_conditional_paths_and_realization, initialize_empty_realizations_and_paths
from oadg.sampling import evaluate_entropy

image_size = 32
batch_size = 1
device = 'cpu'

path = hf_hub_download(repo_id="porestar/oadg_mnist_32", filename="model.pt")

model = UNet2DModel(
    sample_size=32,
    in_channels=2,
    out_channels=2,
    layers_per_block=2,
    block_out_channels=(64, 64, 128, 128),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))

model = model.to(device)


def sample_image(img):
    if img is None:
        idx_start, random_paths, realization = initialize_empty_realizations_and_paths(batch_size, image_size, image_size, device=device)
    else:
        img = (img > 0).astype(int)
        idx_start, random_paths, realization = make_conditional_paths_and_realization(img, batch_size=batch_size, device=device)

    img = sample(model, batch_size=batch_size, image_size=image_size,
                 realization=realization, idx_start=idx_start, random_paths=random_paths, device=device)
    img = img.reshape(image_size, image_size) * 255

    entropy = evaluate_entropy(model, batch_size=batch_size, image_size=image_size,
                               realization=realization, idx_start=idx_start, random_paths=random_paths, device=device)
    entropy = (entropy.reshape(image_size, image_size) * 255).astype(int)

    return entropy, img


img = gr.Image(image_mode="L", source="canvas", shape=(image_size, image_size), invert_colors=True, label="Drawing Canvas")
out_realization = gr.Image(image_mode="L", shape=(image_size, image_size), invert_colors=True, label="Sample Realization")
out_entropy = gr.Image(image_mode="L", shape=(image_size, image_size), invert_colors=True, label="Entropy of Drawn Data")

demo = gr.Interface(fn=sample_image, inputs=img, outputs=[out_entropy, out_realization],
                    title="Order Agnostic Autoregressive Diffusion MNIST Demo",
                    description="""Sample conditional or unconditional images by drawing into the canvas.
                    Outputs a random sampled realization and predicted entropy under the trained model for the conditioning data.""")


demo.launch()