Littlehongman's picture
fix: wrong path
04d88b1
import torch
import torch.nn as nn
import wandb
import streamlit as st
import os
import clip
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class ImageEncoder(nn.Module):
def __init__(self, base_network):
super(ImageEncoder, self).__init__()
self.base_network = base_network
self.embedding_size = self.base_network.token_embedding.weight.shape[1]
def forward(self, images):
with torch.no_grad():
x = self.base_network.encode_image(images)
x = x / x.norm(dim=1, keepdim=True)
x = x.float()
return x
class Mapping(nn.Module):
# Map the featureMap from CLIP model to GPT2
def __init__(self, clip_embedding_size, gpt_embedding_size, length=30): # length: sentence length
super(Mapping, self).__init__()
self.clip_embedding_size = clip_embedding_size
self.gpt_embedding_size = gpt_embedding_size
self.length = length
self.fc1 = nn.Linear(clip_embedding_size, gpt_embedding_size * length)
def forward(self, x):
x = self.fc1(x)
return x.view(-1, self.length, self.gpt_embedding_size)
class TextDecoder(nn.Module):
def __init__(self, base_network):
super(TextDecoder, self).__init__()
self.base_network = base_network
self.embedding_size = self.base_network.transformer.wte.weight.shape[1]
self.vocab_size = self.base_network.transformer.wte.weight.shape[0]
def forward(self, concat_embedding, mask=None):
return self.base_network(inputs_embeds=concat_embedding, attention_mask=mask)
def get_embedding(self, texts):
return self.base_network.transformer.wte(texts)
import pytorch_lightning as pl
class ImageCaptioner(pl.LightningModule):
def __init__(self, clip_model, gpt_model, tokenizer, total_steps, max_length=20):
super(ImageCaptioner, self).__init__()
self.padding_token_id = tokenizer.pad_token_id
#self.stop_token_id = tokenizer.encode('.')[0]
# Define networks
self.clip = ImageEncoder(clip_model)
self.gpt = TextDecoder(gpt_model)
self.mapping_network = Mapping(self.clip.embedding_size, self.gpt.embedding_size, max_length)
# Define variables
self.total_steps = total_steps
self.max_length = max_length
self.clip_embedding_size = self.clip.embedding_size
self.gpt_embedding_size = self.gpt.embedding_size
self.gpt_vocab_size = self.gpt.vocab_size
def forward(self, images, texts, masks):
texts_embedding = self.gpt.get_embedding(texts)
images_embedding = self.clip(images)
images_projection = self.mapping_network(images_embedding).view(-1, self.max_length, self.gpt_embedding_size)
embedding_concat = torch.cat((images_projection, texts_embedding), dim=1)
out = self.gpt(embedding_concat, masks)
return out
# @st.cache_resource
# def download_trained_model():
# wandb.init(anonymous="must")
# api = wandb.Api()
# artifact = api.artifact('hungchiehwu/CLIP-L14_GPT/model-ql03493w:v3')
# artifact_dir = artifact.download()
# wandb.finish()
# return artifact_dir
@st.cache_resource
def load_clip_model():
clip_model, image_transform = clip.load("ViT-L/14", device="cpu")
return clip_model, image_transform
@st.cache_resource
def load_gpt_model():
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
return gpt_model, tokenizer
@st.cache_resource
def load_model():
# # Load fine-tuned model from wandb
artifact_dir = "./artifacts/model-ql03493w:v3"
PATH = f"{os.getcwd()}/{artifact_dir[2:]}/model.ckpt"
# Load pretrained GPT, CLIP model from OpenAI
clip_model, image_transfrom = load_clip_model()
gpt_model, tokenizer = load_gpt_model()
# Load weights
print(PATH)
model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0)
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])
return model, image_transfrom, tokenizer