import torch import torch.nn as nn import transformers from torch.utils.data import Dataset from transformers import ViTFeatureExtractor from io import BytesIO from base64 import b64decode from PIL import Image from accelerate import Accelerator import base64 from config import get_config from pathlib import Path from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.trainers import WordLevelTrainer from tokenizers.pre_tokenizers import Whitespace from model import build_transformer import torch.nn.functional as F from transformers import GPT2TokenizerFast def process(model,image, tokenizer, device): image = get_image(image) model.eval() with torch.no_grad(): encoder_input = image.unsqueeze(0).to(device) # (b, seq_len) # decoder_input = batch['decoder_input'].to(device) # (B, seq_len) # encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len) # decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len) model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device) model_text = tokenizer.decode(model_out.detach().cpu().numpy()) print(model_text) # get image prompt def get_image(image): # import model model_id = 'google/vit-base-patch16-224-in21k' feature_extractor = ViTFeatureExtractor.from_pretrained( model_id ) image = Image.open(BytesIO(b64decode(''.join(image)))) if image.mode != 'RGB': image = image.convert('RGB') enc_input = feature_extractor( image, return_tensors='pt' ) return enc_input['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0) #get tokenizer def get_or_build_tokenizer(config): tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]') return tokenizer def causal_mask(size): mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) return mask == 0 # get model def get_model(config, vocab_tgt_len): model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model']) return model # greedy decode def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device): sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]') eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]') # Precompute the encoder output and reuse it for every step encoder_output = model.encode(source, None) # Initialize the decoder input with the sos token decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device) while True: if decoder_input.size(1) == max_len: break # build mask for target decoder_mask = causal_mask(decoder_input.size(1)).long().to(device) # calculate output out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) # print(f'out: {out.shape}') # Get next token probabilities with temperature applied logits = model.project(out[:, -1]) probabilities = F.softmax(logits, dim=-1) # Greedily select the next word next_word = torch.argmax(probabilities, dim=1) # Append next word decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1) # # get next token # prob = model.project(out[:, -1]) # _, next_word = torch.max(prob, dim=1) # # print(f'prob: {prob.shape}') # decoder_input = torch.cat( # [decoder_input, torch.empty(1, 1).long().fill_(next_word.item()).to(device)], dim=1 # ) if next_word.item() == eos_idx: break return decoder_input.squeeze(0) def image_base64(): with open('C:/AI/projects/vision_model_pretrained/validation/content/memory_image_23330.jpg', 'rb') as image_file: base64_bytes = base64.b64encode(image_file.read()) base64_string = base64_bytes.decode() return base64_string def start(): print('start') accelerator = Accelerator() device = accelerator.device config = get_config() tokenizer = get_or_build_tokenizer(config) model = get_model(config, len(tokenizer)) model = accelerator.prepare(model) accelerator.load_state('C:/AI/projects/vision_model_pretrained/Vision_Model_pretrained/models/vision_model_04') image = image_base64() process(model, image, tokenizer, device) start()