# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(os.path.join(os.path.dirname(__file__), "../gradio_demo"))

import cv2
import time
import torch
import mimetypes
import subprocess
import numpy as np
from typing import List
from cog import BasePredictor, Input, Path

import PIL
from PIL import Image

import diffusers
from diffusers import LCMScheduler
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel

from model_util import get_torch_device
from insightface.app import FaceAnalysis
from transformers import CLIPImageProcessor
from controlnet_util import openpose, get_depth_map, get_canny_image

from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
)
from pipeline_stable_diffusion_xl_instantid_full import (
    StableDiffusionXLInstantIDPipeline,
    draw_kps,
)

mimetypes.add_type("image/webp", ".webp")

# GPU global variables
DEVICE = get_torch_device()
DTYPE = torch.float16 if str(DEVICE).__contains__("cuda") else torch.float32

# for `ip-adapter`, `ControlNetModel`, and `stable-diffusion-xl-base-1.0`
CHECKPOINTS_CACHE = "./checkpoints"
CHECKPOINTS_URL = "https://weights.replicate.delivery/default/InstantID/checkpoints.tar"

# for `models/antelopev2`
MODELS_CACHE = "./models"
MODELS_URL = "https://weights.replicate.delivery/default/InstantID/models.tar"

# for the safety checker
SAFETY_CACHE = "./safety-cache"
FEATURE_EXTRACTOR = "./feature-extractor"
SAFETY_URL = "https://weights.replicate.delivery/default/playgroundai/safety-cache.tar"

SDXL_NAME_TO_PATHLIKE = {
    # These are all huggingface models that we host via gcp + pget
    "stable-diffusion-xl-base-1.0": {
        "slug": "stabilityai/stable-diffusion-xl-base-1.0",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stabilityai--stable-diffusion-xl-base-1.0.tar",
        "path": "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0",
    },
    "afrodite-xl-v2": {
        "slug": "stablediffusionapi/afrodite-xl-v2",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--afrodite-xl-v2.tar",
        "path": "checkpoints/models--stablediffusionapi--afrodite-xl-v2",
    },
    "albedobase-xl-20": {
        "slug": "stablediffusionapi/albedobase-xl-20",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-20.tar",
        "path": "checkpoints/models--stablediffusionapi--albedobase-xl-20",
    },
    "albedobase-xl-v13": {
        "slug": "stablediffusionapi/albedobase-xl-v13",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-v13.tar",
        "path": "checkpoints/models--stablediffusionapi--albedobase-xl-v13",
    },
    "animagine-xl-30": {
        "slug": "stablediffusionapi/animagine-xl-30",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--animagine-xl-30.tar",
        "path": "checkpoints/models--stablediffusionapi--animagine-xl-30",
    },
    "anime-art-diffusion-xl": {
        "slug": "stablediffusionapi/anime-art-diffusion-xl",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-art-diffusion-xl.tar",
        "path": "checkpoints/models--stablediffusionapi--anime-art-diffusion-xl",
    },
    "anime-illust-diffusion-xl": {
        "slug": "stablediffusionapi/anime-illust-diffusion-xl",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-illust-diffusion-xl.tar",
        "path": "checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl",
    },
    "dreamshaper-xl": {
        "slug": "stablediffusionapi/dreamshaper-xl",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dreamshaper-xl.tar",
        "path": "checkpoints/models--stablediffusionapi--dreamshaper-xl",
    },
    "dynavision-xl-v0610": {
        "slug": "stablediffusionapi/dynavision-xl-v0610",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dynavision-xl-v0610.tar",
        "path": "checkpoints/models--stablediffusionapi--dynavision-xl-v0610",
    },
    "guofeng4-xl": {
        "slug": "stablediffusionapi/guofeng4-xl",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--guofeng4-xl.tar",
        "path": "checkpoints/models--stablediffusionapi--guofeng4-xl",
    },
    "juggernaut-xl-v8": {
        "slug": "stablediffusionapi/juggernaut-xl-v8",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--juggernaut-xl-v8.tar",
        "path": "checkpoints/models--stablediffusionapi--juggernaut-xl-v8",
    },
    "nightvision-xl-0791": {
        "slug": "stablediffusionapi/nightvision-xl-0791",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--nightvision-xl-0791.tar",
        "path": "checkpoints/models--stablediffusionapi--nightvision-xl-0791",
    },
    "omnigen-xl": {
        "slug": "stablediffusionapi/omnigen-xl",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--omnigen-xl.tar",
        "path": "checkpoints/models--stablediffusionapi--omnigen-xl",
    },
    "pony-diffusion-v6-xl": {
        "slug": "stablediffusionapi/pony-diffusion-v6-xl",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--pony-diffusion-v6-xl.tar",
        "path": "checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl",
    },
    "protovision-xl-high-fidel": {
        "slug": "stablediffusionapi/protovision-xl-high-fidel",
        "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--protovision-xl-high-fidel.tar",
        "path": "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel",
    },
    "RealVisXL_V3.0_Turbo": {
        "slug": "SG161222/RealVisXL_V3.0_Turbo",
        "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V3.0_Turbo.tar",
        "path": "checkpoints/models--SG161222--RealVisXL_V3.0_Turbo",
    },
    "RealVisXL_V4.0_Lightning": {
        "slug": "SG161222/RealVisXL_V4.0_Lightning",
        "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V4.0_Lightning.tar",
        "path": "checkpoints/models--SG161222--RealVisXL_V4.0_Lightning",
    },
}


def convert_from_cv2_to_image(img: np.ndarray) -> Image:
    return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))


def convert_from_image_to_cv2(img: Image) -> np.ndarray:
    return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)


def resize_img(
    input_image,
    max_side=1280,
    min_side=1024,
    size=None,
    pad_to_max_side=False,
    mode=PIL.Image.BILINEAR,
    base_pixel_number=64,
):
    w, h = input_image.size
    if size is not None:
        w_resize_new, h_resize_new = size
    else:
        ratio = min_side / min(h, w)
        w, h = round(ratio * w), round(ratio * h)
        ratio = max_side / max(h, w)
        input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
        w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
        h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
    input_image = input_image.resize([w_resize_new, h_resize_new], mode)

    if pad_to_max_side:
        res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
        offset_x = (max_side - w_resize_new) // 2
        offset_y = (max_side - h_resize_new) // 2
        res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = (
            np.array(input_image)
        )
        input_image = Image.fromarray(res)
    return input_image


def download_weights(url, dest):
    start = time.time()
    print("[!] Initiating download from URL: ", url)
    print("[~] Destination path: ", dest)
    command = ["pget", "-vf", url, dest]
    if ".tar" in url:
        command.append("-x")
    try:
        subprocess.check_call(command, close_fds=False)
    except subprocess.CalledProcessError as e:
        print(
            f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}."
        )
        raise
    print("[+] Download completed in: ", time.time() - start, "seconds")


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""

        if not os.path.exists(CHECKPOINTS_CACHE):
            download_weights(CHECKPOINTS_URL, CHECKPOINTS_CACHE)

        if not os.path.exists(MODELS_CACHE):
            download_weights(MODELS_URL, MODELS_CACHE)

        self.face_detection_input_width, self.face_detection_input_height = 640, 640
        self.app = FaceAnalysis(
            name="antelopev2",
            root="./",
            providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
        )
        self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height))

        # Path to InstantID models
        self.face_adapter = f"./checkpoints/ip-adapter.bin"
        controlnet_path = f"./checkpoints/ControlNetModel"

        # Load pipeline face ControlNetModel
        self.controlnet_identitynet = ControlNetModel.from_pretrained(
            controlnet_path,
            torch_dtype=DTYPE,
            cache_dir=CHECKPOINTS_CACHE,
            local_files_only=True,
        )
        self.setup_extra_controlnets()

        self.load_weights("stable-diffusion-xl-base-1.0")
        self.setup_safety_checker()

    def setup_safety_checker(self):
        print(f"[~] Seting up safety checker")

        if not os.path.exists(SAFETY_CACHE):
            download_weights(SAFETY_URL, SAFETY_CACHE)

        self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
            SAFETY_CACHE,
            torch_dtype=DTYPE,
            local_files_only=True,
        )
        self.safety_checker.to(DEVICE)
        self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)

    def run_safety_checker(self, image):
        safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
            DEVICE
        )
        np_image = np.array(image)
        image, has_nsfw_concept = self.safety_checker(
            images=[np_image],
            clip_input=safety_checker_input.pixel_values.to(DTYPE),
        )
        return image, has_nsfw_concept

    def load_weights(self, sdxl_weights):
        self.base_weights = sdxl_weights
        weights_info = SDXL_NAME_TO_PATHLIKE[self.base_weights]

        download_url = weights_info["url"]
        path_to_weights_dir = weights_info["path"]
        if not os.path.exists(path_to_weights_dir):
            download_weights(download_url, path_to_weights_dir)

        is_hugging_face_model = "slug" in weights_info.keys()
        path_to_weights_file = os.path.join(
            path_to_weights_dir,
            weights_info.get("file", ""),
        )

        print(f"[~] Loading new SDXL weights: {path_to_weights_file}")
        if is_hugging_face_model:
            self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
                weights_info["slug"],
                controlnet=[self.controlnet_identitynet],
                torch_dtype=DTYPE,
                cache_dir=CHECKPOINTS_CACHE,
                local_files_only=True,
                safety_checker=None,
                feature_extractor=None,
            )
            self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(
                self.pipe.scheduler.config
            )
        else:  # e.g. .safetensors, NOTE: This functionality is not being used right now
            self.pipe.from_single_file(
                path_to_weights_file,
                controlnet=self.controlnet_identitynet,
                torch_dtype=DTYPE,
                cache_dir=CHECKPOINTS_CACHE,
            )

        self.pipe.load_ip_adapter_instantid(self.face_adapter)
        self.setup_lcm_lora()
        self.pipe.cuda()

    def setup_lcm_lora(self):
        print(f"[~] Seting up LCM (just in case)")

        lcm_lora_key = "models--latent-consistency--lcm-lora-sdxl"
        lcm_lora_path = f"checkpoints/{lcm_lora_key}"
        if not os.path.exists(lcm_lora_path):
            download_weights(
                f"https://weights.replicate.delivery/default/InstantID/{lcm_lora_key}.tar",
                lcm_lora_path,
            )
        self.pipe.load_lora_weights(
            "latent-consistency/lcm-lora-sdxl",
            cache_dir=CHECKPOINTS_CACHE,
            local_files_only=True,
            weight_name="pytorch_lora_weights.safetensors",
        )
        self.pipe.disable_lora()

    def setup_extra_controlnets(self):
        print(f"[~] Seting up pose, canny, depth ControlNets")

        controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
        controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
        controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"

        for controlnet_key in [
            "models--diffusers--controlnet-canny-sdxl-1.0",
            "models--diffusers--controlnet-depth-sdxl-1.0-small",
            "models--thibaud--controlnet-openpose-sdxl-1.0",
        ]:
            controlnet_path = f"checkpoints/{controlnet_key}"
            if not os.path.exists(controlnet_path):
                download_weights(
                    f"https://weights.replicate.delivery/default/InstantID/{controlnet_key}.tar",
                    controlnet_path,
                )

        controlnet_pose = ControlNetModel.from_pretrained(
            controlnet_pose_model,
            torch_dtype=DTYPE,
            cache_dir=CHECKPOINTS_CACHE,
            local_files_only=True,
        ).to(DEVICE)
        controlnet_canny = ControlNetModel.from_pretrained(
            controlnet_canny_model,
            torch_dtype=DTYPE,
            cache_dir=CHECKPOINTS_CACHE,
            local_files_only=True,
        ).to(DEVICE)
        controlnet_depth = ControlNetModel.from_pretrained(
            controlnet_depth_model,
            torch_dtype=DTYPE,
            cache_dir=CHECKPOINTS_CACHE,
            local_files_only=True,
        ).to(DEVICE)

        self.controlnet_map = {
            "pose": controlnet_pose,
            "canny": controlnet_canny,
            "depth": controlnet_depth,
        }
        self.controlnet_map_fn = {
            "pose": openpose,
            "canny": get_canny_image,
            "depth": get_depth_map,
        }

    def generate_image(
        self,
        face_image_path,
        pose_image_path,
        prompt,
        negative_prompt,
        num_steps,
        identitynet_strength_ratio,
        adapter_strength_ratio,
        pose_strength,
        canny_strength,
        depth_strength,
        controlnet_selection,
        guidance_scale,
        seed,
        scheduler,
        enable_LCM,
        enhance_face_region,
        num_images_per_prompt,
    ):
        if enable_LCM:
            self.pipe.enable_lora()
            self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
        else:
            self.pipe.disable_lora()
            scheduler_class_name = scheduler.split("-")[0]

            add_kwargs = {}
            if len(scheduler.split("-")) > 1:
                add_kwargs["use_karras_sigmas"] = True
            if len(scheduler.split("-")) > 2:
                add_kwargs["algorithm_type"] = "sde-dpmsolver++"
            scheduler = getattr(diffusers, scheduler_class_name)
            self.pipe.scheduler = scheduler.from_config(
                self.pipe.scheduler.config,
                **add_kwargs,
            )

        if face_image_path is None:
            raise Exception(
                f"Cannot find any input face `image`! Please upload the face `image`"
            )

        face_image = load_image(face_image_path)
        face_image = resize_img(face_image)
        face_image_cv2 = convert_from_image_to_cv2(face_image)
        height, width, _ = face_image_cv2.shape

        # Extract face features
        face_info = self.app.get(face_image_cv2)

        if len(face_info) == 0:
            raise Exception(
                "Face detector could not find a face in the `image`. Please use a different `image` as input."
            )

        face_info = sorted(
            face_info,
            key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
        )[
            -1
        ]  # only use the maximum face
        face_emb = face_info["embedding"]
        face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])

        img_controlnet = face_image
        if pose_image_path is not None:
            pose_image = load_image(pose_image_path)
            pose_image = resize_img(pose_image, max_side=1024)
            img_controlnet = pose_image
            pose_image_cv2 = convert_from_image_to_cv2(pose_image)

            face_info = self.app.get(pose_image_cv2)

            if len(face_info) == 0:
                raise Exception(
                    "Face detector could not find a face in the `pose_image`. Please use a different `pose_image` as input."
                )

            face_info = face_info[-1]
            face_kps = draw_kps(pose_image, face_info["kps"])

            width, height = face_kps.size

        if enhance_face_region:
            control_mask = np.zeros([height, width, 3])
            x1, y1, x2, y2 = face_info["bbox"]
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            control_mask[y1:y2, x1:x2] = 255
            control_mask = Image.fromarray(control_mask.astype(np.uint8))
        else:
            control_mask = None

        if len(controlnet_selection) > 0:
            controlnet_scales = {
                "pose": pose_strength,
                "canny": canny_strength,
                "depth": depth_strength,
            }
            self.pipe.controlnet = MultiControlNetModel(
                [self.controlnet_identitynet]
                + [self.controlnet_map[s] for s in controlnet_selection]
            )
            control_scales = [float(identitynet_strength_ratio)] + [
                controlnet_scales[s] for s in controlnet_selection
            ]
            control_images = [face_kps] + [
                self.controlnet_map_fn[s](img_controlnet).resize((width, height))
                for s in controlnet_selection
            ]
        else:
            self.pipe.controlnet = self.controlnet_identitynet
            control_scales = float(identitynet_strength_ratio)
            control_images = face_kps

        generator = torch.Generator(device=DEVICE).manual_seed(seed)

        print("Start inference...")
        print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")

        self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
        images = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image_embeds=face_emb,
            image=control_images,
            control_mask=control_mask,
            controlnet_conditioning_scale=control_scales,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width,
            generator=generator,
            num_images_per_prompt=num_images_per_prompt,
        ).images

        return images

    def predict(
        self,
        image: Path = Input(
            description="Input face image",
        ),
        pose_image: Path = Input(
            description="(Optional) reference pose image",
            default=None,
        ),
        prompt: str = Input(
            description="Input prompt",
            default="a person",
        ),
        negative_prompt: str = Input(
            description="Input Negative Prompt",
            default="",
        ),
        sdxl_weights: str = Input(
            description="Pick which base weights you want to use",
            default="stable-diffusion-xl-base-1.0",
            choices=[
                "stable-diffusion-xl-base-1.0",
                "juggernaut-xl-v8",
                "afrodite-xl-v2",
                "albedobase-xl-20",
                "albedobase-xl-v13",
                "animagine-xl-30",
                "anime-art-diffusion-xl",
                "anime-illust-diffusion-xl",
                "dreamshaper-xl",
                "dynavision-xl-v0610",
                "guofeng4-xl",
                "nightvision-xl-0791",
                "omnigen-xl",
                "pony-diffusion-v6-xl",
                "protovision-xl-high-fidel",
                "RealVisXL_V3.0_Turbo",
                "RealVisXL_V4.0_Lightning",
            ],
        ),
        face_detection_input_width: int = Input(
            description="Width of the input image for face detection",
            default=640,
            ge=640,
            le=4096,
        ),
        face_detection_input_height: int = Input(
            description="Height of the input image for face detection",
            default=640,
            ge=640,
            le=4096,
        ),
        scheduler: str = Input(
            description="Scheduler",
            choices=[
                "DEISMultistepScheduler",
                "HeunDiscreteScheduler",
                "EulerDiscreteScheduler",
                "DPMSolverMultistepScheduler",
                "DPMSolverMultistepScheduler-Karras",
                "DPMSolverMultistepScheduler-Karras-SDE",
            ],
            default="EulerDiscreteScheduler",
        ),
        num_inference_steps: int = Input(
            description="Number of denoising steps",
            default=30,
            ge=1,
            le=500,
        ),
        guidance_scale: float = Input(
            description="Scale for classifier-free guidance",
            default=7.5,
            ge=1,
            le=50,
        ),
        ip_adapter_scale: float = Input(
            description="Scale for image adapter strength (for detail)",  # adapter_strength_ratio
            default=0.8,
            ge=0,
            le=1.5,
        ),
        controlnet_conditioning_scale: float = Input(
            description="Scale for IdentityNet strength (for fidelity)",  # identitynet_strength_ratio
            default=0.8,
            ge=0,
            le=1.5,
        ),
        enable_pose_controlnet: bool = Input(
            description="Enable Openpose ControlNet, overrides strength if set to false",
            default=True,
        ),
        pose_strength: float = Input(
            description="Openpose ControlNet strength, effective only if `enable_pose_controlnet` is true",
            default=0.4,
            ge=0,
            le=1,
        ),
        enable_canny_controlnet: bool = Input(
            description="Enable Canny ControlNet, overrides strength if set to false",
            default=False,
        ),
        canny_strength: float = Input(
            description="Canny ControlNet strength, effective only if `enable_canny_controlnet` is true",
            default=0.3,
            ge=0,
            le=1,
        ),
        enable_depth_controlnet: bool = Input(
            description="Enable Depth ControlNet, overrides strength if set to false",
            default=False,
        ),
        depth_strength: float = Input(
            description="Depth ControlNet strength, effective only if `enable_depth_controlnet` is true",
            default=0.5,
            ge=0,
            le=1,
        ),
        enable_lcm: bool = Input(
            description="Enable Fast Inference with LCM (Latent Consistency Models) - speeds up inference steps, trade-off is the quality of the generated image. Performs better with close-up portrait face images",
            default=False,
        ),
        lcm_num_inference_steps: int = Input(
            description="Only used when `enable_lcm` is set to True, Number of denoising steps when using LCM",
            default=5,
            ge=1,
            le=10,
        ),
        lcm_guidance_scale: float = Input(
            description="Only used when `enable_lcm` is set to True, Scale for classifier-free guidance when using LCM",
            default=1.5,
            ge=1,
            le=20,
        ),
        enhance_nonface_region: bool = Input(
            description="Enhance non-face region", default=True
        ),
        output_format: str = Input(
            description="Format of the output images",
            choices=["webp", "jpg", "png"],
            default="webp",
        ),
        output_quality: int = Input(
            description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.",
            default=80,
            ge=0,
            le=100,
        ),
        seed: int = Input(
            description="Random seed. Leave blank to randomize the seed",
            default=None,
        ),
        num_outputs: int = Input(
            description="Number of images to output",
            default=1,
            ge=1,
            le=8,
        ),
        disable_safety_checker: bool = Input(
            description="Disable safety checker for generated images",
            default=False,
        ),
    ) -> List[Path]:
        """Run a single prediction on the model"""

        # If no seed is provided, generate a random seed
        if seed is None:
            seed = int.from_bytes(os.urandom(2), "big")
        print(f"Using seed: {seed}")

        # Load the weights if they are different from the base weights
        if sdxl_weights != self.base_weights:
            self.load_weights(sdxl_weights)

        # Resize the output if the provided dimensions are different from the current ones
        if self.face_detection_input_width != face_detection_input_width or self.face_detection_input_height != face_detection_input_height:
            print(f"[!] Resizing output to {face_detection_input_width}x{face_detection_input_height}")
            self.face_detection_input_width = face_detection_input_width
            self.face_detection_input_height = face_detection_input_height
            self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height))

        # Set up ControlNet selection and their respective strength values (if any)
        controlnet_selection = []
        if pose_strength > 0 and enable_pose_controlnet:
            controlnet_selection.append("pose")
        if canny_strength > 0 and enable_canny_controlnet:
            controlnet_selection.append("canny")
        if depth_strength > 0 and enable_depth_controlnet:
            controlnet_selection.append("depth")

        # Switch to LCM inference steps and guidance scale if LCM is enabled
        if enable_lcm:
            num_inference_steps = lcm_num_inference_steps
            guidance_scale = lcm_guidance_scale

        # Generate
        images = self.generate_image(
            face_image_path=str(image),
            pose_image_path=str(pose_image) if pose_image else None,
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_steps=num_inference_steps,
            identitynet_strength_ratio=controlnet_conditioning_scale,
            adapter_strength_ratio=ip_adapter_scale,
            pose_strength=pose_strength,
            canny_strength=canny_strength,
            depth_strength=depth_strength,
            controlnet_selection=controlnet_selection,
            scheduler=scheduler,
            guidance_scale=guidance_scale,
            seed=seed,
            enable_LCM=enable_lcm,
            enhance_face_region=enhance_nonface_region,
            num_images_per_prompt=num_outputs,
        )

        # Save the generated images and check for NSFW content
        output_paths = []
        for i, output_image in enumerate(images):
            if not disable_safety_checker:
                _, has_nsfw_content_list = self.run_safety_checker(output_image)
                has_nsfw_content = any(has_nsfw_content_list)
                print(f"NSFW content detected: {has_nsfw_content}")
                if has_nsfw_content:
                    raise Exception(
                        "NSFW content detected. Try running it again, or try a different prompt."
                    )

            extension = output_format.lower()
            extension = "jpeg" if extension == "jpg" else extension
            output_path = f"/tmp/out_{i}.{extension}"

            print(f"[~] Saving to {output_path}...")
            print(f"[~] Output format: {extension.upper()}")
            if output_format != "png":
                print(f"[~] Output quality: {output_quality}")

            save_params = {"format": extension.upper()}
            if output_format != "png":
                save_params["quality"] = output_quality
                save_params["optimize"] = True

            output_image.save(output_path, **save_params)
            output_paths.append(Path(output_path))
        return output_paths