File size: 3,836 Bytes
f3daba8 0cc889f f3daba8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import sys
import time
import json
import torch
import base64
from PIL import Image
from io import BytesIO
# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
os.environ["SAFETENSORS_FAST_GPU"] = "1"
sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art"))
from seg2art.sstan_models.pix2pix_model import Pix2PixModel
from seg2art.options.test_options import TestOptions
from seg2art.inference_util import get_artwork
import uvicorn
from fastapi import FastAPI, Form
from fastapi.templating import Jinja2Templates
from fastapi.responses import PlainTextResponse, HTMLResponse
from fastapi.requests import Request
from fastapi.staticfiles import StaticFiles
# declare constants
HOST = "0.0.0.0"
PORT = 7860
# FastAPI
app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__)))
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# initialize SEAN model.
opt = TestOptions().parse()
opt.status = "test"
model = Pix2PixModel(opt)
model = model.half() if torch.cuda.is_available() else model
model.eval()
from utils.umap_utils import get_code, load_boundries, modify_code
boundaries = load_boundries()
global current_codes
current_codes = {}
max_user_num = 5
initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code")
initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu"))
def EncodeImage(img_pil):
with BytesIO() as buffer:
img_pil.save(buffer, "jpeg")
image_data = base64.b64encode(buffer.getvalue())
return image_data
def DecodeImage(img_pil):
img_pil = BytesIO(base64.urlsafe_b64decode(img_pil))
img_pil = Image.open(img_pil).convert("RGB")
return img_pil
def process_input(body, random=False):
global current_codes
json_body = json.loads(body.decode("utf-8"))
user_id = json_body["user_id"]
start_time = time.time()
# save current code for different users
if user_id not in current_codes:
current_codes[user_id] = initial_code.clone()
if len(current_codes[user_id]) > max_user_num:
current_codes[user_id] = current_codes[user_id][-max_user_num:]
if random:
# randomize code
domain = json_body["model"]
current_codes[user_id] = get_code(domain, boundaries)
# get input
input_img = DecodeImage(json_body["img"])
try:
move_range = float(json_body["move_range"])
except:
move_range = 0
# set move range to 3 if random is True
move_range = 3 if random else move_range
# print("Input image was received")
# get selected style
domain = json_body["model"]
if move_range != 0:
modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range)
else:
modified_code = current_code.clone()
# inference
result = get_artwork(model, input_img, modified_code)
print("Time Cost: ", time.time() - start_time)
return EncodeImage(result)
@app.get("/", response_class=HTMLResponse)
def root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/check_gpu")
async def check_gpu():
return torch.cuda.is_available()
@app.post("/predict")
async def predict(request: Request):
body = await request.body()
result = process_input(body, random=False)
return result
@app.post("/predict_random")
async def predict_random(request: Request):
body = await request.body()
result = process_input(body, random=True)
return result
if __name__ == "__main__":
uvicorn.run(app, host=HOST, port=PORT, log_level="info")
|