Update app.py
Browse files
app.py
CHANGED
@@ -53,52 +53,64 @@ use_cuda = torch.cuda.is_available()
|
|
53 |
# use fp16 only when GPU is available
|
54 |
use_fp16 = False
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
# 'mkdir -p checkpoints; mv caption_demo.pt checkpoints/caption_demo.pt')
|
59 |
-
# os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/refcoco_demo.pt; '
|
60 |
-
# 'mkdir -p checkpoints; mv refcoco_demo.pt checkpoints/refcoco_demo.pt')
|
61 |
-
# os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/general_demo.pt; '
|
62 |
-
# 'mkdir -p checkpoints; mv general_demo.pt checkpoints/general_demo.pt')
|
63 |
|
|
|
|
|
64 |
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# Load pretrained ckpt & config for Generic Interface
|
|
|
|
|
102 |
parser = options.get_generation_parser()
|
103 |
input_args = ["", "--task=refcoco", "--beam=10", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE", "--no-repeat-ngram-size=3", "--patch-image-size=384"]
|
104 |
args = options.parse_args_and_arch(parser, input_args)
|
@@ -113,30 +125,26 @@ general_models, general_cfg = checkpoint_utils.load_model_ensemble(
|
|
113 |
arg_overrides=overrides
|
114 |
)
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
move2gpu(general_models, general_cfg)
|
121 |
|
122 |
# # Initialize generator
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
general_generator = general_task.build_generator(general_models, general_cfg.generation)
|
129 |
|
130 |
# Construct image transforms
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
general_transform = construct_transform(general_cfg.task.patch_image_size)
|
135 |
|
136 |
-
# # Text preprocess
|
137 |
-
# bos_item = torch.LongTensor([caption_task.src_dict.bos()])
|
138 |
-
# eos_item = torch.LongTensor([caption_task.src_dict.eos()])
|
139 |
-
# pad_idx = caption_task.src_dict.pad()
|
140 |
|
141 |
# Text preprocess
|
142 |
bos_item = torch.LongTensor([general_task.src_dict.bos()])
|
|
|
53 |
# use fp16 only when GPU is available
|
54 |
use_fp16 = False
|
55 |
|
56 |
+
# download checkpoints
|
57 |
+
os.system('mkdir -p checkpoints; ')
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
os.system('wget https://data.isir.upmc.fr/unival/models/unival_s2_hs/checkpoint1.pt; '
|
60 |
+
'mkdir -p checkpoints/unival_s2_hs; mv checkpoint1.pt checkpoints/unival_s2_hs/')
|
61 |
|
62 |
+
os.system('wget https://data.isir.upmc.fr/unival/models/unival_vqa/checkpoint_best.pt; '
|
63 |
+
'mkdir -p checkpoints/unival_vqa; mv checkpoint_best.pt checkpoints/unival_vqa/')
|
64 |
+
os.system('wget https://data.isir.upmc.fr/unival/models/unival_caption_stage_1/checkpoint_best_test.pt; '
|
65 |
+
'mkdir -p checkpoints/unival_caption_stage_1; mv checkpoint_best_test.pt checkpoints/unival_caption_stage_1/')
|
66 |
+
os.system('wget https://data.isir.upmc.fr/unival/models/unival_refcocog/checkpoint_best.pt; '
|
67 |
+
'mkdir -p checkpoints/unival_refcocog; mv checkpoint_best.pt checkpoints/unival_refcocog/')
|
68 |
+
|
69 |
+
|
70 |
+
# Load ckpt & config for Image Captioning
|
71 |
+
checkpoint_path = 'checkpoints/unival_caption_stage_1/checkpoint_best_test.pt'
|
72 |
|
73 |
+
caption_overrides={"eval_cider":False, "beam":5, "max_len_b":22, "no_repeat_ngram_size":3, "seed":7, "unnormalized": False,
|
74 |
+
"bpe_dir":"utils/BPE", "video_model_path": None,}
|
75 |
+
|
76 |
+
caption_models, caption_cfg, caption_task = checkpoint_utils.load_model_ensemble_and_task(
|
77 |
+
utils.split_paths(checkpoint_path),
|
78 |
+
arg_overrides=caption_overrides
|
79 |
+
)
|
80 |
+
|
81 |
+
# Load ckpt & config for Refcoco
|
82 |
+
checkpoint_path = 'checkpoints/unival_refcocog/checkpoint_best.pt'
|
83 |
+
|
84 |
+
refcoco_overrides = {"bpe_dir":"utils/BPE", "video_model_path": None}
|
85 |
+
|
86 |
+
refcoco_models, refcoco_cfg, refcoco_task = checkpoint_utils.load_model_ensemble_and_task(
|
87 |
+
utils.split_paths(checkpoint_path),
|
88 |
+
arg_overrides=refcoco_overrides
|
89 |
+
)
|
90 |
+
refcoco_cfg.common.seed = 7
|
91 |
+
refcoco_cfg.generation.beam = 5
|
92 |
+
refcoco_cfg.generation.min_len = 4
|
93 |
+
refcoco_cfg.generation.max_len_a = 0
|
94 |
+
refcoco_cfg.generation.max_len_b = 4
|
95 |
+
refcoco_cfg.generation.no_repeat_ngram_size = 3
|
96 |
+
|
97 |
+
|
98 |
+
# Load pretrained ckpt & config for VQA
|
99 |
+
checkpoint_path = 'checkpoints/unival_vqa/checkpoint_best.pt'
|
100 |
+
|
101 |
+
parser = options.get_generation_parser()
|
102 |
+
input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE"]
|
103 |
+
args = options.parse_args_and_arch(parser, input_args)
|
104 |
+
vqa_cfg = convert_namespace_to_omegaconf(args)
|
105 |
+
vqa_task = tasks.setup_task(vqa_cfg.task)
|
106 |
+
vqa_models, vqa_cfg = checkpoint_utils.load_model_ensemble(
|
107 |
+
utils.split_paths(vqa_cfg.common_eval.path),
|
108 |
+
task=vqa_task
|
109 |
+
)
|
110 |
|
111 |
# Load pretrained ckpt & config for Generic Interface
|
112 |
+
checkpoint_path = 'checkpoints/unival_s2_hs/checkpoint1.pt'
|
113 |
+
|
114 |
parser = options.get_generation_parser()
|
115 |
input_args = ["", "--task=refcoco", "--beam=10", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE", "--no-repeat-ngram-size=3", "--patch-image-size=384"]
|
116 |
args = options.parse_args_and_arch(parser, input_args)
|
|
|
125 |
arg_overrides=overrides
|
126 |
)
|
127 |
|
128 |
+
move models to gpu
|
129 |
+
move2gpu(caption_models, caption_cfg)
|
130 |
+
move2gpu(refcoco_models, refcoco_cfg)
|
131 |
+
move2gpu(vqa_models, vqa_cfg)
|
132 |
move2gpu(general_models, general_cfg)
|
133 |
|
134 |
# # Initialize generator
|
135 |
+
caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
|
136 |
+
refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
|
137 |
+
vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
|
138 |
+
vqa_generator.zero_shot = True
|
139 |
+
vqa_generator.constraint_trie = None
|
140 |
general_generator = general_task.build_generator(general_models, general_cfg.generation)
|
141 |
|
142 |
# Construct image transforms
|
143 |
+
caption_transform = construct_transform(caption_cfg.task.patch_image_size)
|
144 |
+
refcoco_transform = construct_transform(refcoco_cfg.task.patch_image_size)
|
145 |
+
vqa_transform = construct_transform(vqa_cfg.task.patch_image_size)
|
146 |
general_transform = construct_transform(general_cfg.task.patch_image_size)
|
147 |
|
|
|
|
|
|
|
|
|
148 |
|
149 |
# Text preprocess
|
150 |
bos_item = torch.LongTensor([general_task.src_dict.bos()])
|