File size: 3,651 Bytes
8ce5e2d
 
 
 
 
 
 
 
 
 
 
 
a94ffe4
8ce5e2d
b46e9dc
e65b549
8ce5e2d
 
 
 
 
e65b549
 
 
 
 
 
 
 
8ce5e2d
 
 
 
 
 
 
 
 
 
 
 
 
e65b549
8ce5e2d
e65b549
 
8ce5e2d
e65b549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cb75cc
e65b549
8ce5e2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46e9dc
8ce5e2d
 
 
 
 
b46e9dc
8ce5e2d
 
 
b46e9dc
 
 
 
8ce5e2d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import sys
import base64
from io import BytesIO
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from torch import nn
from fastapi import FastAPI
import numpy as np
from PIL import Image

#import clip
from dalle.models import Dalle
import logging
import streamlit as st
from dalle.utils.utils import clip_score, download

print("Loading models...")
app = FastAPI()

from huggingface_hub import hf_hub_download

logging.info("Start downloading")
full_dict_path = hf_hub_download(repo_id="MatthiasC/dall-e-logo", filename="full_dict_new.ckpt",
                                 use_auth_token=st.secrets["model_hub"])
logging.info("End downloading")
logging.info(full_dict_path)


# url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
# root = os.path.expanduser("~/.cache/minDALLE")
# filename = os.path.basename(url)
# pathname = filename[: -len(".tar.gz")]
# download_target = os.path.join(root, filename)
# result_path = os.path.join(root, pathname)
# if not os.path.exists(result_path):
#     result_path = download(url, root)


device = "cuda" if torch.cuda.is_available() else "cpu"
model = Dalle.from_pretrained("minDALL-E/1.3B")  # This will automatically download the pretrained model.
#model.to(device=device)


# OLD CODE
# -----------------------------------------------------------
# state_dict_ = torch.load('last.ckpt', map_location='cpu')
# vqgan_stage_dict = model.stage1.state_dict()
#
# for name, param in state_dict_['state_dict'].items():
#     if name not in model.stage1.state_dict().keys():
#         continue
#     if isinstance(param, nn.parameter.Parameter):
#         param = param.data
#     vqgan_stage_dict[name].copy_(param)
#
# model.stage1.load_state_dict(vqgan_stage_dict)
# #---------------------------------------------------------
# state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
# dalle_stage_dict = model.stage2.state_dict()
#
# for name, param in state_dict_dalle['state_dict'].items():
#     if name[6:] not in model.stage2.state_dict().keys():
#         print(name)
#         continue
#     if isinstance(param, nn.parameter.Parameter):
#         param = param.data
#     dalle_stage_dict[name[6:]].copy_(param)
#
# model.stage2.load_state_dict(dalle_stage_dict)

# NEW METHOD
model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu')))
model.to(device=device)

# model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
# model_clip.to(device=device)

print("Models loaded !")


@app.get("/")
def read_root():
    return {"minDALL-E!"}


@app.get("/{generate}")
def generate(prompt):
    images = sample(prompt)
    images = [to_base64(image) for image in images]
    return {"images": images}


def sample(prompt):
    # Sampling
    logging.info("starting sampling")
    images = (
        model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device)
        .cpu()
        .numpy()
    )
    logging.info("sampling succeeded")
    images = np.transpose(images, (0, 2, 3, 1))

    # CLIP Re-ranking
    # rank = clip_score(
    #     prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
    # )
    # images = images[rank]
    
    pil_images = []
    for i in range(len(images)):
        im = Image.fromarray((images[i] * 255).astype(np.uint8))
        pil_images.append(im)
    
    return pil_images


def to_base64(pil_image):
    buffered = BytesIO()
    pil_image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue())