RitaParadaRamos commited on
Commit
1212b6f
1 Parent(s): cb08135

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +102 -0
  2. requirements.txt +19 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ import gradio as gr
4
+
5
+ import torch
6
+ from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM
7
+ from transformers.models.auto.configuration_auto import AutoConfig
8
+ from src.vision_encoder_decoder import SmallCap, SmallCapConfig
9
+ from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel
10
+ from src.utils import prep_strings, postprocess_preds
11
+ import json
12
+
13
+ from src.retrieve_caps import *
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+
17
+ from src.opt import ThisOPTConfig, ThisOPTForCausalLM
18
+
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # load feature extractor
23
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
24
+
25
+ # load and configure tokenizer
26
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125M")
27
+ tokenizer.pad_token = '!'
28
+ tokenizer.eos_token = '.'
29
+
30
+ # load model
31
+ # AutoConfig.register("this_gpt2", ThisGPT2Config)
32
+ # AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel)
33
+ # AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel)
34
+ # AutoConfig.register("smallcap", SmallCapConfig)
35
+ # AutoModel.register(SmallCapConfig, SmallCap)
36
+ # model = AutoModel.from_pretrained("Yova/SmallCap7M")
37
+
38
+
39
+ AutoConfig.register("this_opt", ThisOPTConfig)
40
+ AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM)
41
+ AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM)
42
+ AutoConfig.register("smallcap", SmallCapConfig)
43
+ AutoModel.register(SmallCapConfig, SmallCap)
44
+ model = AutoModel.from_pretrained("Yova/SmallCapOPT7M")
45
+
46
+ model= model.to(device)
47
+
48
+ template = open('src/template.txt').read().strip() + ' '
49
+
50
+ # precompute captions for retrieval
51
+ captions = json.load(open('datastore/coco_index_captions.json'))
52
+ retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device)
53
+ retrieval_index = faiss.read_index('datastore/coco_index')
54
+ #res = faiss.StandardGpuResources()
55
+ #retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index)
56
+
57
+ # Download human-readable labels for ImageNet.
58
+ response = requests.get("https://git.io/JJkYN")
59
+ labels = response.text.split("\n")
60
+
61
+
62
+ def retrieve_caps(image_embedding, index, k=4):
63
+ xq = image_embedding.astype(np.float32)
64
+ faiss.normalize_L2(xq)
65
+ D, I = index.search(xq, k)
66
+ return I
67
+
68
+ def classify_image(image):
69
+ inp = transforms.ToTensor()(image)
70
+
71
+ pixel_values_retrieval = feature_extractor_retrieval(image).to(device)
72
+ with torch.no_grad():
73
+ image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy()
74
+
75
+ nns = retrieve_caps(image_embedding, retrieval_index)[0]
76
+ caps = [captions[i] for i in nns][:4]
77
+
78
+ # prepare prompt
79
+ decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True)
80
+
81
+ # generate caption
82
+ pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
83
+ with torch.no_grad():
84
+ pred = model.generate(pixel_values.to(device),
85
+ decoder_input_ids=torch.tensor([decoder_input_ids]).to(device),
86
+ max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0,
87
+ min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id)
88
+ #inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
89
+ #prediction = inception_net.predict(inp).flatten()
90
+ retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps)
91
+ return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer))
92
+
93
+
94
+ image = gr.Image(type="pil")
95
+
96
+ textbox = gr.Textbox(placeholder="Retrieved captions and generated caption...", lines=4)
97
+
98
+
99
+ title = "SmallCap Demo"
100
+ gr.Interface(
101
+ fn=classify_image, inputs=image, outputs=textbox, title=title
102
+ ).launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.4.0
2
+ faiss-gpu==1.7.2
3
+ h5py==3.7.0
4
+ huggingface-hub==0.8.1
5
+ pandas==1.4.3
6
+ Pillow==9.2.0
7
+ pyarrow==9.0.0
8
+ pyparsing==3.0.9
9
+ PyYAML==6.0
10
+ tokenizers==0.12.1
11
+ torch==1.12.1
12
+ torchaudio==0.12.1
13
+ torchvision==0.13.1
14
+ tqdm==4.64.0
15
+ transformers==4.21.1
16
+ ftfy
17
+ regex
18
+ tqdm
19
+ git+https://github.com/openai/CLIP.git