|
import os |
|
import sys |
|
import base64 |
|
from io import BytesIO |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
import torch |
|
from torch import nn |
|
from fastapi import FastAPI |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from dalle.models import Dalle |
|
import logging |
|
import streamlit as st |
|
|
|
|
|
print("Loading models...") |
|
app = FastAPI() |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
logging.info("Start downloading") |
|
full_dict_path = hf_hub_download(repo_id="ml6team/logo-generator", filename="full_dict_new.ckpt", |
|
use_auth_token=st.secrets["model_download"]) |
|
logging.info("End downloading") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = Dalle.from_pretrained("minDALL-E/1.3B") |
|
|
|
model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu'))) |
|
model.to(device=device) |
|
|
|
print("Models loaded !") |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"minDALL-E!"} |
|
|
|
|
|
@app.get("/{generate}") |
|
def generate(prompt): |
|
images = sample(prompt) |
|
images = [to_base64(image) for image in images] |
|
return {"images": images} |
|
|
|
|
|
def sample(prompt): |
|
|
|
logging.info("starting sampling") |
|
images = ( |
|
model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device) |
|
.cpu() |
|
.numpy() |
|
) |
|
logging.info("sampling succeeded") |
|
images = np.transpose(images, (0, 2, 3, 1)) |
|
|
|
|
|
pil_images = [] |
|
for i in range(len(images)): |
|
im = Image.fromarray((images[i] * 255).astype(np.uint8)) |
|
pil_images.append(im) |
|
|
|
return pil_images |
|
|
|
|
|
def to_base64(pil_image): |
|
buffered = BytesIO() |
|
pil_image.save(buffered, format="JPEG") |
|
return base64.b64encode(buffered.getvalue()) |