File size: 4,826 Bytes
e6a22e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import streamlit as st
from PIL import Image
from io import BytesIO
from collections import namedtuple
import numpy as np
from src.simswap import SimSwap
def run(model):
id_image = None
attr_image = None
specific_image = None
output = None
def get_np_image(file):
return np.array(Image.open(file))[:, :, :3]
with st.sidebar:
uploaded_file = st.file_uploader("Select an ID image")
if uploaded_file is not None:
id_image = get_np_image(uploaded_file)
uploaded_file = st.file_uploader("Select an Attribute image")
if uploaded_file is not None:
attr_image = get_np_image(uploaded_file)
uploaded_file = st.file_uploader("Select a specific person image (Optional)")
if uploaded_file is not None:
specific_image = get_np_image(uploaded_file)
face_alignment_type = st.radio("Face alignment type:", ("none", "ffhq"))
enhance_output = st.radio("Enhance output:", ("yes", "no"))
smooth_mask_iter = st.slider(
label="smooth_mask_iter", min_value=1, max_value=60, step=1, value=7
)
smooth_mask_kernel_size = st.slider(
label="smooth_mask_kernel_size", min_value=1, max_value=61, step=2, value=17
)
smooth_mask_threshold = st.slider(label="smooth_mask_threshold", min_value=0.01, max_value=1.0, step=0.01, value=0.9)
specific_latent_match_threshold = st.slider(
label="specific_latent_match_threshold",
min_value=0.0,
max_value=10.0,
value=0.05,
)
num_cols = sum(
(id_image is not None, attr_image is not None, specific_image is not None)
)
cols = st.columns(num_cols if num_cols > 0 else 1)
i = 0
if id_image is not None:
with cols[i]:
i += 1
st.header("ID image")
st.image(id_image)
if attr_image is not None:
with cols[i]:
i += 1
st.header("Attribute image")
st.image(attr_image)
if specific_image is not None:
with cols[i]:
st.header("Specific image")
st.image(specific_image)
if id_image is not None and attr_image is not None:
model.set_face_alignment_type(face_alignment_type)
model.set_smooth_mask_iter(smooth_mask_iter)
model.set_smooth_mask_kernel_size(smooth_mask_kernel_size)
model.set_smooth_mask_threshold(smooth_mask_threshold)
model.set_specific_latent_match_threshold(specific_latent_match_threshold)
model.enhance_output = True if enhance_output == "yes" else False
model.specific_latent = None
model.specific_id_image = specific_image if specific_image is not None else None
model.id_latent = None
model.id_image = id_image
output = model(attr_image)
if output is not None:
with st.container():
st.header("SimSwap output")
st.image(output)
output_to_download = Image.fromarray(output.astype("uint8"), "RGB")
buf = BytesIO()
output_to_download.save(buf, format="JPEG")
st.download_button(
label="Download",
data=buf.getvalue(),
file_name="output.jpg",
mime="image/jpeg",
)
@st.cache(allow_output_mutation=True)
def load_model(config):
return SimSwap(
config=config,
id_image=None,
specific_image=None,
)
# TODO: remove it and use config files from 'configs'
Config = namedtuple(
"Config",
"face_detector_weights"
+ " face_id_weights"
+ " parsing_model_weights"
+ " simswap_weights"
+ " gfpgan_weights"
+ " blend_module_weights"
+ " device"
+ " crop_size"
+ " checkpoint_type"
+ " face_alignment_type"
+ " smooth_mask_iter"
+ " smooth_mask_kernel_size"
+ " smooth_mask_threshold"
+ " face_detector_threshold"
+ " specific_latent_match_threshold"
+ " enhance_output",
)
if __name__ == "__main__":
config = Config(
face_detector_weights="weights/scrfd_10g_bnkps.onnx",
face_id_weights="weights/arcface_net.jit",
parsing_model_weights="weights/79999_iter.pth",
simswap_weights="weights/latest_net_G.pth",
gfpgan_weights="weights/GFPGANv1.4_ema.pth",
blend_module_weights="weights/blend.jit",
device="cuda",
crop_size=224,
checkpoint_type="official_224",
face_alignment_type="none",
smooth_mask_iter=7,
smooth_mask_kernel_size=17,
smooth_mask_threshold=0.9,
face_detector_threshold=0.6,
specific_latent_match_threshold=0.05,
enhance_output=True
)
model = load_model(config)
run(model)
|