yashonwu
add captioning
9bf9e42
raw
history blame
2.73 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import numpy as np
import torch
from .ShowTellModel import ShowTellModel
from .FCModel import FCModel
from .AttModel import *
from .TransformerModel import TransformerModel
from .cachedTransformer import TransformerModel as cachedTransformer
from .BertCapModel import BertCapModel
from .M2Transformer import M2TransformerModel
from .AoAModel import AoAModel
from .OldModel import ShowAttendTellModel
def setup(opt):
if opt.caption_model in ['fc', 'show_tell']:
print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
if opt.caption_model == 'fc':
print('Use newfc instead of fc')
if opt.caption_model == 'fc':
model = FCModel(opt)
elif opt.caption_model == 'language_model':
model = LMModel(opt)
elif opt.caption_model == 'newfc':
model = NewFCModel(opt)
elif opt.caption_model == 'show_tell':
model = ShowTellModel(opt)
elif opt.caption_model == 'show_attend_tell':
model = ShowAttendTellModel(opt)
# Att2in model in self-critical
elif opt.caption_model == 'att2in':
model = Att2inModel(opt)
# Att2in model with two-layer MLP img embedding and word embedding
elif opt.caption_model == 'att2in2':
model = Att2in2Model(opt)
elif opt.caption_model == 'att2all2':
print('Warning: this is not a correct implementation of the att2all model in the original paper.')
model = Att2all2Model(opt)
# Adaptive Attention model from Knowing when to look
elif opt.caption_model == 'adaatt':
model = AdaAttModel(opt)
# Adaptive Attention with maxout lstm
elif opt.caption_model == 'adaattmo':
model = AdaAttMOModel(opt)
# Top-down attention model
elif opt.caption_model in ['topdown', 'updown']:
model = UpDownModel(opt)
# StackAtt
elif opt.caption_model == 'stackatt':
model = StackAttModel(opt)
# DenseAtt
elif opt.caption_model == 'denseatt':
model = DenseAttModel(opt)
# Transformer
elif opt.caption_model == 'transformer':
if getattr(opt, 'cached_transformer', False):
model = cachedTransformer(opt)
else:
model = TransformerModel(opt)
# AoANet
elif opt.caption_model == 'aoa':
model = AoAModel(opt)
elif opt.caption_model == 'bert':
model = BertCapModel(opt)
elif opt.caption_model == 'm2transformer':
model = M2TransformerModel(opt)
else:
raise Exception("Caption model not supported: {}".format(opt.caption_model))
return model