mshukor commited on
Commit
7f260ad
1 Parent(s): 961d6ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -57
app.py CHANGED
@@ -53,52 +53,64 @@ use_cuda = torch.cuda.is_available()
53
  # use fp16 only when GPU is available
54
  use_fp16 = False
55
 
56
- # # download checkpoints
57
- # os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/caption_demo.pt; '
58
- # 'mkdir -p checkpoints; mv caption_demo.pt checkpoints/caption_demo.pt')
59
- # os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/refcoco_demo.pt; '
60
- # 'mkdir -p checkpoints; mv refcoco_demo.pt checkpoints/refcoco_demo.pt')
61
- # os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/general_demo.pt; '
62
- # 'mkdir -p checkpoints; mv general_demo.pt checkpoints/general_demo.pt')
63
 
 
 
64
 
65
- checkpoint_path = 'checkpoints/unival_s2_hs/checkpoint1.pt'
 
 
 
 
 
 
 
 
 
66
 
67
- # # Load ckpt & config for Image Captioning
68
- # caption_overrides={"eval_cider":False, "beam":5, "max_len_b":22, "no_repeat_ngram_size":3, "seed":7, "unnormalized": False,
69
- # "bpe_dir":"utils/BPE", "video_model_path": None,}
70
-
71
- # caption_models, caption_cfg, caption_task = checkpoint_utils.load_model_ensemble_and_task(
72
- # utils.split_paths(checkpoint_path),
73
- # arg_overrides=caption_overrides
74
- # )
75
-
76
- # # Load ckpt & config for Refcoco
77
- # refcoco_overrides = {"bpe_dir":"utils/BPE", "video_model_path": None}
78
-
79
- # refcoco_models, refcoco_cfg, refcoco_task = checkpoint_utils.load_model_ensemble_and_task(
80
- # utils.split_paths(checkpoint_path),
81
- # arg_overrides=refcoco_overrides
82
- # )
83
- # refcoco_cfg.common.seed = 7
84
- # refcoco_cfg.generation.beam = 5
85
- # refcoco_cfg.generation.min_len = 4
86
- # refcoco_cfg.generation.max_len_a = 0
87
- # refcoco_cfg.generation.max_len_b = 4
88
- # refcoco_cfg.generation.no_repeat_ngram_size = 3
89
-
90
- # # Load pretrained ckpt & config for VQA
91
- # parser = options.get_generation_parser()
92
- # input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE"]
93
- # args = options.parse_args_and_arch(parser, input_args)
94
- # vqa_cfg = convert_namespace_to_omegaconf(args)
95
- # vqa_task = tasks.setup_task(vqa_cfg.task)
96
- # vqa_models, vqa_cfg = checkpoint_utils.load_model_ensemble(
97
- # utils.split_paths(vqa_cfg.common_eval.path),
98
- # task=vqa_task
99
- # )
 
 
 
 
100
 
101
  # Load pretrained ckpt & config for Generic Interface
 
 
102
  parser = options.get_generation_parser()
103
  input_args = ["", "--task=refcoco", "--beam=10", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE", "--no-repeat-ngram-size=3", "--patch-image-size=384"]
104
  args = options.parse_args_and_arch(parser, input_args)
@@ -113,30 +125,26 @@ general_models, general_cfg = checkpoint_utils.load_model_ensemble(
113
  arg_overrides=overrides
114
  )
115
 
116
- # move models to gpu
117
- # move2gpu(caption_models, caption_cfg)
118
- # move2gpu(refcoco_models, refcoco_cfg)
119
- # move2gpu(vqa_models, vqa_cfg)
120
  move2gpu(general_models, general_cfg)
121
 
122
  # # Initialize generator
123
- # caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
124
- # refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
125
- # vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
126
- # vqa_generator.zero_shot = True
127
- # vqa_generator.constraint_trie = None
128
  general_generator = general_task.build_generator(general_models, general_cfg.generation)
129
 
130
  # Construct image transforms
131
- # caption_transform = construct_transform(caption_cfg.task.patch_image_size)
132
- # refcoco_transform = construct_transform(refcoco_cfg.task.patch_image_size)
133
- # vqa_transform = construct_transform(vqa_cfg.task.patch_image_size)
134
  general_transform = construct_transform(general_cfg.task.patch_image_size)
135
 
136
- # # Text preprocess
137
- # bos_item = torch.LongTensor([caption_task.src_dict.bos()])
138
- # eos_item = torch.LongTensor([caption_task.src_dict.eos()])
139
- # pad_idx = caption_task.src_dict.pad()
140
 
141
  # Text preprocess
142
  bos_item = torch.LongTensor([general_task.src_dict.bos()])
 
53
  # use fp16 only when GPU is available
54
  use_fp16 = False
55
 
56
+ # download checkpoints
57
+ os.system('mkdir -p checkpoints; ')
 
 
 
 
 
58
 
59
+ os.system('wget https://data.isir.upmc.fr/unival/models/unival_s2_hs/checkpoint1.pt; '
60
+ 'mkdir -p checkpoints/unival_s2_hs; mv checkpoint1.pt checkpoints/unival_s2_hs/')
61
 
62
+ os.system('wget https://data.isir.upmc.fr/unival/models/unival_vqa/checkpoint_best.pt; '
63
+ 'mkdir -p checkpoints/unival_vqa; mv checkpoint_best.pt checkpoints/unival_vqa/')
64
+ os.system('wget https://data.isir.upmc.fr/unival/models/unival_caption_stage_1/checkpoint_best_test.pt; '
65
+ 'mkdir -p checkpoints/unival_caption_stage_1; mv checkpoint_best_test.pt checkpoints/unival_caption_stage_1/')
66
+ os.system('wget https://data.isir.upmc.fr/unival/models/unival_refcocog/checkpoint_best.pt; '
67
+ 'mkdir -p checkpoints/unival_refcocog; mv checkpoint_best.pt checkpoints/unival_refcocog/')
68
+
69
+
70
+ # Load ckpt & config for Image Captioning
71
+ checkpoint_path = 'checkpoints/unival_caption_stage_1/checkpoint_best_test.pt'
72
 
73
+ caption_overrides={"eval_cider":False, "beam":5, "max_len_b":22, "no_repeat_ngram_size":3, "seed":7, "unnormalized": False,
74
+ "bpe_dir":"utils/BPE", "video_model_path": None,}
75
+
76
+ caption_models, caption_cfg, caption_task = checkpoint_utils.load_model_ensemble_and_task(
77
+ utils.split_paths(checkpoint_path),
78
+ arg_overrides=caption_overrides
79
+ )
80
+
81
+ # Load ckpt & config for Refcoco
82
+ checkpoint_path = 'checkpoints/unival_refcocog/checkpoint_best.pt'
83
+
84
+ refcoco_overrides = {"bpe_dir":"utils/BPE", "video_model_path": None}
85
+
86
+ refcoco_models, refcoco_cfg, refcoco_task = checkpoint_utils.load_model_ensemble_and_task(
87
+ utils.split_paths(checkpoint_path),
88
+ arg_overrides=refcoco_overrides
89
+ )
90
+ refcoco_cfg.common.seed = 7
91
+ refcoco_cfg.generation.beam = 5
92
+ refcoco_cfg.generation.min_len = 4
93
+ refcoco_cfg.generation.max_len_a = 0
94
+ refcoco_cfg.generation.max_len_b = 4
95
+ refcoco_cfg.generation.no_repeat_ngram_size = 3
96
+
97
+
98
+ # Load pretrained ckpt & config for VQA
99
+ checkpoint_path = 'checkpoints/unival_vqa/checkpoint_best.pt'
100
+
101
+ parser = options.get_generation_parser()
102
+ input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE"]
103
+ args = options.parse_args_and_arch(parser, input_args)
104
+ vqa_cfg = convert_namespace_to_omegaconf(args)
105
+ vqa_task = tasks.setup_task(vqa_cfg.task)
106
+ vqa_models, vqa_cfg = checkpoint_utils.load_model_ensemble(
107
+ utils.split_paths(vqa_cfg.common_eval.path),
108
+ task=vqa_task
109
+ )
110
 
111
  # Load pretrained ckpt & config for Generic Interface
112
+ checkpoint_path = 'checkpoints/unival_s2_hs/checkpoint1.pt'
113
+
114
  parser = options.get_generation_parser()
115
  input_args = ["", "--task=refcoco", "--beam=10", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE", "--no-repeat-ngram-size=3", "--patch-image-size=384"]
116
  args = options.parse_args_and_arch(parser, input_args)
 
125
  arg_overrides=overrides
126
  )
127
 
128
+ move models to gpu
129
+ move2gpu(caption_models, caption_cfg)
130
+ move2gpu(refcoco_models, refcoco_cfg)
131
+ move2gpu(vqa_models, vqa_cfg)
132
  move2gpu(general_models, general_cfg)
133
 
134
  # # Initialize generator
135
+ caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
136
+ refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
137
+ vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
138
+ vqa_generator.zero_shot = True
139
+ vqa_generator.constraint_trie = None
140
  general_generator = general_task.build_generator(general_models, general_cfg.generation)
141
 
142
  # Construct image transforms
143
+ caption_transform = construct_transform(caption_cfg.task.patch_image_size)
144
+ refcoco_transform = construct_transform(refcoco_cfg.task.patch_image_size)
145
+ vqa_transform = construct_transform(vqa_cfg.task.patch_image_size)
146
  general_transform = construct_transform(general_cfg.task.patch_image_size)
147
 
 
 
 
 
148
 
149
  # Text preprocess
150
  bos_item = torch.LongTensor([general_task.src_dict.bos()])