Spaces:
Runtime error
Runtime error
import os | |
import streamlit as st | |
import gdown | |
from packaging.version import Version | |
from infer_func import convert | |
ROOT = os.path.dirname(os.path.abspath(__file__)) | |
EXAMPLES = { | |
'content': { | |
'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg' | |
}, | |
'style': { | |
'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg' | |
} | |
} | |
VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx' | |
DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF' | |
VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth' | |
DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth' | |
def download_models(): | |
with st.spinner(text="Downloading VGG weights..."): | |
gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME) | |
with st.spinner(text="Downloading Decoder weights..."): | |
gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME) | |
def image_getter(image_kind): | |
image = None | |
options = ['Use Example Image', 'Upload Image'] | |
# if Version(st.__version__) >= Version('1.4.0'): | |
# options.append('Open Camera') | |
option = st.selectbox( | |
'Choose Image', | |
options, key=image_kind) | |
if option == 'Use Example Image': | |
image_key = st.selectbox( | |
'Choose from examples', | |
EXAMPLES[image_kind], key=image_kind) | |
image = EXAMPLES[image_kind][image_key] | |
elif option == 'Upload Image': | |
image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind) | |
elif option == 'Open Camera': | |
image = st.camera_input('', key=image_kind) | |
return image | |
if __name__ == '__main__': | |
st.set_page_config(layout="wide") | |
st.header('Adaptive Instance Normalization demo based on ' | |
'[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)') | |
download_models() | |
# col1, col2, col3, col4 = st.columns((2, 2, 1, 3)) | |
col1, col2, col3 = st.columns((3, 4, 4)) | |
with col1: | |
st.subheader('Content Image') | |
content = image_getter('content') | |
st.subheader('Style Image') | |
style = image_getter('style') | |
with col2: | |
img1 = content if content is not None else 'examples/img.png' | |
img2 = style if style is not None else 'examples/img.png' | |
if img1 is not None: | |
st.image(img1, width=None, caption='Content Image') | |
if img2 is not None: | |
st.image(img2, width=None, caption='Style Image') | |
with col3: | |
color_control = st.checkbox('Preserve content image color') | |
alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01) | |
process = st.button('Stylize') | |
output_image = 'output.png' | |
if content is not None and style is not None and process: | |
print(content, style) | |
with st.spinner('Processing...'): | |
output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control) | |
if os.path.exists(output_image): | |
with col3: | |
st.image(output_image, width=None, caption='Stylized Image') | |