herrius's picture
Upload 259 files
32b542e
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from uniperceiver.config import configurable
from uniperceiver.functional import load_vocab, decode_sequence, decode_sequence_bert
# from uniperceiver.tokenization import BertTokenizer
from uniperceiver.tokenization import ClipTokenizer
class DecodeStrategy(nn.Module, metaclass=ABCMeta):
@configurable
def __init__(
self,
*,
vocab_path,
vocab_name,
beam_size,
max_seq_len,
tokenizer,
bos_token_id,
eos_token_id,
spe_token_id = None,
fp16=False,
cfg=None,
):
super().__init__()
self.beam_size = beam_size
if tokenizer is None:
self.vocab = load_vocab(vocab_path)
else:
self.vocab = None
if len(vocab_name) > 1:
raise NotImplementedError("Only support caption inference on a single vocabulary!")
else:
self.vocab_name = vocab_name[0]
self.max_seq_len = max_seq_len
self.tokenizer = tokenizer
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.spe_token_id = spe_token_id
self.fp16 = fp16
self.cfg = cfg
self.len_penalty = self.cfg.DECODE_STRATEGY.get('LEN_PENALTY', 0.0) # do not normalize
pass
@classmethod
def from_config(cls, cfg):
tokenizer_map = {
# 'BERT': BertTokenizer,
'CLIP': ClipTokenizer,
}
tokenizer_cls = tokenizer_map.get(cfg.INFERENCE.VOCAB, None)
spe_token_id = None
if tokenizer_cls is None:
tokenizer = None
bos_token_id = 0
eos_token_id = 0
elif cfg.INFERENCE.VOCAB == 'CLIP':
tokenizer = tokenizer_cls()
bos_token_id = tokenizer.vocab['<|startoftext|>']
eos_token_id = tokenizer.vocab['<|endoftext|>']
spe_token_id = tokenizer.vocab['<|spe|>']
elif cfg.INFERENCE.VOCAB == 'CLIP_CAPTION':
tokenizer = tokenizer_cls()
bos_token_id = tokenizer.vocab['<|gen|>']
eos_token_id = tokenizer.vocab['<|endoftext|>']
else:
tokenizer = tokenizer_cls.from_pretrained(cfg.MODEL.PRETRAINING.MODEL_NAME, do_lower_case=cfg.MODEL.PRETRAINING.DO_LOWER_CASE)
if cfg.INFERENCE.VOCAB == 'BERT':
raise NotImplementedError
bos_token_id = tokenizer.vocab["[CLS]"]
eos_token_id = tokenizer.vocab["[SEP]"]
return {
"vocab_path": cfg.INFERENCE.VOCAB,
"vocab_name": cfg.DATASETS.TARGET_SET,
"beam_size": cfg.DECODE_STRATEGY.BEAM_SIZE,
"max_seq_len": cfg.MODEL.EVAL_MAX_SEQ_LEN if 'EVAL_MAX_SEQ_LEN' in cfg.MODEL else cfg.MODEL.MAX_SEQ_LEN,
'tokenizer': tokenizer,
"bos_token_id": bos_token_id,
"eos_token_id": eos_token_id,
"spe_token_id": spe_token_id,
"cfg": cfg,
# "fp16": cfg.SOLVER.AMP_FP16,
}
@abstractmethod
def _forward(self, batched_inputs, model):
pass
def forward(self, batched_inputs, output_sents, model):
ret = self._forward(batched_inputs, model)
if output_sents:
if self.vocab:
sents = decode_sequence(self.vocab, ret["G_SENTS_IDS"])
else:
sents = decode_sequence_bert(self.tokenizer, ret["G_SENTS_IDS"], self.eos_token_id)
ret.update({ "output": sents })
return ret