|
import os |
|
import sys |
|
import base64 |
|
from io import BytesIO |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
import torch |
|
from torch import nn |
|
from fastapi import FastAPI |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
from dalle.models import Dalle |
|
import logging |
|
from dalle.utils.utils import clip_score, download |
|
|
|
print("Loading models...") |
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = Dalle.from_pretrained("minDALL-E/1.3B") |
|
model.to(device=device) |
|
|
|
|
|
state_dict_ = torch.load('last.ckpt', map_location='cpu') |
|
vqgan_stage_dict = model.stage1.state_dict() |
|
|
|
for name, param in state_dict_['state_dict'].items(): |
|
if name not in model.stage1.state_dict().keys(): |
|
continue |
|
if isinstance(param, nn.parameter.Parameter): |
|
param = param.data |
|
vqgan_stage_dict[name].copy_(param) |
|
|
|
model.stage1.load_state_dict(vqgan_stage_dict) |
|
|
|
state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu') |
|
dalle_stage_dict = model.stage2.state_dict() |
|
|
|
for name, param in state_dict_dalle['state_dict'].items(): |
|
if name[6:] not in model.stage2.state_dict().keys(): |
|
print(name) |
|
continue |
|
if isinstance(param, nn.parameter.Parameter): |
|
param = param.data |
|
dalle_stage_dict[name[6:]].copy_(param) |
|
|
|
model.stage2.load_state_dict(dalle_stage_dict) |
|
|
|
|
|
|
|
|
|
print("Models loaded !") |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"minDALL-E!"} |
|
|
|
|
|
@app.get("/{generate}") |
|
def generate(prompt): |
|
images = sample(prompt) |
|
images = [to_base64(image) for image in images] |
|
return {"images": images} |
|
|
|
|
|
def sample(prompt): |
|
|
|
logging.info("starting sampling") |
|
images = ( |
|
model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device) |
|
.cpu() |
|
.numpy() |
|
) |
|
logging.info("sampling succeeded") |
|
images = np.transpose(images, (0, 2, 3, 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pil_images = [] |
|
for i in range(len(images)): |
|
im = Image.fromarray((images[i] * 255).astype(np.uint8)) |
|
pil_images.append(im) |
|
|
|
return pil_images |
|
|
|
|
|
def to_base64(pil_image): |
|
buffered = BytesIO() |
|
pil_image.save(buffered, format="JPEG") |
|
return base64.b64encode(buffered.getvalue()) |