Spaces:
Build error
Build error
File size: 4,734 Bytes
3d5e231 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
import os
import random
import urllib
import hashlib
import tarfile
import torch
import clip
import numpy as np
from PIL import Image
from torch.nn import functional as F
from tqdm import tqdm
import torchvision.utils as vutils
import matplotlib.pyplot as plt
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@torch.no_grad()
def clip_score(prompt: str,
images: np.ndarray,
model_clip: torch.nn.Module,
preprocess_clip,
device: str) -> np.ndarray:
images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
images = torch.stack(images, dim=0).to(device=device)
texts = clip.tokenize(prompt).to(device=device)
texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
image_features = model_clip.encode_image(images)
text_features = model_clip.encode_text(texts)
scores = F.cosine_similarity(image_features, text_features).squeeze()
rank = torch.argsort(scores, descending=True).cpu().numpy()
return rank
def download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
pathname = filename[:-len('.tar.gz')]
expected_md5 = url.split("/")[-2]
download_target = os.path.join(root, filename)
result_path = os.path.join(root, pathname)
if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
return result_path
with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
with tarfile.open(download_target, 'r:gz') as f:
pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
for member in pbar:
pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
f.extract(member=member, path=root)
return result_path
def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
return download(url_or_path, root)
return url_or_path
def images_to_numpy(tensor):
generated = tensor.data.cpu().numpy().transpose(1,2,0)
generated[generated < -1] = -1
generated[generated > 1] = 1
generated = (generated + 1) / 2 * 255
return generated.astype('uint8')
def save_image(ground_truth, images, out_dir, batch_idx):
for i, im in enumerate(images):
if len(im.shape) == 3:
plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im)
else:
bs = im.shape[0]
# plt.imsave()
for j in range(bs):
plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j])
# print("Ground truth Images shape: ", ground_truth.shape, len(images))
# images = vutils.make_grid(images, nrow=ground_truth.shape[0])
# images = images_to_numpy(images)
#
# if ground_truth is not None:
# ground_truth = vutils.make_grid(ground_truth, 5)
# ground_truth = images_to_numpy(ground_truth)
# print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape)
# images = np.concatenate([ground_truth, images], axis=0)
#
# output = Image.fromarray(images)
# output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx))
# if texts is not None:
# fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w')
# for idx in range(images.shape[0]):
# fid.write(str(idx) + '--------------------------------------------------------\n')
# for i in range(len(texts)):
# fid.write(texts[i][idx] + '\n')
# fid.write('\n\n')
# fid.close()
return |