mehdidc commited on
Commit
ddfc4d0
·
1 Parent(s): 7155638

add ksparse

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. app.py +19 -7
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  convae.th filter=lfs diff=lfs merge=lfs -text
36
  deep_convae.th filter=lfs diff=lfs merge=lfs -text
 
 
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  convae.th filter=lfs diff=lfs merge=lfs -text
36
  deep_convae.th filter=lfs diff=lfs merge=lfs -text
37
+ fc_sparse.th filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torchvision
3
  import gradio as gr
@@ -5,13 +6,16 @@ from PIL import Image
5
  from cli import iterative_refinement
6
  from viz import grid_of_images_default
7
  models = {
8
- "convae": torch.load("convae.th", map_location="cpu"),
9
- "deep_convae": torch.load("deep_convae.th", map_location="cpu"),
 
10
  }
11
- def gen(md, model, seed, nb_iter, nb_samples, width, height):
12
  torch.manual_seed(int(seed))
13
  bs = 64
14
- model = models[model]
 
 
15
  samples = iterative_refinement(
16
  model,
17
  nb_iter=int(nb_iter),
@@ -19,20 +23,28 @@ def gen(md, model, seed, nb_iter, nb_samples, width, height):
19
  w=int(width), h=int(height), c=1,
20
  batch_size=bs,
21
  )
22
- grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1]))
 
 
 
 
 
 
23
  grid = (grid*255).astype("uint8")
24
  return Image.fromarray(grid)
25
 
26
  text = """
27
- Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`)
28
 
29
  These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
 
 
30
  """
31
  iface = gr.Interface(
32
  fn=gen,
33
  inputs=[
34
  gr.Markdown(text),
35
- gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)
36
  ],
37
  outputs="image"
38
  )
 
1
+ import math
2
  import torch
3
  import torchvision
4
  import gradio as gr
 
6
  from cli import iterative_refinement
7
  from viz import grid_of_images_default
8
  models = {
9
+ "ConvAE": torch.load("convae.th", map_location="cpu"),
10
+ "Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
11
+ "Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
12
  }
13
+ def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg):
14
  torch.manual_seed(int(seed))
15
  bs = 64
16
+ model = models[model_name]
17
+ if model == "Dense K-Sparse":
18
+ model.nb_active = nb_active
19
  samples = iterative_refinement(
20
  model,
21
  nb_iter=int(nb_iter),
 
23
  w=int(width), h=int(height), c=1,
24
  batch_size=bs,
25
  )
26
+ if only_last:
27
+ s = int(math.sqrt((nb_samples)))
28
+ grid = grid_of_images_default(samples[-1].numpy(), shape=(s, s))
29
+ else:
30
+ grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1]))
31
+ if not black_bg:
32
+ grid = 1 - grid
33
  grid = (grid*255).astype("uint8")
34
  return Image.fromarray(grid)
35
 
36
  text = """
37
+ Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`), Dense K-Sparse model (from [here](https://openreview.net/forum?id=r1QXQkSYg))
38
 
39
  These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
40
+
41
+ NB: `nb_active` is only used for the Dense K-Sparse, specifying nb of activations to keep in the last layer.
42
  """
43
  iface = gr.Interface(
44
  fn=gen,
45
  inputs=[
46
  gr.Markdown(text),
47
+ gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background")
48
  ],
49
  outputs="image"
50
  )