import os import sys import time import json import torch import base64 from PIL import Image from io import BytesIO # set CUDA_MODULE_LOADING=LAZY to speed up the serverless function os.environ["CUDA_MODULE_LOADING"] = "LAZY" # set SAFETENSORS_FAST_GPU=1 to speed up the serverless function os.environ["SAFETENSORS_FAST_GPU"] = "1" sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art")) from seg2art.sstan_models.pix2pix_model import Pix2PixModel from seg2art.options.test_options import TestOptions from seg2art.inference_util import get_artwork import uvicorn from fastapi import FastAPI, Form from fastapi.templating import Jinja2Templates from fastapi.responses import PlainTextResponse, HTMLResponse from fastapi.requests import Request from fastapi.staticfiles import StaticFiles # declare constants HOST = "0.0.0.0" PORT = 7860 # FastAPI app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__))) app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # initialize SEAN model. opt = TestOptions().parse() opt.status = "test" model = Pix2PixModel(opt) model = model.half() if torch.cuda.is_available() else model model.eval() from utils.umap_utils import get_code, load_boundries, modify_code boundaries = load_boundries() global current_codes current_codes = {} max_user_num = 5 initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code") initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu")) def EncodeImage(img_pil): with BytesIO() as buffer: img_pil.save(buffer, "jpeg") image_data = base64.b64encode(buffer.getvalue()) return image_data def DecodeImage(img_pil): img_pil = BytesIO(base64.urlsafe_b64decode(img_pil)) img_pil = Image.open(img_pil).convert("RGB") return img_pil def process_input(body, random=False): global current_codes json_body = json.loads(body.decode("utf-8")) user_id = json_body["user_id"] start_time = time.time() # save current code for different users if user_id not in current_codes: current_codes[user_id] = initial_code.clone() if len(current_codes[user_id]) > max_user_num: current_codes[user_id] = current_codes[user_id][-max_user_num:] if random: # randomize code domain = json_body["model"] current_codes[user_id] = get_code(domain, boundaries) # get input input_img = DecodeImage(json_body["img"]) try: move_range = float(json_body["move_range"]) except: move_range = 0 # set move range to 3 if random is True move_range = 3 if random else move_range # print("Input image was received") # get selected style domain = json_body["model"] if move_range != 0: modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range) else: modified_code = current_code.clone() # inference result = get_artwork(model, input_img, modified_code) print("Time Cost: ", time.time() - start_time) return EncodeImage(result) @app.get("/", response_class=HTMLResponse) def root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.get("/check_gpu") async def check_gpu(): return torch.cuda.is_available() @app.post("/predict") async def predict(request: Request): body = await request.body() result = process_input(body, random=False) return result @app.post("/predict_random") async def predict_random(request: Request): body = await request.body() result = process_input(body, random=True) return result if __name__ == "__main__": uvicorn.run(app, host=HOST, port=PORT, log_level="info")