{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import gradio as gr "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from src.data.embs import ImageDataset\n",
    "from src.model.blip_embs import blip_embs\n",
    "\n",
    "from demo_chat import Chat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data.transforms import transform_test\n",
    "\n",
    "transform = transform_test(384)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_blip_config(model=\"base\"):\n",
    "    config = dict()\n",
    "    if model == \"base\":\n",
    "        config[\n",
    "            \"pretrained\"\n",
    "        ] = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth \"\n",
    "        config[\"vit\"] = \"base\"\n",
    "        config[\"batch_size_train\"] = 32\n",
    "        config[\"batch_size_test\"] = 16\n",
    "        config[\"vit_grad_ckpt\"] = True\n",
    "        config[\"vit_ckpt_layer\"] = 4\n",
    "        config[\"init_lr\"] = 1e-5\n",
    "    elif model == \"large\":\n",
    "        config[\n",
    "            \"pretrained\"\n",
    "        ] = \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth\"\n",
    "        config[\"vit\"] = \"large\"\n",
    "        config[\"batch_size_train\"] = 16\n",
    "        config[\"batch_size_test\"] = 32\n",
    "        config[\"vit_grad_ckpt\"] = True\n",
    "        config[\"vit_ckpt_layer\"] = 12\n",
    "        config[\"init_lr\"] = 5e-6\n",
    "\n",
    "    config[\"image_size\"] = 384\n",
    "    config[\"queue_size\"] = 57600\n",
    "    config[\"alpha\"] = 0.4\n",
    "    config[\"k_test\"] = 256\n",
    "    config[\"negative_all_rank\"] = True\n",
    "\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating model\n",
      "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth\n",
      "missing keys:\n",
      "[]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "BLIPEmbs(\n",
       "  (visual_encoder): VisionTransformer(\n",
       "    (patch_embed): PatchEmbed(\n",
       "      (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))\n",
       "      (norm): Identity()\n",
       "    )\n",
       "    (pos_drop): Dropout(p=0.0, inplace=False)\n",
       "    (blocks): ModuleList(\n",
       "      (0): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.004)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (2): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.009)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (3): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.013)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (4): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.017)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (5): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.022)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (6): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.026)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (7): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.030)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (8): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.035)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (9): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.039)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (10): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.043)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (11): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.048)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (12): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.052)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (13): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.057)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (14): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.061)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (15): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.065)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (16): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.070)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (17): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.074)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (18): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.078)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (19): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.083)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (20): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.087)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (21): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.091)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (22): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.096)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (23): Block(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "        (drop_path): DropPath(drop_prob=0.100)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
       "  )\n",
       "  (text_encoder): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(30524, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (crossattention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=1024, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=1024, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (vision_proj): Linear(in_features=1024, out_features=256, bias=True)\n",
       "  (text_proj): Linear(in_features=768, out_features=256, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Creating model\")\n",
    "config = get_blip_config(\"large\")\n",
    "\n",
    "model = blip_embs(\n",
    "        pretrained=config[\"pretrained\"],\n",
    "        image_size=config[\"image_size\"],\n",
    "        vit=config[\"vit\"],\n",
    "        vit_grad_ckpt=config[\"vit_grad_ckpt\"],\n",
    "        vit_ckpt_layer=config[\"vit_ckpt_layer\"],\n",
    "        queue_size=config[\"queue_size\"],\n",
    "        negative_all_rank=config[\"negative_all_rank\"],\n",
    "    )\n",
    "\n",
    "model = model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_json(\"datasets/sidechef/my_recipes.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Target Embedding\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading Target Embedding\")\n",
    "tar_img_feats = []\n",
    "for _id in df[\"id_\"].tolist():     \n",
    "    tar_img_feats.append(torch.load(\"datasets/sidechef/blip-embs-large/{:07d}.pth\".format(_id)).unsqueeze(0))\n",
    "\n",
    "tar_img_feats = torch.cat(tar_img_feats, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7866\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"650px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n",
    "# Define the custom CSS to add a footer\n",
    "custom_css = \"\"\"\n",
    "/* Footer style */\n",
    ".gradio-footer {\n",
    "    display: flex;\n",
    "    justify-content: center;\n",
    "    align-items: center;\n",
    "    padding: 10px;\n",
    "    background-color: #f8f9fa;\n",
    "    color: #333;\n",
    "    font-size: 0.9em;\n",
    "}\n",
    "\n",
    ".custom-header {\n",
    "    text-align: center;\n",
    "    padding: 12px;\n",
    "    background-color: #333; \n",
    "    color: white;\n",
    "    position: bottom;\n",
    "    bottom: 0;\n",
    "    width: 100%;\n",
    "    font-size: 0.8em;\n",
    "}\n",
    "\n",
    ".footer {\n",
    "    width: 100%;\n",
    "    background-color: #f2f2f2;\n",
    "    color: #555;\n",
    "    text-align: center;\n",
    "    padding: 10px 0;\n",
    "    position: absolute;\n",
    "    bottom: 0;\n",
    "    left: 0;\n",
    "}\n",
    "\n",
    "/* Make sure the interface leaves space for the footer */\n",
    ".body {\n",
    "    margin-bottom: 50px;\n",
    "}\n",
    "\"\"\"\n",
    "\n",
    "# Add a custom footer by injecting HTML into the description\n",
    "custom_footer_html = \"\"\"\n",
    "<footer> <p> Reach out to us at {omkar.thawakar, muzammal.naseer}@mbzuai.ac.ae </p> </footer>\n",
    "\"\"\"\n",
    "\n",
    "custom_header_html = \"\"\"\n",
    "<div class='custom-header'>Nutrition-GPT Demo</div>\n",
    "\"\"\"\n",
    "\n",
    "def respond_to_user(image, message):\n",
    "    # Process the image and message here\n",
    "    # For demonstration, I'll just return a simple text response\n",
    "    chat = Chat(model,transform,df,tar_img_feats)\n",
    "    chat.encode_image(image)\n",
    "    response = chat.ask(message)\n",
    "    return response\n",
    "\n",
    "iface = gr.Interface(\n",
    "    fn=respond_to_user,\n",
    "    inputs=[gr.Image(height=\"70%\"), gr.Textbox(label=\"Ask Query\"),],\n",
    "    outputs=[gr.Textbox(label=\"Nutrition-GPT\")],\n",
    "    title=custom_header_html,  \n",
    "    # description=\"Upload an food image and ask queries!\",\n",
    "    css=custom_css,\n",
    "    # description=custom_footer_html   \n",
    ")\n",
    "\n",
    "iface.launch(show_error=True, height=\"650px\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],\n",
    "            # label=\"Prompt Examples\",\n",
    "            # samples=[\n",
    "            #     [\"Provide nutritional information for given food image.\"],\n",
    "            #     [\"What are the nutrients available in given food image.\"],\n",
    "            #     [\"Could you provide a detailed nutritional data of the given food image?\"],\n",
    "            #     [\"Describe the instructions to prepare given food.\"],\n",
    "            #     [\"What are the key ingredients in this food image?\"],\n",
    "            #     [\"Could you highlight the dietary tags for this food image?\"],\n",
    "            # ],)\n",
    "\n",
    "# example_images = gr.Dataset(components=[image], label=\"Food Examples\",\n",
    "#                     samples=[\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000018.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000021.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000035.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000038.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000090.png\")],\n",
    "#                         [os.path.join(os.path.dirname(\"./\"), \"./datasets/sidechef/sample_images/0000122.png\")],\n",
    "#                     ])\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "covr",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}