File size: 1,758 Bytes
8ce5e2d
 
 
 
 
 
 
 
 
 
 
 
 
b46e9dc
e65b549
9f824d9
8ce5e2d
 
 
 
e65b549
 
 
72a955a
 
e65b549
8ce5e2d
 
9b4f999
e65b549
5cb75cc
e65b549
8ce5e2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46e9dc
8ce5e2d
 
 
 
 
b46e9dc
8ce5e2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import sys
import base64
from io import BytesIO
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from torch import nn
from fastapi import FastAPI
import numpy as np
from PIL import Image

from dalle.models import Dalle
import logging
import streamlit as st


print("Loading models...")
app = FastAPI()

from huggingface_hub import hf_hub_download

logging.info("Start downloading")
full_dict_path = hf_hub_download(repo_id="ml6team/logo-generator", filename="full_dict_new.ckpt",
                                 use_auth_token=st.secrets["model_download"])
logging.info("End downloading")

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Dalle.from_pretrained("minDALL-E/1.3B")

model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu')))
model.to(device=device)

print("Models loaded !")


@app.get("/")
def read_root():
    return {"minDALL-E!"}


@app.get("/{generate}")
def generate(prompt):
    images = sample(prompt)
    images = [to_base64(image) for image in images]
    return {"images": images}


def sample(prompt):
    # Sampling
    logging.info("starting sampling")
    images = (
        model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device)
        .cpu()
        .numpy()
    )
    logging.info("sampling succeeded")
    images = np.transpose(images, (0, 2, 3, 1))

    
    pil_images = []
    for i in range(len(images)):
        im = Image.fromarray((images[i] * 255).astype(np.uint8))
        pil_images.append(im)
    
    return pil_images


def to_base64(pil_image):
    buffered = BytesIO()
    pil_image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue())