{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 Erik Härkönen. All rights reserved.\n",
    "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License. You may obtain a copy\n",
    "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "# Unless required by applicable law or agreed to in writing, software distributed under\n",
    "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n",
    "# OF ANY KIND, either express or implied. See the License for the specific language\n",
    "# governing permissions and limitations under the License.\n",
    "\n",
    "# Teaser: sequence of 3 interesting edits\n",
    "%matplotlib inline\n",
    "from notebook_init import *\n",
    "\n",
    "rand = lambda : np.random.randint(np.iinfo(np.int32).max)\n",
    "outdir = Path('out/figures/teaser')\n",
    "makedirs(outdir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_model(model_name, class_name, layer_name):\n",
    "    global inst, model, lat_comp, lat_mean, lat_std\n",
    "\n",
    "    use_w = 'StyleGAN' in model_name\n",
    "    inst = get_instrumented_model(model_name, class_name, layer_name, device, use_w=use_w, inst=inst)\n",
    "    model = inst.model\n",
    "\n",
    "    pc_config = Config(components=80, n=1_000_000, batch_size=200,\n",
    "        layer=layer_name, model=model_name, output_class=class_name, use_w=use_w)\n",
    "    dump_name = get_or_compute(pc_config, inst)\n",
    "\n",
    "    with np.load(dump_name) as data:\n",
    "        lat_comp = data['lat_comp']\n",
    "        lat_mean = data['lat_mean']\n",
    "        lat_std = data['lat_stdev']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_edit(seeds, edit_sequence, save_images=False, crop=None):\n",
    "    max_figs = 1000 if save_images else 10\n",
    "\n",
    "    for seed in seeds[:max_figs]:\n",
    "        w = model.sample_latent(1, seed=seed).cpu().numpy()\n",
    "        w = [w]*model.get_max_latents()\n",
    "        imgs = []\n",
    "        \n",
    "        # Starting point\n",
    "        imgs.append(model.sample_np(w))\n",
    "        \n",
    "        # Perform edits in order\n",
    "        for edit in edit_sequence:\n",
    "            (idx, start, end, strength, invert) = configs[edit]\n",
    "            \n",
    "            # Find out coordinate of w along PC\n",
    "            w_centered = w[0] - lat_mean\n",
    "            w_coord = np.sum(w_centered.reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]\n",
    "            \n",
    "            # Invert property if desired (e.g. flip rotation)\n",
    "            # Otherwise reinforce existing\n",
    "            if invert:\n",
    "                sign = w_coord / np.abs(w_coord)\n",
    "                target = -sign*strength # opposite side of mean\n",
    "            else:\n",
    "                target = strength\n",
    "                \n",
    "            delta = target - w_coord # offset vector\n",
    "            \n",
    "            for l in range(start, end):\n",
    "                w[l] = w[l] + lat_comp[idx]*lat_std[idx]*delta\n",
    "            imgs.append(model.sample_np(w))\n",
    "        \n",
    "        # Crop away black borders\n",
    "        if crop:\n",
    "            imgs = [img[crop[0]:-crop[1], crop[2]:-crop[3], :] for img in imgs]\n",
    "\n",
    "        if save_images:\n",
    "            # Save to disk\n",
    "            for i, img in enumerate(imgs):\n",
    "                Image.fromarray((img*255).astype(np.uint8)).save(outdir / f'teaser_{seed}_{i}.png')\n",
    "        \n",
    "        # Show in notebook\n",
    "        strip = np.hstack(imgs)\n",
    "        #strip = strip[::2, ::2, :] # 2x downscale for preview\n",
    "        plt.figure(figsize=(30,5))\n",
    "        plt.imshow(strip, interpolation='bilinear')\n",
    "        plt.axis('off')\n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (idx, edit_start, edit_end, strength, invert)\n",
    "configs = {\n",
    "    # StyleGAN2 cars W\n",
    "    'Redness':          (22,  9, 11,   -8, False),\n",
    "    'Horizontal flip':  ( 0,  0,  5,  2.0, True),\n",
    "    'Add grass':        (41,  9, 11,  -18, False),\n",
    "    'Blocky shape':     (16,  3,  6,   25, False),\n",
    "\n",
    "    # BigGAN-512 irish_setter\n",
    "    'Move right':       ( 0,  0, 15, -1.5, False),\n",
    "    'Rotate':           ( 3,  0, 15,  -0.5, False),\n",
    "    'Move back':        ( 4,  0, 15,  2.5, False),\n",
    "    'Zoom in':          ( 6,  0, 15, -2.0, False),\n",
    "    'Zoom out':         (12,  0, 15, -4.0, False),\n",
    "    'Sharpen BG':       (13,  6,  9, 20.0, False),\n",
    "    'Camera down':      (15,  1,  6, -4.0, False),\n",
    "    'Light right':      (28,  7,  8,  30, False),\n",
    "    'Pixelate':         (46, 10, 11,  -25, False),\n",
    "    'Reeds':            (61,  4,  8,  -15, False),\n",
    "    'Dry BG':           (65,  6,  8,  -30, False),\n",
    "    'Grass length':     (69,  5,  8,   15, False),\n",
    "\n",
    "    # StyleGAN2 ffhq\n",
    "    'frizzy_hair':             (31,  2,  6,  20, False),\n",
    "    'background_blur':         (49,  6,  9,  20, False),\n",
    "    'bald':                    (21,  2,  5,  20, False),\n",
    "    'big_smile':               (19,  4,  5,  20, False),\n",
    "    'caricature_smile':        (26,  3,  8,  13, False),\n",
    "    'scary_eyes':              (33,  6,  8,  20, False),\n",
    "    'curly_hair':              (47,  3,  6,  20, False),\n",
    "    'dark_bg_shiny_hair':      (13,  8,  9,  20, False),\n",
    "    'dark_hair_and_light_pos': (14,  8,  9,  20, False),\n",
    "    'dark_hair':               (16,  8,  9,  20, False),\n",
    "    'disgusted':               (43,  6,  8, -30, False),\n",
    "    'displeased':              (36,  4,  7,  20, False),\n",
    "    'eye_openness':            (54,  7,  8,  20, False),\n",
    "    'eye_wrinkles':            (28,  6,  8,  20, False),\n",
    "    'eyebrow_thickness':       (37,  8,  9,  20, False),\n",
    "    'face_roundness':          (37,  0,  5,  20, False),\n",
    "    'fearful_eyes':            (54,  4, 10,  20, False),\n",
    "    'hairline':                (21,  4,  5, -20, False),\n",
    "    'happy_frizzy_hair':       (30,  0,  8,  20, False),\n",
    "    'happy_elderly_lady':      (27,  4,  7,  20, False),\n",
    "    'head_angle_up':           (11,  1,  4,  20, False),\n",
    "    'huge_grin':               (28,  4,  6,  20, False),\n",
    "    'in_awe':                  (23,  3,  6, -15, False),\n",
    "    'wide_smile':              (23,  3,  6,  20, False),\n",
    "    'large_jaw':               (22,  3,  6,  20, False),\n",
    "    'light_lr':                (15,  8,  9,  10, False),\n",
    "    'lipstick_and_age':        (34,  6, 11,  20, False),\n",
    "    'lipstick':                (34, 10, 11,  20, False),\n",
    "    'mascara_vs_beard':        (41,  6,  9,  20, False),\n",
    "    'nose_length':             (51,  4,  5, -20, False),\n",
    "    'elderly_woman':           (34,  6,  7,  20, False),\n",
    "    'overexposed':             (27,  8, 18,  15, False),\n",
    "    'screaming':               (35,  3,  7, -15, False),\n",
    "    'short_face':              (32,  2,  6, -20, False),\n",
    "    'show_front_teeth':        (59,  4,  5,  40, False),\n",
    "    'smile':                   (46,  4,  5, -20, False),\n",
    "    'straight_bowl_cut':       (20,  4,  5, -20, False),\n",
    "    'sunlight_in_face':        (10,  8,  9,  10, False),\n",
    "    'trimmed_beard':           (58,  7,  9,  20, False),\n",
    "    'white_hair':              (57,  7, 10, -24, False),\n",
    "    'wrinkles':                (20,  6,  7, -18, False),\n",
    "    'boyishness':              (8,   2,  5,  20, False),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# StyleGAN2 faces - emphasis on novel edits\n",
    "setup_model('StyleGAN2', 'ffhq', 'style')\n",
    "model.truncation = 0.7\n",
    "model.use_w()\n",
    "\n",
    "seeds = [6293435, 2105448342] # + [rand() for _ in range(1)]\n",
    "print(seeds)\n",
    "edits = ['wrinkles', 'white_hair', 'in_awe', 'overexposed']\n",
    "perform_edit(seeds, edits, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# StyleGAN2 cars\n",
    "setup_model('StyleGAN2', 'car', 'style')\n",
    "model.truncation = 0.6\n",
    "model.use_w()\n",
    "\n",
    "seeds = [440749230] # + [rand() for _ in range(10)]\n",
    "edits = ['Redness', 'Horizontal flip', 'Add grass', 'Blocky shape']\n",
    "perform_edit(seeds, edits, True, crop=[64, 64, 1, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# BigGAN-512 irish setter\n",
    "setup_model('BigGAN-512', 'husky', 'generator.gen_z')\n",
    "model.set_output_class('irish_setter')\n",
    "model.truncation = 0.6\n",
    "\n",
    "seeds = [489408325]# + [rand() for _ in range(10)]\n",
    "edits = ['Rotate', 'Zoom out', 'Camera down', 'Reeds']\n",
    "perform_edit(seeds, edits, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}