{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Systematic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import clip\n",
    "from evaluation_utils import norm, denorm\n",
    "from general_utils import *\n",
    "from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
    "\n",
    "clip_device = 'cuda'\n",
    "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
    "clip_model.eval();\n",
    "\n",
    "from models.clipseg import CLIPDensePredTMasked\n",
    "\n",
    "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
    "clip_mask_model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
    "                     text_class_labels=True, image_size=352, min_area=0.1,\n",
    "                     min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_data(lvis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import json\n",
    "\n",
    "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
    "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
    "\n",
    "objects_per_image = defaultdict(lambda : set())\n",
    "for ann in lvis_raw['annotations']:\n",
    "    objects_per_image[ann['image_id']].add(ann['category_id'])\n",
    "    \n",
    "for ann in lvis_val_raw['annotations']:\n",
    "    objects_per_image[ann['image_id']].add(ann['category_id'])    \n",
    "    \n",
    "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
    "\n",
    "del lvis_raw, lvis_val_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#bs = 32\n",
    "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from general_utils import get_batch\n",
    "from functools import partial\n",
    "from evaluation_utils import img_preprocess\n",
    "import torch\n",
    "\n",
    "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
    "\n",
    "    # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
    "\n",
    "    all_prompts = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        valid_sims = []\n",
    "        torch.manual_seed(571)\n",
    "        \n",
    "        if type(batches_or_dataset) == list:\n",
    "            loader = batches_or_dataset  # already loaded\n",
    "            max_iter = float('inf')\n",
    "        else:\n",
    "            loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
    "            max_iter = 50\n",
    "        \n",
    "        global batch\n",
    "        for i_batch, (batch, batch_y) in enumerate(loader):\n",
    "            \n",
    "            if i_batch >= max_iter: break\n",
    "                \n",
    "            processed_batch = process(batch)\n",
    "            if type(processed_batch) == dict:\n",
    "                \n",
    "                # processed_batch =  {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
    "                image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
    "            else:\n",
    "                processed_batch = process(batch).to(clip_device)\n",
    "                processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
    "                #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
    "                image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
    "                \n",
    "            image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
    "            bs = len(batch[0])\n",
    "            for j in range(bs):\n",
    "            \n",
    "                c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
    "                support_image = basename(lvis.samples[c][sid])\n",
    "                \n",
    "                img_objs = [o for o in objects_per_image[int(support_image)]]\n",
    "                img_objs = [o.replace('_', ' ') for o in img_objs]\n",
    "                \n",
    "                other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
    "                               if o != batch_y[2][j]]\n",
    "            \n",
    "                prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
    "                all_prompts += [prompts]\n",
    "                \n",
    "                text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
    "                text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)            \n",
    "\n",
    "                global logits\n",
    "                logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
    "\n",
    "                global sim\n",
    "                sim = torch.softmax(logits, dim=-1)\n",
    "            \n",
    "                valid_sims += [sim]\n",
    "                \n",
    "        #valid_sims = torch.stack(valid_sims)\n",
    "        return valid_sims, all_prompts\n",
    "    \n",
    "\n",
    "def new_img_preprocess(x):\n",
    "    return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
    "    \n",
    "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
    "get_similarities(lvis, lambda x: x[1]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessing_functions = [\n",
    "#     ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
    "#     ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
    "#     ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
    "#     ['colorize object red', partial(img_preprocess, colorize=True)],\n",
    "#     ['add red outline', partial(img_preprocess, outline=True)],\n",
    "    \n",
    "#     ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
    "#     ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
    "#     ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
    "#     ['BG blur', partial(img_preprocess, blur=3)],\n",
    "#     ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
    "   \n",
    "#     ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
    "#     ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
    "    ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
    "    ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
    "#     ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
    "]\n",
    "\n",
    "preprocessing_functions = preprocessing_functions\n",
    "\n",
    "base, base_p = get_similarities(lvis, lambda x: x[1])\n",
    "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs2 = [get_similarities(lvis, fun) for _, fun in  [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for j in range(1):\n",
    "    print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pandas import DataFrame\n",
    "tab = dict()\n",
    "for j, (name, _) in enumerate(preprocessing_functions):\n",
    "    tab[name] =  np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
    "    \n",
    "    \n",
    "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items()))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluation_utils import denorm, norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_sample(filename, filename2):\n",
    "    from os.path import join\n",
    "    bp = expanduser('~/cloud/resources/sample_images')\n",
    "    tf = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "        transforms.Resize(224),\n",
    "        transforms.CenterCrop(224)\n",
    "    ])\n",
    "    tf2 = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Resize(224),\n",
    "        transforms.CenterCrop(224)\n",
    "    ])\n",
    "    inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
    "    inp1[1] = inp1[1].unsqueeze(0)\n",
    "    inp1[2] = inp1[2][:1]   \n",
    "    return inp1\n",
    "\n",
    "def all_preprocessing(inp1):\n",
    "    return [\n",
    "        img_preprocess(inp1),\n",
    "        img_preprocess(inp1, colorize=True),\n",
    "        img_preprocess(inp1, outline=True),        \n",
    "        img_preprocess(inp1, blur=3),\n",
    "        img_preprocess(inp1, bg_fac=0.1),\n",
    "        #img_preprocess(inp1, bg_fac=0.5),\n",
    "        #img_preprocess(inp1, blur=3, bg_fac=0.5),        \n",
    "        img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
    "    ]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "from matplotlib import pyplot as plt\n",
    "from evaluation_utils import img_preprocess\n",
    "import clip\n",
    "\n",
    "images_queries = [\n",
    "    [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
    "    [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
    "]\n",
    "\n",
    "\n",
    "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
    "\n",
    "for j, (images, objects) in enumerate(images_queries):\n",
    "    \n",
    "    joint_image = all_preprocessing(images)\n",
    "    \n",
    "    joint_image = torch.stack(joint_image)[:,0]\n",
    "    clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
    "    image_features = clip_model.encode_image(joint_image)\n",
    "    image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
    "    \n",
    "    prompts = [f'a photo of a {obj}'for obj in objects]\n",
    "    text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
    "    text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
    "    logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
    "    sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
    "\n",
    "    for i, img in enumerate(joint_image):\n",
    "        ax[2*j, i].axis('off')\n",
    "        \n",
    "        ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
    "        ax[2*j+ 1, i].grid(True)\n",
    "        \n",
    "        ax[2*j + 1, i].set_ylim(0,1)\n",
    "        ax[2*j + 1, i].set_yticklabels([])\n",
    "        ax[2*j + 1, i].set_xticks([])  # set_xticks(range(len(prompts)))\n",
    "#         ax[1, i].set_xticklabels(objects, rotation=90)\n",
    "        for k in range(len(sim[i])):\n",
    "            ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
    "            ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env2",
   "language": "python",
   "name": "env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}