subatomicseer's picture
Update app.py
ee2aef9
import os
import streamlit as st
import gdown
from packaging.version import Version
from infer_func import convert
ROOT = os.path.dirname(os.path.abspath(__file__))
EXAMPLES = {
'content': {
'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg'
},
'style': {
'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg'
}
}
VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx'
DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF'
VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth'
DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth'
@st.cache
def download_models():
with st.spinner(text="Downloading VGG weights..."):
gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME)
with st.spinner(text="Downloading Decoder weights..."):
gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME)
def image_getter(image_kind):
image = None
options = ['Use Example Image', 'Upload Image']
# if Version(st.__version__) >= Version('1.4.0'):
# options.append('Open Camera')
option = st.selectbox(
'Choose Image',
options, key=image_kind)
if option == 'Use Example Image':
image_key = st.selectbox(
'Choose from examples',
EXAMPLES[image_kind], key=image_kind)
image = EXAMPLES[image_kind][image_key]
elif option == 'Upload Image':
image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind)
elif option == 'Open Camera':
image = st.camera_input('', key=image_kind)
return image
if __name__ == '__main__':
st.set_page_config(layout="wide")
st.header('Adaptive Instance Normalization demo based on '
'[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)')
download_models()
# col1, col2, col3, col4 = st.columns((2, 2, 1, 3))
col1, col2, col3 = st.columns((3, 4, 4))
with col1:
st.subheader('Content Image')
content = image_getter('content')
st.subheader('Style Image')
style = image_getter('style')
with col2:
img1 = content if content is not None else 'examples/img.png'
img2 = style if style is not None else 'examples/img.png'
if img1 is not None:
st.image(img1, width=None, caption='Content Image')
if img2 is not None:
st.image(img2, width=None, caption='Style Image')
with col3:
color_control = st.checkbox('Preserve content image color')
alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01)
process = st.button('Stylize')
output_image = 'output.png'
if content is not None and style is not None and process:
print(content, style)
with st.spinner('Processing...'):
output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control)
if os.path.exists(output_image):
with col3:
st.image(output_image, width=None, caption='Stylized Image')