File size: 3,836 Bytes
f3daba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc889f
 
 
f3daba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import time
import json
import torch
import base64
from PIL import Image
from io import BytesIO

# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
os.environ["SAFETENSORS_FAST_GPU"] = "1"

sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art"))
from seg2art.sstan_models.pix2pix_model import Pix2PixModel
from seg2art.options.test_options import TestOptions
from seg2art.inference_util import get_artwork

import uvicorn
from fastapi import FastAPI, Form
from fastapi.templating import Jinja2Templates
from fastapi.responses import PlainTextResponse, HTMLResponse
from fastapi.requests import Request
from fastapi.staticfiles import StaticFiles


# declare constants
HOST = "0.0.0.0"
PORT = 7860
# FastAPI
app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__)))
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")


# initialize SEAN model.
opt = TestOptions().parse()
opt.status = "test"
model = Pix2PixModel(opt)
model = model.half() if torch.cuda.is_available() else model
model.eval()


from utils.umap_utils import get_code, load_boundries, modify_code

boundaries = load_boundries()
global current_codes
current_codes = {}
max_user_num = 5

initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code")
initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu"))


def EncodeImage(img_pil):
    with BytesIO() as buffer:
        img_pil.save(buffer, "jpeg")
        image_data = base64.b64encode(buffer.getvalue())
    return image_data


def DecodeImage(img_pil):
    img_pil = BytesIO(base64.urlsafe_b64decode(img_pil))
    img_pil = Image.open(img_pil).convert("RGB")
    return img_pil


def process_input(body, random=False):
    global current_codes
    json_body = json.loads(body.decode("utf-8"))
    user_id = json_body["user_id"]
    start_time = time.time()

    # save current code for different users
    if user_id not in current_codes:
        current_codes[user_id] = initial_code.clone()
    if len(current_codes[user_id]) > max_user_num:
        current_codes[user_id] = current_codes[user_id][-max_user_num:]

    if random:
        # randomize code
        domain = json_body["model"]
        current_codes[user_id] = get_code(domain, boundaries)

    # get input
    input_img = DecodeImage(json_body["img"])

    try:
        move_range = float(json_body["move_range"])
    except:
        move_range = 0

    # set move range to 3 if random is True
    move_range = 3 if random else move_range
    # print("Input image was received")
    # get selected style
    domain = json_body["model"]
    if move_range != 0:
        modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range)
    else:
        modified_code = current_code.clone()

    # inference
    result = get_artwork(model, input_img, modified_code)
    print("Time Cost: ", time.time() - start_time)
    return EncodeImage(result)


@app.get("/", response_class=HTMLResponse)
def root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.get("/check_gpu")
async def check_gpu():
    return torch.cuda.is_available()

@app.post("/predict")
async def predict(request: Request):
    body = await request.body()
    result = process_input(body, random=False)
    return result


@app.post("/predict_random")
async def predict_random(request: Request):
    body = await request.body()
    result = process_input(body, random=True)
    return result


if __name__ == "__main__":
    uvicorn.run(app, host=HOST, port=PORT, log_level="info")