Spaces:
Runtime error
Runtime error
File size: 4,984 Bytes
7d1df38 f307fe5 7d1df38 3150e77 a8416ee 3150e77 214bd84 09da12b 214bd84 4dab50d 7d1df38 4dab50d 79c7b01 f307fe5 a8416ee f307fe5 79c7b01 4dab50d 7d1df38 4dab50d 79c7b01 ab30850 79c7b01 09da12b 1df8b5f 79c7b01 7d1df38 b80df5c 16484d3 b80df5c 250dc27 b80df5c c838395 ed768de 99ca89b d92334f b80df5c 7d1df38 4dab50d 7d1df38 d92334f 7d1df38 0674c7e 7d1df38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from huggingface_hub import hf_hub_url, cached_download
from PIL import Image
import os
import json
import glob
import random
from typing import Any, Dict, List
import torch
import torchvision
import wordsegment as ws
from virtex.config import Config
from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
from virtex.utils.checkpointing import CheckpointManager
CONFIG_PATH = "config.yaml"
MODEL_PATH = "checkpoint_last5.pth"
VALID_SUBREDDITS_PATH = "subreddit_list.json"
SAMPLES_PATH = "./samples/*.jpg"
class ImageLoader():
def __init__(self):
self.transformer = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
self.show_size=500
def load(self, im_path):
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
return {"image": im}
def raw_load(self, im_path):
im = torch.FloatTensor(Image.open(im_path))
return {"image": im}
def transform(self, image):
im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
return {"image": im}
def text_transform(self, text):
# at present just lowercasing:
return text.lower()
def show_resize(self, image):
# ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
image = torchvision.transforms.functional.to_tensor(image)
x,y = image.shape[-2:]
ratio = float(self.show_size/max((x,y)))
image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
return torchvision.transforms.functional.to_pil_image(image)
class VirTexModel():
def __init__(self):
self.config = Config(CONFIG_PATH)
ws.load()
self.device = 'cpu'
self.tokenizer = TokenizerFactory.from_config(self.config)
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
self.model.eval()
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
def predict(self, image_dict, sub_prompt = None, prompt = ""):
if sub_prompt is None:
subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
else:
subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
subreddit_tokens = (
[self.model.sos_index] +
self.tokenizer.encode(subreddit_tokens) +
[self.tokenizer.token_to_id("[SEP]")]
)
subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
if prompt is not "":
# at present prompts without subreddits will break without this change
# TODO FIX
if True: #sub_prompt is not None:
cap_tokens = self.tokenizer.encode(prompt)
cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
predictions: List[Dict[str, Any]] = []
is_valid_subreddit = False
subreddit, rest_of_caption = "", ""
image_dict["decode_prompt"] = subreddit_tokens
while not is_valid_subreddit:
with torch.no_grad():
caption = self.model(image_dict)["predictions"][0].tolist()
if self.tokenizer.token_to_id("[SEP]") in caption:
sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
caption[sep_index] = self.tokenizer.token_to_id("://")
caption = self.tokenizer.decode(caption)
if "://" in caption:
subreddit, rest_of_caption = caption.split("://")
subreddit = "".join(subreddit.split())
rest_of_caption = rest_of_caption.strip()
else:
subreddit, rest_of_caption = "", caption
is_valid_subreddit = True if sub_prompt is not None or prompt is not None else subreddit in self.valid_subs
return subreddit, rest_of_caption
def download_files():
#download model files
download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
def get_samples():
return glob.glob(SAMPLES_PATH)
def get_rand_img(samples):
i = random.randint(0,len(samples)-1)
return i, samples[i]
|