File size: 3,056 Bytes
0513aaf
2fec875
0513aaf
 
 
 
2fec875
0513aaf
 
2fec875
0513aaf
 
2fec875
 
 
 
 
0513aaf
 
 
 
2fec875
0513aaf
 
dd1add1
 
 
 
2fec875
 
c00162e
dd1add1
 
2fec875
0513aaf
c00162e
2fec875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0513aaf
 
 
2fec875
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import utils.utils as utils
from PIL import Image
import torch
import math
from torchvision import transforms
from run_pti import run_PTI
device = "cpu"
years = [str(y) for y in range(1880, 2020, 10)]
decades = [y + "s" for y in years]


transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

orig_models = {}

for year in years:
    G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
    orig_models[year] = { "G": G.eval().float()}
    

def run_alignment(image_path,idx=None):
    import dlib
    from align_all_parallel import align_face
    predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
    aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)     
    

    return aligned_image 

def predict(inp, in_decade):
  #with torch.no_grad():     
  inp.save("imgs/input.png")
  inversion = run_alignment("imgs/input.png", idx=0)
  inversion.save("imgs/cropped/input.png")
  run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
  #inversion = Image.open("imgs/cropped/input.png")
  
  in_year = in_decade[:-1]
  pti_models = {}

  for year in years:
      G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
      pti_models[year] = { "G": G.eval().float()}


  pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()

  for year in years:
      if year != in_year:
          for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
              with torch.no_grad():
                  delta = p_pti - p_orig
                  p += delta 

  space = 0
  dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')


  w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device)  

  border_width = 10
  #fill_color = 'red'
  dst.paste(inversion, (0, 0))



  for i in range(0, len(years)):
      year = str(years[i])
      with torch.no_grad():
          child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
      img = utils.tensor2im(child_tensor.squeeze(0))
      #     if year == in_year:
      #         img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
      #         img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
      dst.paste(img, ((256 + space) * (i+1), 0))  
  dst
  return dst


gr.Interface(fn=predict, 
             inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")],
             outputs=gr.Image(label="Decade Transformations", type="pil"),
             examples=[["imgs/Steven-Yeun.jpg", "2010s"]]
             
             ).launch() #.launch(server_name="0.0.0.0", server_port=8098)