# Import general purpose libraries
import os, re, time
import streamlit as st
import PIL
import cv2
import numpy as np
import uuid
from zipfile import ZipFile, ZIP_DEFLATED
from io import BytesIO
from random import randint
# Import util functions from deoldify
# NOTE: This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices: CPU, GPU0...GPU7
device.set(device=DeviceId.CPU)
from deoldify.visualize import *
# Import util functions from app_utils
from app_utils import get_model_bin
SESSION_STATE_VARIABLES = [
'model_folder','max_img_size','uploaded_file_key','uploaded_files'
]
for i in SESSION_STATE_VARIABLES:
if i not in st.session_state:
st.session_state[i] = None
#### SET INPUT PARAMS ###########
if not st.session_state.model_folder: st.session_state.model_folder = 'models/'
if not st.session_state.max_img_size: st.session_state.max_img_size = 800
################################
@st.cache(allow_output_mutation=True, show_spinner=False)
def load_model(model_dir, option):
if option.lower() == 'artistic':
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
colorizer = get_image_colorizer(artistic=True)
elif option.lower() == 'stable':
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
colorizer = get_image_colorizer(artistic=False)
return colorizer
def resize_img(input_img, max_size):
img = input_img.copy()
img_height, img_width = img.shape[0],img.shape[1]
if max(img_height, img_width) > max_size:
if img_height > img_width:
new_width = img_width*(max_size/img_height)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
elif img_height <= img_width:
new_width = img_height*(max_size/img_width)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
return img
def get_image_download_link(img, filename, button_text):
button_uuid = str(uuid.uuid4()).replace('-', '')
button_id = re.sub('\d+', '', button_uuid)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return get_button_html_code(img_str, filename, 'txt', button_id, button_text)
def get_button_html_code(data_str, filename, filetype, button_id, button_txt='Download file'):
custom_css = f"""
"""
href = custom_css + f'{button_txt}'
return href
def display_single_image(uploaded_file, img_size=800):
st_title_message.markdown("**Processing your image, please wait** ⌛")
img_name = uploaded_file.name
# Open the image
pil_img = PIL.Image.open(uploaded_file)
img_rgb = np.array(pil_img)
resized_img_rgb = resize_img(img_rgb, img_size)
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
# Send the image to the model
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
# Plot images
st_input_img.image(resized_pil_img, 'Input image', use_column_width=True)
st_output_img.image(output_pil_img, 'Output image', use_column_width=True)
# Show download button
st_download_button.markdown(get_image_download_link(output_pil_img, img_name, 'Download Image'), unsafe_allow_html=True)
# Reset the message
st_title_message.markdown("**To begin, please upload an image** 👇")
def process_multiple_images(uploaded_files, img_size=800):
num_imgs = len(uploaded_files)
output_images_list = []
img_names_list = []
idx = 1
st_progress_bar.progress(0)
for idx, uploaded_file in enumerate(uploaded_files, start=1):
st_title_message.markdown("**Processing image {}/{}. Please wait** ⌛".format(idx,
num_imgs))
img_name = uploaded_file.name
img_type = uploaded_file.type
# Open the image
pil_img = PIL.Image.open(uploaded_file)
img_rgb = np.array(pil_img)
resized_img_rgb = resize_img(img_rgb, img_size)
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
# Send the image to the model
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
output_images_list.append(output_pil_img)
img_names_list.append(img_name.split('.')[0])
percent = int((idx / num_imgs)*100)
st_progress_bar.progress(percent)
# Zip output files
zip_path = 'processed_images.zip'
zip_buf = zip_multiple_images(output_images_list, img_names_list, zip_path)
st_download_button.download_button(
label='Download ZIP file',
data=zip_buf.read(),
file_name=zip_path,
mime="application/zip"
)
# Show message
st_title_message.markdown("**Images are ready for download** 💾")
def zip_multiple_images(pil_images_list, img_names_list, dest_path):
# Create zip file on memory
zip_buf = BytesIO()
with ZipFile(zip_buf, 'w', ZIP_DEFLATED) as zipObj:
for pil_img, img_name in zip(pil_images_list, img_names_list):
with BytesIO() as output:
# Save image in memory
pil_img.save(output, format="PNG")
# Read data
contents = output.getvalue()
# Write it to zip file
zipObj.writestr(img_name+".png", contents)
zip_buf.seek(0)
return zip_buf
###########################
###### STREAMLIT CODE #####
###########################
# General configuration
# st.set_page_config(layout="centered")
st.set_page_config(layout="wide")
st.set_option('deprecation.showfileUploaderEncoding', False)
st.markdown('''