import os
import sys
import time
import json
import torch
import base64
from PIL import Image
from io import BytesIO

# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
os.environ["SAFETENSORS_FAST_GPU"] = "1"

sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art"))
from seg2art.sstan_models.pix2pix_model import Pix2PixModel
from seg2art.options.test_options import TestOptions
from seg2art.inference_util import get_artwork

import uvicorn
from fastapi import FastAPI, Form
from fastapi.templating import Jinja2Templates
from fastapi.responses import PlainTextResponse, HTMLResponse
from fastapi.requests import Request
from fastapi.staticfiles import StaticFiles


# declare constants
HOST = "0.0.0.0"
PORT = 7860
# FastAPI
app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__)))
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")


# initialize SEAN model.
opt = TestOptions().parse()
opt.status = "test"
model = Pix2PixModel(opt)
model = model.half() if torch.cuda.is_available() else model
model.eval()


from utils.umap_utils import get_code, load_boundries, modify_code

boundaries = load_boundries()
global current_codes
current_codes = {}
max_user_num = 5

initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code")
initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu"))


def EncodeImage(img_pil):
    with BytesIO() as buffer:
        img_pil.save(buffer, "jpeg")
        image_data = base64.b64encode(buffer.getvalue())
    return image_data


def DecodeImage(img_pil):
    img_pil = BytesIO(base64.urlsafe_b64decode(img_pil))
    img_pil = Image.open(img_pil).convert("RGB")
    return img_pil


def process_input(body, random=False):
    global current_codes
    json_body = json.loads(body.decode("utf-8"))
    user_id = json_body["user_id"]
    start_time = time.time()

    # save current code for different users
    if user_id not in current_codes:
        current_codes[user_id] = initial_code.clone()
    if len(current_codes[user_id]) > max_user_num:
        current_codes[user_id] = current_codes[user_id][-max_user_num:]

    if random:
        # randomize code
        domain = json_body["model"]
        current_codes[user_id] = get_code(domain, boundaries)

    # get input
    input_img = DecodeImage(json_body["img"])

    try:
        move_range = float(json_body["move_range"])
    except:
        move_range = 0

    # set move range to 3 if random is True
    move_range = 3 if random else move_range
    # print("Input image was received")
    # get selected style
    domain = json_body["model"]
    if move_range != 0:
        modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range)
    else:
        modified_code = current_code.clone()

    # inference
    result = get_artwork(model, input_img, modified_code)
    print("Time Cost: ", time.time() - start_time)
    return EncodeImage(result)


@app.get("/", response_class=HTMLResponse)
def root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.post("/predict")
async def predict(request: Request):
    body = await request.body()
    result = process_input(body, random=False)
    return result


@app.post("/predict_random")
async def predict_random(request: Request):
    body = await request.body()
    result = process_input(body, random=True)
    return result


if __name__ == "__main__":
    uvicorn.run(app, host=HOST, port=PORT, log_level="info")