logo-generator / server.py
MatthiasC's picture
Changes, 1 worker on server and remove clip usage from server
history blame
3.16 kB
import os
import sys
import base64
from io import BytesIO
import torch
from torch import nn
from fastapi import FastAPI
import numpy as np
from PIL import Image
#import clip
from dalle.models import Dalle
import logging
from dalle.utils.utils import clip_score, download
print("Loading models...")
app = FastAPI()
# url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
# root = os.path.expanduser("~/.cache/minDALLE")
# filename = os.path.basename(url)
# pathname = filename[: -len(".tar.gz")]
# download_target = os.path.join(root, filename)
# result_path = os.path.join(root, pathname)
# if not os.path.exists(result_path):
# result_path = download(url, root)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically download the pretrained model.
# -----------------------------------------------------------
state_dict_ = torch.load('last.ckpt', map_location='cpu')
vqgan_stage_dict = model.stage1.state_dict()
for name, param in state_dict_['state_dict'].items():
if name not in model.stage1.state_dict().keys():
if isinstance(param, nn.parameter.Parameter):
param = param.data
# ---------------------------------------------------------
# state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
# dalle_stage_dict = model.stage2.state_dict()
# for name, param in state_dict_dalle['state_dict'].items():
# if name[6:] not in model.stage2.state_dict().keys():
# print(name)
# continue
# if isinstance(param, nn.parameter.Parameter):
# param = param.data
# dalle_stage_dict[name[6:]].copy_(param)
# model.stage2.load_state_dict(dalle_stage_dict)
# model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
# model_clip.to(device=device)
print("Models loaded !")
def read_root():
return {"minDALL-E!"}
def generate(prompt):
images = sample(prompt)
images = [to_base64(image) for image in images]
return {"images": images}
def sample(prompt):
# Sampling
logging.info("starting sampling")
images = (
model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device)
logging.info("sampling succeeded")
images = np.transpose(images, (0, 2, 3, 1))
# CLIP Re-ranking
# rank = clip_score(
# prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
# )
# images = images[rank]
pil_images = []
for i in range(len(images)):
im = Image.fromarray((images[i] * 255).astype(np.uint8))
return pil_images
def to_base64(pil_image):
buffered = BytesIO()
pil_image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue())