ebylmz commited on
Commit
70ea298
·
verified ·
1 Parent(s): 7f60893

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ from nst.transform_net import TransformNet
5
+ from nst.inference import stylize
6
+ from nst.config import get_style_configs
7
+
8
+ def stylize_pil_image(content_image, style_name, models):
9
+ return stylize(models[style_name], content_image, scaling_width=1024)
10
+
11
+ def select_style(selection: gr.SelectData):
12
+ """
13
+ Style Selector Callback
14
+ """
15
+ return selection.value['caption'] # returns the style name (e.g., "Starry Night")
16
+
17
+ def create_app():
18
+ with gr.Blocks() as app:
19
+ gr.Markdown("## 🎨 Neural Style Transfer")
20
+
21
+ with gr.Row():
22
+ with gr.Column():
23
+ content_input = gr.Image(label="Upload your content image", type="pil", height=500)
24
+
25
+ style_gallery = gr.Gallery(
26
+ label="Choose a style by clicking the image",
27
+ value=thumbnails,
28
+ show_label=True,
29
+ interactive=False,
30
+ columns=len(thumbnails),
31
+ # rows=2,
32
+ height=250
33
+ )
34
+
35
+ selected_style = gr.State(value='Starry Night') # default
36
+ style_gallery.select(fn=select_style, outputs=selected_style)
37
+
38
+ submit_btn = gr.Button("Stylize")
39
+
40
+ with gr.Column():
41
+ result_output = gr.Image(label="Stylized Result", type="pil", height=500)
42
+
43
+ submit_btn.click(
44
+ fn=lambda img, style: stylize_pil_image(img, style, models),
45
+ inputs=[content_input, selected_style],
46
+ outputs=result_output
47
+ )
48
+
49
+ return app
50
+
51
+
52
+ style_configs = get_style_configs()
53
+ thumbnails = [(cfg.style_image_path, cfg.style_name) for cfg in style_configs] # (img_path, caption)
54
+
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ models = {}
57
+ for cfg in style_configs:
58
+ model = TransformNet().to(device)
59
+ model.load_state_dict(torch.load(cfg.model_path, map_location=device))
60
+ model.eval()
61
+ models[cfg.style_name] = model
62
+
63
+ app = create_app()
64
+ app.launch(debug=True, share=True)