File size: 3,884 Bytes
1212b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe53098
1212b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7957de1
1212b6f
7957de1
1212b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1eae7e
1212b6f
a1eae7e
1212b6f
 
 
a1eae7e
1212b6f
 
 
 
79b66b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import requests

import gradio as gr

import torch
from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM
from transformers.models.auto.configuration_auto import AutoConfig
from src.vision_encoder_decoder import SmallCap, SmallCapConfig
from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel
from src.utils import prep_strings, postprocess_preds
import json

from src.retrieve_caps import *
from PIL import Image
from torchvision import transforms

from src.opt import ThisOPTConfig, ThisOPTForCausalLM


device = "cuda" if torch.cuda.is_available() else "cpu"

# load feature extractor
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")

# load and configure tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.pad_token = '!'
tokenizer.eos_token = '.'

# load model
# AutoConfig.register("this_gpt2", ThisGPT2Config)
# AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel)
# AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel)
# AutoConfig.register("smallcap", SmallCapConfig)
# AutoModel.register(SmallCapConfig, SmallCap)
# model = AutoModel.from_pretrained("Yova/SmallCap7M")


AutoConfig.register("this_opt", ThisOPTConfig)
AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM)
AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM)
AutoConfig.register("smallcap", SmallCapConfig)
AutoModel.register(SmallCapConfig, SmallCap)
model = AutoModel.from_pretrained("Yova/SmallCapOPT7M")

model= model.to(device)

template = open('src/template.txt').read().strip() + ' '

# precompute captions for retrieval
captions = json.load(open('coco_index_captions.json'))
retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device)
retrieval_index = faiss.read_index('coco_index')
#res = faiss.StandardGpuResources()  
#retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index)

# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")


def retrieve_caps(image_embedding, index, k=4):
    xq = image_embedding.astype(np.float32)
    faiss.normalize_L2(xq)
    D, I = index.search(xq, k) 
    return I

def classify_image(image):
    inp = transforms.ToTensor()(image)

    pixel_values_retrieval = feature_extractor_retrieval(image).to(device)
    with torch.no_grad():
        image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy()

    nns = retrieve_caps(image_embedding, retrieval_index)[0]
    caps = [captions[i] for i in nns][:4]

    # prepare prompt
    decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True)

    # generate caption
    pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
    with torch.no_grad():
        pred = model.generate(pixel_values.to(device),
                            decoder_input_ids=torch.tensor([decoder_input_ids]).to(device),
                            max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0,
                            min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id)
        #inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
        #prediction = inception_net.predict(inp).flatten()
    retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps)
    #return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer))

    return str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) + "\n\n\n"+ retrieved_caps 

image = gr.Image(type="pil")

textbox = gr.Textbox(placeholder="Generated caption and retrieved captions...", lines=4)

title = "SmallCap Demo"
gr.Interface(
    fn=classify_image, inputs=image, outputs=textbox, title=title
).launch()