|
import torch |
|
import torch.nn as nn |
|
import wandb |
|
import streamlit as st |
|
import os |
|
|
|
import clip |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel |
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
|
|
def __init__(self, base_network): |
|
super(ImageEncoder, self).__init__() |
|
self.base_network = base_network |
|
self.embedding_size = self.base_network.token_embedding.weight.shape[1] |
|
|
|
def forward(self, images): |
|
with torch.no_grad(): |
|
x = self.base_network.encode_image(images) |
|
x = x / x.norm(dim=1, keepdim=True) |
|
x = x.float() |
|
|
|
return x |
|
|
|
class Mapping(nn.Module): |
|
|
|
def __init__(self, clip_embedding_size, gpt_embedding_size, length=30): |
|
super(Mapping, self).__init__() |
|
|
|
self.clip_embedding_size = clip_embedding_size |
|
self.gpt_embedding_size = gpt_embedding_size |
|
self.length = length |
|
|
|
self.fc1 = nn.Linear(clip_embedding_size, gpt_embedding_size * length) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
|
|
return x.view(-1, self.length, self.gpt_embedding_size) |
|
|
|
|
|
class TextDecoder(nn.Module): |
|
def __init__(self, base_network): |
|
super(TextDecoder, self).__init__() |
|
self.base_network = base_network |
|
self.embedding_size = self.base_network.transformer.wte.weight.shape[1] |
|
self.vocab_size = self.base_network.transformer.wte.weight.shape[0] |
|
|
|
def forward(self, concat_embedding, mask=None): |
|
return self.base_network(inputs_embeds=concat_embedding, attention_mask=mask) |
|
|
|
|
|
def get_embedding(self, texts): |
|
return self.base_network.transformer.wte(texts) |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
|
|
|
class ImageCaptioner(pl.LightningModule): |
|
def __init__(self, clip_model, gpt_model, tokenizer, total_steps, max_length=20): |
|
super(ImageCaptioner, self).__init__() |
|
|
|
self.padding_token_id = tokenizer.pad_token_id |
|
|
|
|
|
|
|
self.clip = ImageEncoder(clip_model) |
|
self.gpt = TextDecoder(gpt_model) |
|
self.mapping_network = Mapping(self.clip.embedding_size, self.gpt.embedding_size, max_length) |
|
|
|
|
|
self.total_steps = total_steps |
|
self.max_length = max_length |
|
self.clip_embedding_size = self.clip.embedding_size |
|
self.gpt_embedding_size = self.gpt.embedding_size |
|
self.gpt_vocab_size = self.gpt.vocab_size |
|
|
|
|
|
def forward(self, images, texts, masks): |
|
texts_embedding = self.gpt.get_embedding(texts) |
|
images_embedding = self.clip(images) |
|
|
|
images_projection = self.mapping_network(images_embedding).view(-1, self.max_length, self.gpt_embedding_size) |
|
embedding_concat = torch.cat((images_projection, texts_embedding), dim=1) |
|
|
|
out = self.gpt(embedding_concat, masks) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_clip_model(): |
|
|
|
clip_model, image_transform = clip.load("ViT-L/14", device="cpu") |
|
|
|
return clip_model, image_transform |
|
|
|
@st.cache_resource |
|
def load_gpt_model(): |
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
return gpt_model, tokenizer |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
|
|
|
|
artifact_dir = "./artifacts/model-ql03493w:v3" |
|
PATH = f"{os.getcwd()}/{artifact_dir[2:]}/model.ckpt" |
|
|
|
|
|
clip_model, image_transfrom = load_clip_model() |
|
gpt_model, tokenizer = load_gpt_model() |
|
|
|
|
|
|
|
print(PATH) |
|
model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0) |
|
checkpoint = torch.load(PATH, map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
|
|
return model, image_transfrom, tokenizer |