zamborg commited on
Commit
7925ce3
·
1 Parent(s): 53c2d82
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -15,7 +15,6 @@
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
- *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
17
  *.pt filter=lfs diff=lfs merge=lfs -text
 
18
  *.rar filter=lfs diff=lfs merge=lfs -text
19
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
  *.tar.* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pth
2
+
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Redcaps Dev
3
- emoji: 📚
4
  colorFrom: red
5
- colorTo: red
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
 
1
  ---
2
+ title: Redcaps
3
+ emoji: 📊
4
  colorFrom: red
5
+ colorTo: pink
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
__pycache__/model.cpython-37.pyc ADDED
Binary file (4.7 kB). View file
 
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import io
3
+ import sys
4
+ import time
5
+ import json
6
+ sys.path.append("./virtex/")
7
+ from model import *
8
+
9
+ # # TODO:
10
+ # - Reformat the model introduction
11
+ # - Center the images using the 3 column method
12
+ # - Make the iterative text generation
13
+
14
+ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
15
+ with st.spinner("Generating Caption"):
16
+ if sub_prompt is None and cap_prompt is not "":
17
+ st.write("Without a specified subreddit we default to /r/pics")
18
+ subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
19
+ st.markdown(
20
+ f"""
21
+ <style>
22
+ red{{
23
+ color:#c62828
24
+ }}
25
+ mono{{
26
+ font-family: "Inconsolata";
27
+ }}
28
+ </style>
29
+
30
+ ### <red> r/{subreddit} </red> {caption}
31
+ """,
32
+ unsafe_allow_html=True)
33
+
34
+
35
+ st.title("Image Captioning Demo from RedCaps")
36
+ st.sidebar.markdown(
37
+ """
38
+ ### Image Captioning Model from VirTex trained on RedCaps
39
+
40
+ Use this page to caption your own images or try out some of our samples.
41
+ You can also generate captions as if they are from specific subreddits,
42
+ as if they start with a particular prompt, or even both.
43
+
44
+ Share your results on twitter with #redcaps or with a friend.
45
+ """
46
+ )
47
+
48
+ with st.spinner("Loading Model"):
49
+ virtexModel, imageLoader, sample_images, valid_subs = create_objects()
50
+
51
+
52
+ # staggered = st.sidebar.checkbox("Iteratively Generate Captions")
53
+
54
+ # if staggered:
55
+ # pass
56
+ # else:
57
+
58
+ select_idx = None
59
+
60
+ st.sidebar.title("Select a sample image")
61
+
62
+ if st.sidebar.button("Random Sample Image"):
63
+ select_idx = get_rand_idx(sample_images)
64
+
65
+ sample_image = sample_images[0 if select_idx is None else select_idx]
66
+
67
+
68
+ uploaded_image = None
69
+ # with st.sidebar.form("file-uploader-form", clear_on_submit=True):
70
+ uploaded_file = st.sidebar.file_uploader("Choose a file")
71
+ # submitted = st.form_submit_button("Submit")
72
+ if uploaded_file is not None:# and submitted:
73
+ uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
74
+ select_idx = None # set this to help rewrite the cache
75
+
76
+ # class OnChange():
77
+ # def __init__(self, idx):
78
+ # self.idx = idx
79
+
80
+ # def __call__(self):
81
+ # st.write(f"the idx is: {self.idx}")
82
+ # st.write(f"the sample_image is {sample_image}")
83
+
84
+ # sample_image = st.sidebar.selectbox(
85
+ # "",
86
+ # sample_images,
87
+ # index = 0 if select_idx is None else select_idx,
88
+ # on_change=OnChange(0 if select_idx is None else select_idx)
89
+ # )
90
+
91
+ st.sidebar.title("Select a Subreddit")
92
+ sub = st.sidebar.selectbox(
93
+ "Type below to condition on a subreddit. Select None for a predicted subreddit",
94
+ valid_subs
95
+ )
96
+
97
+ st.sidebar.title("Write a Custom Prompt")
98
+ cap_prompt = st.sidebar.text_input(
99
+ "Write the start of your caption below",
100
+ value=""
101
+ )
102
+
103
+ _ = st.sidebar.button("Regenerate Caption")
104
+
105
+ advanced = st.sidebar.checkbox("Advanced Options")
106
+ num_captions=1
107
+ if advanced:
108
+ num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
109
+ nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
110
+ virtexModel.model.decoder.nucleus_size = nuc_size
111
+
112
+ if False: #uploaded_image is None:# and submitted:
113
+ st.write("Please select a file to upload")
114
+
115
+ else:
116
+ image_file = sample_image
117
+
118
+ # LOAD AND CACHE THE IMAGE
119
+ if uploaded_image is not None:
120
+ image = uploaded_image
121
+ elif select_idx is None and 'image' in st.session_state:
122
+ image = st.session_state['image']
123
+ else:
124
+ image = Image.open(image_file)
125
+
126
+ image = image.convert("RGB")
127
+
128
+ st.session_state['image'] = image
129
+
130
+
131
+ image_dict = imageLoader.transform(image)
132
+
133
+ show_image = imageLoader.show_resize(image)
134
+
135
+ show = st.image(show_image)
136
+ show.image(show_image, "Your Image")
137
+
138
+ for i in range(num_captions):
139
+ gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
config.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AMP: true
2
+ CUDNN_BENCHMARK: true
3
+ CUDNN_DETERMINISTIC: false
4
+ DATA:
5
+ EOS_INDEX: 2
6
+ IMAGE_CROP_SIZE: 224
7
+ IMAGE_TRANSFORM_TRAIN:
8
+ - random_resized_crop
9
+ - horizontal_flip
10
+ - color_jitter
11
+ - normalize
12
+ IMAGE_TRANSFORM_VAL:
13
+ - smallest_resize
14
+ - center_crop
15
+ - normalize
16
+ MASKED_LM:
17
+ MASK_PROBABILITY: 0.85
18
+ MASK_PROPORTION: 0.15
19
+ REPLACE_PROBABILITY: 0.1
20
+ MASK_INDEX: 3
21
+ MAX_CAPTION_LENGTH: 50
22
+ ROOT: datasets/redcaps/tarfiles/*.tar
23
+ SOS_INDEX: 1
24
+ TOKENIZER_MODEL: datasets/common_30k.model
25
+ UNK_INDEX: 0
26
+ USE_PERCENTAGE: 100.0
27
+ USE_SINGLE_CAPTION: false
28
+ VOCAB_SIZE: 30000
29
+ MODEL:
30
+ DECODER:
31
+ BEAM_SIZE: 5
32
+ MAX_DECODING_STEPS: 30
33
+ NAME: nucleus_sampling
34
+ NUCLEUS_SIZE: 0.9
35
+ LABEL_SMOOTHING: 0.1
36
+ NAME: virtex_web
37
+ TEXTUAL:
38
+ DROPOUT: 0.1
39
+ NAME: transdec_prenorm::L6_H512_A8_F2048
40
+ VISUAL:
41
+ FEATURE_SIZE: 2048
42
+ FROZEN: false
43
+ NAME: torchvision::resnet50
44
+ PRETRAINED: false
45
+ OPTIM:
46
+ BATCH_SIZE: 256
47
+ CLIP_GRAD_NORM: 10.0
48
+ CNN_LR: 0.0005
49
+ LOOKAHEAD:
50
+ ALPHA: 0.5
51
+ STEPS: 5
52
+ USE: false
53
+ LR: 0.0005
54
+ LR_DECAY_NAME: cosine
55
+ LR_GAMMA: 0.1
56
+ LR_STEPS: []
57
+ NO_DECAY: .*textual.(embedding|transformer).*(norm.*|bias)
58
+ NUM_ITERATIONS: 1500000
59
+ OPTIMIZER_NAME: adamw
60
+ SGD_MOMENTUM: 0.9
61
+ WARMUP_STEPS: 10000
62
+ WEIGHT_DECAY: 0.01
63
+ RANDOM_SEED: 0
datasets/common_30k.model ADDED
Binary file (748 kB). View file
 
datasets/tmp ADDED
File without changes
experiment.ipynb ADDED
File without changes
model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import hf_hub_url, cached_download
3
+ from PIL import Image
4
+ import os
5
+ import json
6
+ import glob
7
+ import random
8
+ from typing import Any, Dict, List
9
+ import torch
10
+ import torchvision
11
+
12
+ import wordsegment as ws
13
+
14
+ from virtex.config import Config
15
+ from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
16
+ from virtex.utils.checkpointing import CheckpointManager
17
+
18
+ CONFIG_PATH = "config.yaml"
19
+ MODEL_PATH = "checkpoint_last5.pth"
20
+ VALID_SUBREDDITS_PATH = "subreddit_list.json"
21
+ SAMPLES_PATH = "./samples/*.jpg"
22
+
23
+ class ImageLoader():
24
+ def __init__(self):
25
+ self.image_transform = torchvision.transforms.Compose([
26
+ torchvision.transforms.ToTensor(),
27
+ torchvision.transforms.Resize(256),
28
+ torchvision.transforms.CenterCrop(224),
29
+ torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
30
+ self.show_size=500
31
+
32
+ def load(self, im_path):
33
+ im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
34
+ return {"image": im}
35
+
36
+ def raw_load(self, im_path):
37
+ im = torch.FloatTensor(Image.open(im_path))
38
+ return {"image": im}
39
+
40
+ def transform(self, image):
41
+ im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
42
+ return {"image": im}
43
+
44
+ def text_transform(self, text):
45
+ # at present just lowercasing:
46
+ return text.lower()
47
+
48
+ def show_resize(self, image):
49
+ # ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
50
+ image = torchvision.transforms.functional.to_tensor(image)
51
+ x,y = image.shape[-2:]
52
+ ratio = float(self.show_size/max((x,y)))
53
+ image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
54
+ return torchvision.transforms.functional.to_pil_image(image)
55
+
56
+
57
+ class VirTexModel():
58
+ def __init__(self):
59
+ self.config = Config(CONFIG_PATH)
60
+ ws.load()
61
+ self.device = 'cpu'
62
+ self.tokenizer = TokenizerFactory.from_config(self.config)
63
+ self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
64
+ CheckpointManager(model=self.model).load(MODEL_PATH)
65
+ self.model.eval()
66
+ self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
67
+
68
+ def predict(self, image_dict, sub_prompt = None, prompt = ""):
69
+ if sub_prompt is None:
70
+ subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
71
+ else:
72
+ subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
73
+ subreddit_tokens = (
74
+ [self.model.sos_index] +
75
+ self.tokenizer.encode(subreddit_tokens) +
76
+ [self.tokenizer.token_to_id("[SEP]")]
77
+ )
78
+ subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
79
+
80
+ if prompt is not "":
81
+ # at present prompts without subreddits will break without this change
82
+ # TODO FIX
83
+ cap_tokens = self.tokenizer.encode(prompt)
84
+ cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
85
+ subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor(
86
+ (
87
+ [self.model.sos_index] +
88
+ self.tokenizer.encode("pics") +
89
+ [self.tokenizer.token_to_id("[SEP]")]
90
+ ), device = self.device).long()
91
+
92
+ subreddit_tokens = torch.cat(
93
+ [
94
+ subreddit_tokens,
95
+ torch.tensor([self.tokenizer.token_to_id("[SEP]")], device=self.device).long(),
96
+ cap_tokens
97
+ ])
98
+
99
+
100
+ predictions: List[Dict[str, Any]] = []
101
+
102
+ is_valid_subreddit = False
103
+ subreddit, rest_of_caption = "", ""
104
+ image_dict["decode_prompt"] = subreddit_tokens
105
+ while not is_valid_subreddit:
106
+
107
+ with torch.no_grad():
108
+ caption = self.model(image_dict)["predictions"][0].tolist()
109
+
110
+ if self.tokenizer.token_to_id("[SEP]") in caption:
111
+ sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
112
+ caption[sep_index] = self.tokenizer.token_to_id("://")
113
+
114
+ caption = self.tokenizer.decode(caption)
115
+
116
+ if "://" in caption:
117
+ subreddit, rest_of_caption = caption.split("://")
118
+ subreddit = "".join(subreddit.split())
119
+ rest_of_caption = rest_of_caption.strip()
120
+ else:
121
+ subreddit, rest_of_caption = "", caption
122
+
123
+ is_valid_subreddit = subreddit in self.valid_subs
124
+
125
+
126
+ return subreddit, rest_of_caption
127
+
128
+ def download_files():
129
+ #download model files
130
+ download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
131
+ for f in download_files:
132
+ fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
133
+ os.system(f"cp {fp} ./{f}")
134
+
135
+ def get_samples():
136
+ return glob.glob(SAMPLES_PATH)
137
+
138
+ def get_rand_idx(samples):
139
+ return random.randint(0,len(samples)-1)
140
+
141
+ @st.cache(allow_output_mutation=True) # allow mutation to update nucleus size
142
+ def create_objects():
143
+ sample_images = get_samples()
144
+ virtexModel = VirTexModel()
145
+ imageLoader = ImageLoader()
146
+ valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
147
+ valid_subs.insert(0, None)
148
+ return virtexModel, imageLoader, sample_images, valid_subs
149
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations>=0.5.0
2
+ Cython>=0.25
3
+ ftfy==5.8
4
+ future==0.18.0
5
+ huggingface-hub==0.1.2
6
+ lmdb==0.97
7
+ loguru==0.3.2
8
+ mypy_extensions==0.4.1
9
+ lvis==0.5.3
10
+ numpy>=1.17
11
+ opencv-python==4.1.2.30
12
+ sentencepiece>=0.1.90
13
+ torch==1.7.0
14
+ torchvision==0.8
15
+ tqdm>=4.50.0
16
+ wordsegment==1.3.1
17
+ whatimage==0.0.3
18
+ git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
samples/.ipynb_checkpoints/test-checkpoint.jpg ADDED
samples/0.jpg ADDED
samples/1.jpg ADDED
samples/10.jpg ADDED
samples/100.jpg ADDED
samples/11.jpg ADDED
samples/12.jpg ADDED
samples/13.jpg ADDED
samples/14.jpg ADDED
samples/15.jpg ADDED
samples/16.jpg ADDED
samples/17.jpg ADDED
samples/18.jpg ADDED
samples/19.jpg ADDED
samples/2.jpg ADDED
samples/20.jpg ADDED
samples/21.jpg ADDED
samples/22.jpg ADDED
samples/23.jpg ADDED
samples/24.jpg ADDED
samples/25.jpg ADDED
samples/26.jpg ADDED
samples/27.jpg ADDED
samples/28.jpg ADDED
samples/29.jpg ADDED
samples/3.jpg ADDED
samples/30.jpg ADDED
samples/31.jpg ADDED
samples/32.jpg ADDED
samples/33.jpg ADDED
samples/34.jpg ADDED
samples/35.jpg ADDED
samples/36.jpg ADDED
samples/37.jpg ADDED
samples/38.jpg ADDED
samples/39.jpg ADDED
samples/4.jpg ADDED
samples/40.jpg ADDED
samples/41.jpg ADDED