File size: 4,734 Bytes
3d5e231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------

import os
import random
import urllib
import hashlib
import tarfile
import torch
import clip
import numpy as np
from PIL import Image
from torch.nn import functional as F
from tqdm import tqdm
import torchvision.utils as vutils
import matplotlib.pyplot as plt


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@torch.no_grad()
def clip_score(prompt: str,
               images: np.ndarray,
               model_clip: torch.nn.Module,
               preprocess_clip,
               device: str) -> np.ndarray:
    images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
    images = torch.stack(images, dim=0).to(device=device)
    texts = clip.tokenize(prompt).to(device=device)
    texts = torch.repeat_interleave(texts, images.shape[0], dim=0)

    image_features = model_clip.encode_image(images)
    text_features = model_clip.encode_text(texts)

    scores = F.cosine_similarity(image_features, text_features).squeeze()
    rank = torch.argsort(scores, descending=True).cpu().numpy()
    return rank


def download(url: str, root: str) -> str:
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)
    pathname = filename[:-len('.tar.gz')]

    expected_md5 = url.split("/")[-2]
    download_target = os.path.join(root, filename)
    result_path = os.path.join(root, pathname)

    if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
        return result_path

    with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
        with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
                  unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
        raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')

    with tarfile.open(download_target, 'r:gz') as f:
        pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
        for member in pbar:
            pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
            f.extract(member=member, path=root)

    return result_path


def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
    if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
        return download(url_or_path, root)
    return url_or_path


def images_to_numpy(tensor):
    generated = tensor.data.cpu().numpy().transpose(1,2,0)
    generated[generated < -1] = -1
    generated[generated > 1] = 1
    generated = (generated + 1) / 2 * 255
    return generated.astype('uint8')


def save_image(ground_truth, images, out_dir, batch_idx):

    for i, im in enumerate(images):
        if len(im.shape) == 3:
            plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im)
        else:
            bs = im.shape[0]
            # plt.imsave()
            for j in range(bs):
                plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j])


    # print("Ground truth Images shape: ", ground_truth.shape, len(images))

    # images = vutils.make_grid(images, nrow=ground_truth.shape[0])
    # images = images_to_numpy(images)
    #
    # if ground_truth is not None:
    #     ground_truth = vutils.make_grid(ground_truth, 5)
    #     ground_truth = images_to_numpy(ground_truth)
    #     print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape)
    #     images = np.concatenate([ground_truth, images], axis=0)
    #
    # output = Image.fromarray(images)
    # output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx))

    # if texts is not None:
    #     fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w')
    #     for idx in range(images.shape[0]):
    #         fid.write(str(idx) + '--------------------------------------------------------\n')
    #         for i in range(len(texts)):
    #             fid.write(texts[i][idx] + '\n')
    #         fid.write('\n\n')
    #     fid.close()
    return