import gradio as gr import utils.utils as utils from PIL import Image import torch import math from torchvision import transforms from run_pti import run_PTI device = "cpu" years = [str(y) for y in range(1880, 2020, 10)] decades = [y + "s" for y in years] transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) orig_models = {} for year in years: G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device) orig_models[year] = { "G": G.eval().float()} def run_alignment(image_path,idx=None): import dlib from align_all_parallel import align_face predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat") aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx) return aligned_image def predict(inp, in_decade): #with torch.no_grad(): inp.save("imgs/input.png") inversion = run_alignment("imgs/input.png", idx=0) inversion.save("imgs/cropped/input.png") run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False) #inversion = Image.open("imgs/cropped/input.png") in_year = in_decade[:-1] pti_models = {} for year in years: G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device) pti_models[year] = { "G": G.eval().float()} pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float() for year in years: if year != in_year: for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()): with torch.no_grad(): delta = p_pti - p_orig p += delta space = 0 dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white') w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device) border_width = 10 #fill_color = 'red' dst.paste(inversion, (0, 0)) for i in range(0, len(years)): year = str(years[i]) with torch.no_grad(): child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True) img = utils.tensor2im(child_tensor.squeeze(0)) # if year == in_year: # img = img.crop((border_width, border_width, 256 - border_width, 256-border_width)) # img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color) dst.paste(img, ((256 + space) * (i+1), 0)) dst return dst gr.Interface(fn=predict, inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")], outputs=gr.Image(label="Decade Transformations", type="pil"), examples=[["imgs/Steven-Yeun.jpg", "2010s"]] ).launch() #.launch(server_name="0.0.0.0", server_port=8098)