Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import copy | |
import numpy as np | |
import torch | |
from .ShowTellModel import ShowTellModel | |
from .FCModel import FCModel | |
from .AttModel import * | |
from .TransformerModel import TransformerModel | |
from .cachedTransformer import TransformerModel as cachedTransformer | |
from .BertCapModel import BertCapModel | |
from .M2Transformer import M2TransformerModel | |
from .AoAModel import AoAModel | |
from .OldModel import ShowAttendTellModel | |
def setup(opt): | |
if opt.caption_model in ['fc', 'show_tell']: | |
print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model) | |
if opt.caption_model == 'fc': | |
print('Use newfc instead of fc') | |
if opt.caption_model == 'fc': | |
model = FCModel(opt) | |
elif opt.caption_model == 'language_model': | |
model = LMModel(opt) | |
elif opt.caption_model == 'newfc': | |
model = NewFCModel(opt) | |
elif opt.caption_model == 'show_tell': | |
model = ShowTellModel(opt) | |
elif opt.caption_model == 'show_attend_tell': | |
model = ShowAttendTellModel(opt) | |
# Att2in model in self-critical | |
elif opt.caption_model == 'att2in': | |
model = Att2inModel(opt) | |
# Att2in model with two-layer MLP img embedding and word embedding | |
elif opt.caption_model == 'att2in2': | |
model = Att2in2Model(opt) | |
elif opt.caption_model == 'att2all2': | |
print('Warning: this is not a correct implementation of the att2all model in the original paper.') | |
model = Att2all2Model(opt) | |
# Adaptive Attention model from Knowing when to look | |
elif opt.caption_model == 'adaatt': | |
model = AdaAttModel(opt) | |
# Adaptive Attention with maxout lstm | |
elif opt.caption_model == 'adaattmo': | |
model = AdaAttMOModel(opt) | |
# Top-down attention model | |
elif opt.caption_model in ['topdown', 'updown']: | |
model = UpDownModel(opt) | |
# StackAtt | |
elif opt.caption_model == 'stackatt': | |
model = StackAttModel(opt) | |
# DenseAtt | |
elif opt.caption_model == 'denseatt': | |
model = DenseAttModel(opt) | |
# Transformer | |
elif opt.caption_model == 'transformer': | |
if getattr(opt, 'cached_transformer', False): | |
model = cachedTransformer(opt) | |
else: | |
model = TransformerModel(opt) | |
# AoANet | |
elif opt.caption_model == 'aoa': | |
model = AoAModel(opt) | |
elif opt.caption_model == 'bert': | |
model = BertCapModel(opt) | |
elif opt.caption_model == 'm2transformer': | |
model = M2TransformerModel(opt) | |
else: | |
raise Exception("Caption model not supported: {}".format(opt.caption_model)) | |
return model | |