|
from transformers import ViTConfig, FlaxViTModel, GPT2Config, FlaxGPT2Model, FlaxAutoModelForVision2Seq, FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer |
|
|
|
|
|
hidden_size = 8 |
|
num_hidden_layers = 2 |
|
num_attention_heads = 2 |
|
intermediate_size = 16 |
|
|
|
n_embd = 8 |
|
n_layer = 2 |
|
n_head = 2 |
|
n_inner = 16 |
|
|
|
encoder_config = ViTConfig( |
|
hidden_size=hidden_size, |
|
num_hidden_layers=num_hidden_layers, |
|
num_attention_heads=num_attention_heads, |
|
intermediate_size=intermediate_size, |
|
) |
|
decoder_config = GPT2Config( |
|
n_embd=n_embd, |
|
n_layer=n_layer, |
|
n_head=n_head, |
|
n_inner=n_inner, |
|
) |
|
encoder = FlaxViTModel(encoder_config) |
|
decoder = FlaxGPT2Model(decoder_config) |
|
encoder.save_pretrained("./encoder-decoder/encoder") |
|
decoder.save_pretrained("./encoder-decoder/decoder") |
|
|
|
enocder_decoder = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( |
|
"./encoder-decoder/encoder", |
|
"./encoder-decoder/decoder", |
|
) |
|
enocder_decoder.save_pretrained("./encoder-decoder") |
|
enocder_decoder = FlaxAutoModelForVision2Seq.from_pretrained("./encoder-decoder") |
|
|
|
|
|
config = enocder_decoder.config |
|
|
|
decoder_start_token_id = getattr(config, "decoder_start_token_id", None) |
|
if not decoder_start_token_id and getattr(config, "decoder", None): |
|
decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None) |
|
bos_token_id = getattr(config, "bos_token_id", None) |
|
if not bos_token_id and getattr(config, "decoder", None): |
|
bos_token_id = getattr(config.decoder, "bos_token_id", None) |
|
eos_token_id = getattr(config, "eos_token_id", None) |
|
if not eos_token_id and getattr(config, "decoder", None): |
|
eos_token_id = getattr(config.decoder, "eos_token_id", None) |
|
pad_token_id = getattr(config, "pad_token_id", None) |
|
if not pad_token_id and getattr(config, "decoder", None): |
|
pad_token_id = getattr(config.decoder, "pad_token_id", None) |
|
|
|
if decoder_start_token_id is None: |
|
decoder_start_token_id = bos_token_id |
|
if pad_token_id is None: |
|
pad_token_id = eos_token_id |
|
|
|
config.decoder_start_token_id = decoder_start_token_id |
|
config.bos_token_id = bos_token_id |
|
config.eos_token_id = eos_token_id |
|
config.pad_token_id = pad_token_id |
|
|
|
if getattr(config, "decoder", None): |
|
config.decoder.decoder_start_token_id = decoder_start_token_id |
|
config.decoder.bos_token_id = bos_token_id |
|
config.decoder.eos_token_id = eos_token_id |
|
config.decoder.pad_token_id = pad_token_id |
|
|
|
fe = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) |
|
|
|
fe.save_pretrained("./encoder-decoder/encoder") |
|
tokenizer.save_pretrained("./encoder-decoder/decoder") |
|
|
|
targets = ['i love dog', 'you cat is very cute'] |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
labels = tokenizer( |
|
targets, max_length=8, padding="max_length", truncation=True, return_tensors="np" |
|
) |
|
|
|
print(labels) |
|
|