porestar commited on
Commit
a4048c1
·
1 Parent(s): 9461745

Create new file

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers.models import UNet2DModel
4
+ from huggingface_hub import hf_hub_download
5
+ from oadg.sampling import sample, make_conditional_paths_and_realization, initialize_empty_realizations_and_paths
6
+ from oadg.sampling import evaluate_entropy
7
+
8
+ image_size = 32
9
+ batch_size = 1
10
+ device = 'cpu'
11
+
12
+ path = hf_hub_download(repo_id="porestar/oadg_mnist_32", filename="model.pt")
13
+
14
+ model = UNet2DModel(
15
+ sample_size=32,
16
+ in_channels=2,
17
+ out_channels=2,
18
+ layers_per_block=2,
19
+ block_out_channels=(64, 64, 128, 128),
20
+ down_block_types=(
21
+ "DownBlock2D",
22
+ "DownBlock2D",
23
+ "AttnDownBlock2D",
24
+ "DownBlock2D",
25
+ ),
26
+ up_block_types=(
27
+ "UpBlock2D",
28
+ "AttnUpBlock2D",
29
+ "UpBlock2D",
30
+ "UpBlock2D",
31
+ ),
32
+ )
33
+
34
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
35
+
36
+ model = model.to(device)
37
+
38
+
39
+ def sample_image(img):
40
+ if img is None:
41
+ idx_start, random_paths, realization = initialize_empty_realizations_and_paths(batch_size, image_size, image_size, device=device)
42
+ else:
43
+ img = (img > 0).astype(int)
44
+ idx_start, random_paths, realization = make_conditional_paths_and_realization(img, batch_size=batch_size, device=device)
45
+
46
+ img = sample(model, batch_size=batch_size, image_size=image_size,
47
+ realization=realization, idx_start=idx_start, random_paths=random_paths, device=device)
48
+ img = img.reshape(image_size, image_size) * 255
49
+
50
+ entropy = evaluate_entropy(model, batch_size=batch_size, image_size=image_size,
51
+ realization=realization, idx_start=idx_start, random_paths=random_paths, device=device)
52
+ entropy = (entropy.reshape(image_size, image_size) * 255).astype(int)
53
+
54
+ return entropy, img
55
+
56
+
57
+ img = gr.Image(image_mode="L", source="canvas", shape=(image_size, image_size), invert_colors=True, label="Drawing Canvas")
58
+ out_realization = gr.Image(image_mode="L", shape=(image_size, image_size), invert_colors=True, label="Sample Realization")
59
+ out_entropy = gr.Image(image_mode="L", shape=(image_size, image_size), invert_colors=True, label="Entropy of Drawn Data")
60
+
61
+ demo = gr.Interface(fn=sample_image, inputs=img, outputs=[out_entropy, out_realization],
62
+ title="Order Agnostic Autoregressive Diffusion MNIST Demo",
63
+ description="""Sample conditional or unconditional images by drawing into the canvas.
64
+ Outputs a random sampled realization and predicted entropy under the trained model for the conditioning data.""")
65
+
66
+
67
+ demo.launch()