File size: 3,123 Bytes
7999e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d0f78
 
ee2aef9
7999e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2612f87
7999e5a
 
2612f87
 
7999e5a
2612f87
 
7999e5a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import streamlit as st
import gdown
from packaging.version import Version

from infer_func import convert

ROOT = os.path.dirname(os.path.abspath(__file__))

EXAMPLES = {
    'content': {
        'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg'
    },
    'style': {
        'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg'
    }
}

VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx'
DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF'

VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth'
DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth'


@st.cache
def download_models():
    with st.spinner(text="Downloading VGG weights..."):
        gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME)
    with st.spinner(text="Downloading Decoder weights..."):
        gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME)


def image_getter(image_kind):

    image = None

    options = ['Use Example Image', 'Upload Image']

    # if Version(st.__version__) >= Version('1.4.0'):
    #    options.append('Open Camera')
    
    option = st.selectbox(
        'Choose Image',
        options, key=image_kind)

    if option == 'Use Example Image':
        image_key = st.selectbox(
            'Choose from examples',
            EXAMPLES[image_kind], key=image_kind)
        image = EXAMPLES[image_kind][image_key]

    elif option == 'Upload Image':
        image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind)
    elif option == 'Open Camera':
        image = st.camera_input('', key=image_kind)

    return image


if __name__ == '__main__':

    st.set_page_config(layout="wide")
    st.header('Adaptive Instance Normalization demo based on '
              '[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)')

    download_models()
    # col1, col2, col3, col4 = st.columns((2, 2, 1, 3))
    col1, col2, col3 = st.columns((3, 4, 4))
    with col1:
        st.subheader('Content Image')
        content = image_getter('content')
        st.subheader('Style Image')
        style = image_getter('style')
    with col2:
        img1 = content if content is not None else 'examples/img.png'
        img2 = style if style is not None else 'examples/img.png'
        if img1 is not None:
            st.image(img1, width=None, caption='Content Image')
        if img2 is not None:
            st.image(img2, width=None, caption='Style Image')

    with col3:
        color_control = st.checkbox('Preserve content image color')
        alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01)
        process = st.button('Stylize')

    output_image = 'output.png'
    if content is not None and style is not None and process:
        print(content, style)
        with st.spinner('Processing...'):
            output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control)

    if os.path.exists(output_image):
        with col3:
            st.image(output_image, width=None, caption='Stylized Image')