echen01
working demo
2fec875
raw
history blame
3.06 kB
import gradio as gr
import utils.utils as utils
from PIL import Image
import torch
import math
from torchvision import transforms
from run_pti import run_PTI
device = "cpu"
years = [str(y) for y in range(1880, 2020, 10)]
decades = [y + "s" for y in years]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
orig_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
orig_models[year] = { "G": G.eval().float()}
def run_alignment(image_path,idx=None):
import dlib
from align_all_parallel import align_face
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
return aligned_image
def predict(inp, in_decade):
#with torch.no_grad():
inp.save("imgs/input.png")
inversion = run_alignment("imgs/input.png", idx=0)
inversion.save("imgs/cropped/input.png")
run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
#inversion = Image.open("imgs/cropped/input.png")
in_year = in_decade[:-1]
pti_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
pti_models[year] = { "G": G.eval().float()}
pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()
for year in years:
if year != in_year:
for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
with torch.no_grad():
delta = p_pti - p_orig
p += delta
space = 0
dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device)
border_width = 10
#fill_color = 'red'
dst.paste(inversion, (0, 0))
for i in range(0, len(years)):
year = str(years[i])
with torch.no_grad():
child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
img = utils.tensor2im(child_tensor.squeeze(0))
# if year == in_year:
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
dst.paste(img, ((256 + space) * (i+1), 0))
dst
return dst
gr.Interface(fn=predict,
inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")],
outputs=gr.Image(label="Decade Transformations", type="pil"),
examples=[["imgs/Steven-Yeun.jpg", "2010s"]]
).launch() #.launch(server_name="0.0.0.0", server_port=8098)