File size: 4,216 Bytes
4cea813 26d3bd8 4cea813 efa8dd7 4cea813 efa8dd7 4cea813 efa8dd7 4cea813 efa8dd7 4cea813 efa8dd7 04d88b1 4cea813 26d3bd8 4cea813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
import wandb
import streamlit as st
import os
import clip
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class ImageEncoder(nn.Module):
def __init__(self, base_network):
super(ImageEncoder, self).__init__()
self.base_network = base_network
self.embedding_size = self.base_network.token_embedding.weight.shape[1]
def forward(self, images):
with torch.no_grad():
x = self.base_network.encode_image(images)
x = x / x.norm(dim=1, keepdim=True)
x = x.float()
return x
class Mapping(nn.Module):
# Map the featureMap from CLIP model to GPT2
def __init__(self, clip_embedding_size, gpt_embedding_size, length=30): # length: sentence length
super(Mapping, self).__init__()
self.clip_embedding_size = clip_embedding_size
self.gpt_embedding_size = gpt_embedding_size
self.length = length
self.fc1 = nn.Linear(clip_embedding_size, gpt_embedding_size * length)
def forward(self, x):
x = self.fc1(x)
return x.view(-1, self.length, self.gpt_embedding_size)
class TextDecoder(nn.Module):
def __init__(self, base_network):
super(TextDecoder, self).__init__()
self.base_network = base_network
self.embedding_size = self.base_network.transformer.wte.weight.shape[1]
self.vocab_size = self.base_network.transformer.wte.weight.shape[0]
def forward(self, concat_embedding, mask=None):
return self.base_network(inputs_embeds=concat_embedding, attention_mask=mask)
def get_embedding(self, texts):
return self.base_network.transformer.wte(texts)
import pytorch_lightning as pl
class ImageCaptioner(pl.LightningModule):
def __init__(self, clip_model, gpt_model, tokenizer, total_steps, max_length=20):
super(ImageCaptioner, self).__init__()
self.padding_token_id = tokenizer.pad_token_id
#self.stop_token_id = tokenizer.encode('.')[0]
# Define networks
self.clip = ImageEncoder(clip_model)
self.gpt = TextDecoder(gpt_model)
self.mapping_network = Mapping(self.clip.embedding_size, self.gpt.embedding_size, max_length)
# Define variables
self.total_steps = total_steps
self.max_length = max_length
self.clip_embedding_size = self.clip.embedding_size
self.gpt_embedding_size = self.gpt.embedding_size
self.gpt_vocab_size = self.gpt.vocab_size
def forward(self, images, texts, masks):
texts_embedding = self.gpt.get_embedding(texts)
images_embedding = self.clip(images)
images_projection = self.mapping_network(images_embedding).view(-1, self.max_length, self.gpt_embedding_size)
embedding_concat = torch.cat((images_projection, texts_embedding), dim=1)
out = self.gpt(embedding_concat, masks)
return out
# @st.cache_resource
# def download_trained_model():
# wandb.init(anonymous="must")
# api = wandb.Api()
# artifact = api.artifact('hungchiehwu/CLIP-L14_GPT/model-ql03493w:v3')
# artifact_dir = artifact.download()
# wandb.finish()
# return artifact_dir
@st.cache_resource
def load_clip_model():
clip_model, image_transform = clip.load("ViT-L/14", device="cpu")
return clip_model, image_transform
@st.cache_resource
def load_gpt_model():
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
return gpt_model, tokenizer
@st.cache_resource
def load_model():
# # Load fine-tuned model from wandb
artifact_dir = "./artifacts/model-ql03493w:v3"
PATH = f"{os.getcwd()}/{artifact_dir[2:]}/model.ckpt"
# Load pretrained GPT, CLIP model from OpenAI
clip_model, image_transfrom = load_clip_model()
gpt_model, tokenizer = load_gpt_model()
# Load weights
print(PATH)
model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0)
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])
return model, image_transfrom, tokenizer |