kuko6 commited on
Commit
6ccac38
·
1 Parent(s): c583015

updated image inputs

Browse files
Files changed (2) hide show
  1. app.py +53 -24
  2. data/content/cat.jpg +3 -0
app.py CHANGED
@@ -15,48 +15,75 @@ def denorm_img(img: torch.Tensor):
15
 
16
 
17
  def main(inp1, inp2, alph, out_size=256):
 
 
 
18
  model = Model()
19
- model.load_state_dict(torch.load("models/model_puddle.pt", map_location=torch.device(device)))
20
  model.eval()
21
 
22
  model.alpha = alph
23
-
24
- style = TF.to_tensor(inp1)
25
- content = TF.to_tensor(inp2)
26
 
27
  norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
- transform = transforms.Compose(
29
- [transforms.Resize(out_size, antialias=True), transforms.CenterCrop(out_size)]
30
- )
31
 
32
  style, content = norm(style), norm(content)
33
  style, content = transform(style), transform(content)
34
 
35
  style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device)
36
-
37
  out = model(content, style)
38
 
39
  return denorm_img(out[0].detach()).permute(1, 2, 0).numpy()
40
 
 
 
41
 
42
  with gr.Blocks() as demo:
43
  gr.Markdown("# Style Transfer with AdaIN")
44
- with gr.Row(variant="compact"):
45
- inp1 = gr.Image(type="pil", sources=["upload", "clipboard"], label="Style")
46
- inp2 = gr.Image(type="pil", sources=["upload", "clipboard"], label="Content")
47
- out = gr.Image(type="numpy", label="Output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  with gr.Row():
49
- out_size = (
50
- gr.Dropdown(
51
- choices=[256, 512],
52
- value=256,
53
- multiselect=False,
54
- interactive=True,
55
- allow_custom_value=True,
56
- label="Output size",
57
- info="Size of the output image",
58
- ),
59
  )
 
 
 
60
  alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3)
61
 
62
  with gr.Row():
@@ -74,13 +101,15 @@ with gr.Blocks() as demo:
74
  gr.Markdown("## Content Examples")
75
  gr.Examples(
76
  examples=[
77
- os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"),
 
78
  os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"),
79
  os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"),
80
  ],
81
  inputs=inp2,
82
  )
 
83
  btn = gr.Button("Run")
84
- btn.click(fn=main, inputs=[inp1, inp2, alph, out_size[0]], outputs=out)
85
 
86
  demo.launch()
 
15
 
16
 
17
  def main(inp1, inp2, alph, out_size=256):
18
+ # print("inp1 ", inp1)
19
+ # print("inp2 ", inp2)
20
+
21
  model = Model()
22
+ model.load_state_dict(torch.load("./models/model_puddle.pt", map_location=torch.device(device)))
23
  model.eval()
24
 
25
  model.alpha = alph
26
+ style = TF.to_tensor(inp1["composite"])
27
+ content = TF.to_tensor(inp2["composite"])
 
28
 
29
  norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+ transform = transforms.Compose([
31
+ transforms.Resize(out_size, antialias=True)
32
+ ])
33
 
34
  style, content = norm(style), norm(content)
35
  style, content = transform(style), transform(content)
36
 
37
  style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device)
 
38
  out = model(content, style)
39
 
40
  return denorm_img(out[0].detach()).permute(1, 2, 0).numpy()
41
 
42
+ def update_crop_size(crop_size):
43
+ return gr.update(crop_size=(crop_size, crop_size))
44
 
45
  with gr.Blocks() as demo:
46
  gr.Markdown("# Style Transfer with AdaIN")
47
+ with gr.Row(variant="compact", equal_height=False):
48
+ inp1 = gr.ImageEditor(
49
+ type="pil",
50
+ sources=["upload", "clipboard"],
51
+ crop_size=(256, 256),
52
+ eraser=False,
53
+ brush=False,
54
+ layers=False,
55
+ label="Style",
56
+ image_mode="RGB",
57
+ transforms="crop",
58
+ canvas_size=(512, 512)
59
+ )
60
+ inp2 = gr.ImageEditor(
61
+ type="pil",
62
+ sources=["upload", "clipboard"],
63
+ crop_size=(256, 256),
64
+ eraser=False,
65
+ brush=False,
66
+ layers=False,
67
+ label="Content",
68
+ image_mode="RGB",
69
+ transforms="crop",
70
+ canvas_size=(512, 512)
71
+ )
72
+ out = gr.Image(type="pil", label="Output")
73
+
74
  with gr.Row():
75
+ out_size = gr.Dropdown(
76
+ choices=[256, 512],
77
+ value=256,
78
+ multiselect=False,
79
+ interactive=True,
80
+ allow_custom_value=True,
81
+ label="Output size",
82
+ info="Size of the output image"
 
 
83
  )
84
+ out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp1)
85
+ out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp2)
86
+
87
  alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3)
88
 
89
  with gr.Row():
 
101
  gr.Markdown("## Content Examples")
102
  gr.Examples(
103
  examples=[
104
+ # os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"),
105
+ os.path.join(os.path.dirname(__file__), "data/content/cat.jpg"),
106
  os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"),
107
  os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"),
108
  ],
109
  inputs=inp2,
110
  )
111
+
112
  btn = gr.Button("Run")
113
+ btn.click(fn=main, inputs=[inp1, inp2, alph, out_size], outputs=out)
114
 
115
  demo.launch()
data/content/cat.jpg ADDED

Git LFS Details

  • SHA256: e566ddb954eece47c0c396eb0729484a385178d83eb2152774db72d905ed4d57
  • Pointer size: 131 Bytes
  • Size of remote file: 348 kB