Spaces:
Runtime error
Runtime error
File size: 3,056 Bytes
0513aaf 2fec875 0513aaf 2fec875 0513aaf 2fec875 0513aaf 2fec875 0513aaf 2fec875 0513aaf dd1add1 2fec875 c00162e dd1add1 2fec875 0513aaf c00162e 2fec875 0513aaf 2fec875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import utils.utils as utils
from PIL import Image
import torch
import math
from torchvision import transforms
from run_pti import run_PTI
device = "cpu"
years = [str(y) for y in range(1880, 2020, 10)]
decades = [y + "s" for y in years]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
orig_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
orig_models[year] = { "G": G.eval().float()}
def run_alignment(image_path,idx=None):
import dlib
from align_all_parallel import align_face
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
return aligned_image
def predict(inp, in_decade):
#with torch.no_grad():
inp.save("imgs/input.png")
inversion = run_alignment("imgs/input.png", idx=0)
inversion.save("imgs/cropped/input.png")
run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
#inversion = Image.open("imgs/cropped/input.png")
in_year = in_decade[:-1]
pti_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
pti_models[year] = { "G": G.eval().float()}
pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()
for year in years:
if year != in_year:
for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
with torch.no_grad():
delta = p_pti - p_orig
p += delta
space = 0
dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device)
border_width = 10
#fill_color = 'red'
dst.paste(inversion, (0, 0))
for i in range(0, len(years)):
year = str(years[i])
with torch.no_grad():
child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
img = utils.tensor2im(child_tensor.squeeze(0))
# if year == in_year:
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
dst.paste(img, ((256 + space) * (i+1), 0))
dst
return dst
gr.Interface(fn=predict,
inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")],
outputs=gr.Image(label="Decade Transformations", type="pil"),
examples=[["imgs/Steven-Yeun.jpg", "2010s"]]
).launch() #.launch(server_name="0.0.0.0", server_port=8098)
|