|
import streamlit as st |
|
from PIL import Image |
|
from io import BytesIO |
|
from collections import namedtuple |
|
import numpy as np |
|
|
|
from src.simswap import SimSwap |
|
|
|
|
|
def run(model): |
|
id_image = None |
|
attr_image = None |
|
specific_image = None |
|
output = None |
|
|
|
def get_np_image(file): |
|
return np.array(Image.open(file))[:, :, :3] |
|
|
|
with st.sidebar: |
|
uploaded_file = st.file_uploader("Select an ID image") |
|
if uploaded_file is not None: |
|
id_image = get_np_image(uploaded_file) |
|
|
|
uploaded_file = st.file_uploader("Select an Attribute image") |
|
if uploaded_file is not None: |
|
attr_image = get_np_image(uploaded_file) |
|
|
|
uploaded_file = st.file_uploader("Select a specific person image (Optional)") |
|
if uploaded_file is not None: |
|
specific_image = get_np_image(uploaded_file) |
|
|
|
face_alignment_type = st.radio("Face alignment type:", ("none", "ffhq")) |
|
|
|
enhance_output = st.radio("Enhance output:", ("yes", "no")) |
|
|
|
smooth_mask_iter = st.slider( |
|
label="smooth_mask_iter", min_value=1, max_value=60, step=1, value=7 |
|
) |
|
|
|
smooth_mask_kernel_size = st.slider( |
|
label="smooth_mask_kernel_size", min_value=1, max_value=61, step=2, value=17 |
|
) |
|
|
|
smooth_mask_threshold = st.slider(label="smooth_mask_threshold", min_value=0.01, max_value=1.0, step=0.01, value=0.9) |
|
|
|
specific_latent_match_threshold = st.slider( |
|
label="specific_latent_match_threshold", |
|
min_value=0.0, |
|
max_value=10.0, |
|
value=0.05, |
|
) |
|
|
|
num_cols = sum( |
|
(id_image is not None, attr_image is not None, specific_image is not None) |
|
) |
|
cols = st.columns(num_cols if num_cols > 0 else 1) |
|
i = 0 |
|
|
|
if id_image is not None: |
|
with cols[i]: |
|
i += 1 |
|
st.header("ID image") |
|
st.image(id_image) |
|
|
|
if attr_image is not None: |
|
with cols[i]: |
|
i += 1 |
|
st.header("Attribute image") |
|
st.image(attr_image) |
|
|
|
if specific_image is not None: |
|
with cols[i]: |
|
st.header("Specific image") |
|
st.image(specific_image) |
|
|
|
if id_image is not None and attr_image is not None: |
|
model.set_face_alignment_type(face_alignment_type) |
|
model.set_smooth_mask_iter(smooth_mask_iter) |
|
model.set_smooth_mask_kernel_size(smooth_mask_kernel_size) |
|
model.set_smooth_mask_threshold(smooth_mask_threshold) |
|
model.set_specific_latent_match_threshold(specific_latent_match_threshold) |
|
model.enhance_output = True if enhance_output == "yes" else False |
|
|
|
model.specific_latent = None |
|
model.specific_id_image = specific_image if specific_image is not None else None |
|
|
|
model.id_latent = None |
|
model.id_image = id_image |
|
|
|
output = model(attr_image) |
|
|
|
if output is not None: |
|
with st.container(): |
|
st.header("SimSwap output") |
|
st.image(output) |
|
|
|
output_to_download = Image.fromarray(output.astype("uint8"), "RGB") |
|
buf = BytesIO() |
|
output_to_download.save(buf, format="JPEG") |
|
|
|
st.download_button( |
|
label="Download", |
|
data=buf.getvalue(), |
|
file_name="output.jpg", |
|
mime="image/jpeg", |
|
) |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(config): |
|
return SimSwap( |
|
config=config, |
|
id_image=None, |
|
specific_image=None, |
|
) |
|
|
|
|
|
|
|
Config = namedtuple( |
|
"Config", |
|
"face_detector_weights" |
|
+ " face_id_weights" |
|
+ " parsing_model_weights" |
|
+ " simswap_weights" |
|
+ " gfpgan_weights" |
|
+ " blend_module_weights" |
|
+ " device" |
|
+ " crop_size" |
|
+ " checkpoint_type" |
|
+ " face_alignment_type" |
|
+ " smooth_mask_iter" |
|
+ " smooth_mask_kernel_size" |
|
+ " smooth_mask_threshold" |
|
+ " face_detector_threshold" |
|
+ " specific_latent_match_threshold" |
|
+ " enhance_output", |
|
) |
|
|
|
if __name__ == "__main__": |
|
config = Config( |
|
face_detector_weights="weights/scrfd_10g_bnkps.onnx", |
|
face_id_weights="weights/arcface_net.jit", |
|
parsing_model_weights="weights/79999_iter.pth", |
|
simswap_weights="weights/latest_net_G.pth", |
|
gfpgan_weights="weights/GFPGANv1.4_ema.pth", |
|
blend_module_weights="weights/blend.jit", |
|
device="cuda", |
|
crop_size=224, |
|
checkpoint_type="official_224", |
|
face_alignment_type="none", |
|
smooth_mask_iter=7, |
|
smooth_mask_kernel_size=17, |
|
smooth_mask_threshold=0.9, |
|
face_detector_threshold=0.6, |
|
specific_latent_match_threshold=0.05, |
|
enhance_output=True |
|
) |
|
|
|
model = load_model(config) |
|
run(model) |
|
|