In [13]:
import pandas as pd
import json
from PIL import Image
import numpy as np
import gradio as gr 

In [6]:
import os
import sys
from pathlib import Path

import torch
import torch.nn.functional as F

from src.data.embs import ImageDataset
from src.model.blip_embs import blip_embs

from demo_chat import Chat

In [7]:
from src.data.transforms import transform_test

transform = transform_test(384)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
def get_blip_config(model="base"):
    config = dict()
    if model == "base":
        config[
            "pretrained"
        ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth "
        config["vit"] = "base"
        config["batch_size_train"] = 32
        config["batch_size_test"] = 16
        config["vit_grad_ckpt"] = True
        config["vit_ckpt_layer"] = 4
        config["init_lr"] = 1e-5
    elif model == "large":
        config[
            "pretrained"
        ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
        config["vit"] = "large"
        config["batch_size_train"] = 16
        config["batch_size_test"] = 32
        config["vit_grad_ckpt"] = True
        config["vit_ckpt_layer"] = 12
        config["init_lr"] = 5e-6

    config["image_size"] = 384
    config["queue_size"] = 57600
    config["alpha"] = 0.4
    config["k_test"] = 256
    config["negative_all_rank"] = True

    return config

In [10]:
print("Creating model")
config = get_blip_config("large")

model = blip_embs(
        pretrained=config["pretrained"],
        image_size=config["image_size"],
        vit=config["vit"],
        vit_grad_ckpt=config["vit_grad_ckpt"],
        vit_ckpt_layer=config["vit_ckpt_layer"],
        queue_size=config["queue_size"],
        negative_all_rank=config["negative_all_rank"],
    )

model = model.to(device)
model.eval()

Creating model
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth
missing keys:
[]


BLIPEmbs(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1024, out_features=3072, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
  

In [11]:
df = pd.read_json("datasets/sidechef/my_recipes.json")

In [12]:
print("Loading Target Embedding")
tar_img_feats = []
for _id in df["id_"].tolist():     
    tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0))

tar_img_feats = torch.cat(tar_img_feats, dim=0)

Loading Target Embedding


In [46]:


# Define the custom CSS to add a footer
custom_css = """
/* Footer style */
.gradio-footer {
    display: flex;
    justify-content: center;
    align-items: center;
    padding: 10px;
    background-color: #f8f9fa;
    color: #333;
    font-size: 0.9em;
}

.custom-header {
    text-align: center;
    padding: 12px;
    background-color: #333; 
    color: white;
    position: bottom;
    bottom: 0;
    width: 100%;
    font-size: 0.8em;
}

.footer {
    width: 100%;
    background-color: #f2f2f2;
    color: #555;
    text-align: center;
    padding: 10px 0;
    position: absolute;
    bottom: 0;
    left: 0;
}

/* Make sure the interface leaves space for the footer */
.body {
    margin-bottom: 50px;
}
"""

# Add a custom footer by injecting HTML into the description
custom_footer_html = """
<footer> <p> Reach out to us at {omkar.thawakar, muzammal.naseer}@mbzuai.ac.ae </p> </footer>
"""

custom_header_html = """
<div class='custom-header'>Nutrition-GPT Demo</div>
"""

def respond_to_user(image, message):
    # Process the image and message here
    # For demonstration, I'll just return a simple text response
    chat = Chat(model,transform,df,tar_img_feats)
    chat.encode_image(image)
    response = chat.ask(message)
    return response

iface = gr.Interface(
    fn=respond_to_user,
    inputs=[gr.Image(height="70%"), gr.Textbox(label="Ask Query"),],
    outputs=[gr.Textbox(label="Nutrition-GPT")],
    title=custom_header_html,  
    # description="Upload an food image and ask queries!",
    css=custom_css,
    # description=custom_footer_html   
)

iface.launch(show_error=True, height="650px")

Running on local URL:  http://127.0.0.1:7866

To create a public link, set `share=True` in `launch()`.




In [None]:
# example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],
            # label="Prompt Examples",
            # samples=[
            #     ["Provide nutritional information for given food image."],
            #     ["What are the nutrients available in given food image."],
            #     ["Could you provide a detailed nutritional data of the given food image?"],
            #     ["Describe the instructions to prepare given food."],
            #     ["What are the key ingredients in this food image?"],
            #     ["Could you highlight the dietary tags for this food image?"],
            # ],)

# example_images = gr.Dataset(components=[image], label="Food Examples",
#                     samples=[
#                         [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000018.png")],
#                         [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000021.png")],
#                         [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000035.png")],
#                         [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000038.png")],
#                         [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000090.png")],
#                         [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000122.png")],
#                     ])

