import random

import gradio as gr
import time, os
import numpy as np
import torch
from tqdm import tqdm, trange
from PIL import Image


def random_clip(x, min=-1.3, max=1.3):
    if isinstance(x, np.ndarray):
        return np.clip(x, min, max)
    elif isinstance(x, torch.Tensor):
        return torch.clip(x, min, max)
    else:
        raise TypeError(f"type of x is {type(x)}")


class Sampler:
    def __init__(self, device, normal_t):
        self.device = device
        self.total_step = 1000
        self.normal_t = normal_t

        self.afas_cumprod, self.betas = self.get_afa_bars("scaled_linear",  # cosine,linear,scaled_linear
                                                          self.total_step)
        self.afas_cumprod = torch.Tensor(self.afas_cumprod).to(self.device)
        self.betas = torch.Tensor(self.betas).to(self.device)

    def betas_for_alpha_bar(self, num_diffusion_timesteps, alpha_bar, max_beta=0.999):
        """
        Create a beta schedule that discretizes the given alpha_t_bar function,
        which defines the cumulative product of (1-beta) over time from t = [0,1].

        :param num_diffusion_timesteps: the number of betas to produce.
        :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                          produces the cumulative product of (1-beta) up to that
                          part of the diffusion process.
        :param max_beta: the maximum beta to use; use values lower than 1 to
                         prevent singularities.
        """
        betas = []
        for i in range(num_diffusion_timesteps):
            t1 = i / num_diffusion_timesteps
            t2 = (i + 1) / num_diffusion_timesteps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
        return np.array(betas)

    def get_named_beta_schedule(self, schedule_name, num_diffusion_timesteps):
        """
        Get a pre-defined beta schedule for the given name.

        The beta schedule library consists of beta schedules which remain similar
        in the limit of num_diffusion_timesteps.
        Beta schedules may be added, but should not be removed or changed once
        they are committed to maintain backwards compatibility.
        """
        if schedule_name == "linear":
            # Linear schedule from Ho et al, extended to work for any number of
            # diffusion steps.
            scale = 1000 / num_diffusion_timesteps
            beta_start = scale * 0.0001
            beta_end = scale * 0.02
            return np.linspace(
                beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
            )
        elif schedule_name == "scaled_linear":
            scale = 1000 / num_diffusion_timesteps
            beta_start = scale * 0.0001
            beta_end = scale * 0.02
            return np.linspace(
                beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
        elif schedule_name == "cosine":
            return self.betas_for_alpha_bar(
                num_diffusion_timesteps,
                lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2,
            )
        else:
            raise NotImplementedError(f"unknown beta schedule: {schedule_name}")

    def get_afa_bars(self, beta_schedule_name, total_step):
        """
        生成afa bar的列表,列表长度为total_step
        :param beta_schedule_name: beta_schedule
        :return: afa_bars和betas
        """

        # if linear:
        #     # 线性
        #     betas = np.linspace(1e-5, 0.1, self.total_step)
        #
        # else:
        #     # sigmoid
        #     betas = np.linspace(-6, 6, self.total_step)
        #     betas = 1 / (1 + np.exp(betas)) * (afa_max - afa_min) + afa_min
        betas = self.get_named_beta_schedule(schedule_name=beta_schedule_name,
                                             num_diffusion_timesteps=total_step)

        afas = 1 - betas
        afas_cumprod = np.cumprod(afas)
        # afas_cumprod = np.concatenate((np.array([1]), afas_cumprod[:-1]), axis=0)
        return afas_cumprod, betas

    # 重全噪声开始
    @torch.no_grad()
    def sample_loop(self, model, vae_middle_c, batch_size, step, eta, shape=(32, 32)):
        pass

    def apple_noise(self, data, step):
        """
        添加噪声,返回xt和噪声
        :param data: 数据,潜空间
        :param step: 选择的步数
        :return:
        """
        data = data.to(self.device)

        noise = torch.randn(size=data.shape).to(self.device)
        afa_bar_t = self.afas_cumprod[step - 1]
        x_t = torch.sqrt(afa_bar_t) * data + torch.sqrt(1 - afa_bar_t) * noise
        return x_t

    # 图生图
    @torch.no_grad()
    def sample_loop_img2img(self, input_img, model, vae_middle_c, batch_size, step, eta):
        pass

    @torch.no_grad()
    def decode_img(self, vae, x0):
        x0 = vae.decoder(x0)
        res = x0.cpu().numpy()
        if vae.middle_c == 8:
            res = (res + 1) * 127.5
        else:
            res = res * 255
        res = np.transpose(res, [0, 2, 3, 1])  # RGB
        res = np.clip(res, 0, 255)
        res = np.array(res, dtype=np.uint8)
        return res

    @torch.no_grad()
    def encode_img(self, vae, x0):
        mu, _ = vae.encoder(x0)
        return mu


class DDIMSampler(Sampler):
    def __init__(self, device, normal_t):
        super(DDIMSampler, self).__init__(device, normal_t)

        # self.afas_cumprod, self.betas = self.get_afa_bars("scaled_linear",
        #                                                   self.total_step)  # cosine,linear,scaled_linear
        # self.afas_cumprod = torch.Tensor(self.afas_cumprod).to(self.device)
        # self.betas = torch.Tensor(self.betas).to(self.device)

    @torch.no_grad()
    def sample(self, model, x, t, next_t, eta):
        """

        :param model:
        :param x:
        :param t: 属于[1,1000]
        :return:
        """
        t_ = torch.ones((x.shape[0], 1)) * t
        t_ = t_.to(self.device)
        if self.normal_t:
            t_ = t_ / self.total_step
        epsilon = model(x, t_)
        # 把t转成index
        t = int(t - 1)
        next_t = int(next_t - 1)
        if t > 1:
            # pred_x0=(x-sqrt(1-afa_t_bar)ε)/(sqrt(afa_t_bar))
            prede_x0 = (x - torch.sqrt(1 - self.afas_cumprod[t]) * epsilon) / torch.sqrt(self.afas_cumprod[t])
            x_t_1 = torch.sqrt(self.afas_cumprod[next_t]) * prede_x0
            delta = eta * torch.sqrt((1 - self.afas_cumprod[next_t]) / (1 - self.afas_cumprod[t])) * torch.sqrt(
                1 - self.afas_cumprod[t] / self.afas_cumprod[next_t])
            x_t_1 = x_t_1 + torch.sqrt(1 - self.afas_cumprod[next_t] - delta ** 2) * epsilon
            x_t_1 = delta * random_clip(torch.randn_like(x)) + x_t_1
        else:
            coeff = self.betas[t] / (torch.sqrt(1 - self.afas_cumprod[t]))  # + 1e-5
            x_t_1 = (1 / torch.sqrt(1 - self.betas[t])) * (x - coeff * epsilon)

        return x_t_1

    @torch.no_grad()
    def sample_loop(self, model, vae_middle_c, batch_size, step, eta, shape=(32, 32)):
        if step < 1000 and False:
            # 分两端均匀取子集
            # 1k步中的前35%用指定推理步数的50%
            big_steps = self.total_step * (1 - 0.4)
            big_ = int(step * 0.6)
            steps = np.linspace(self.total_step, big_steps, big_)
            steps = np.concatenate([steps, np.linspace(big_steps + int(steps[1] - steps[0]), 1, step - big_)],
                                   axis=0)
        else:
            # 均匀取子集
            steps = np.linspace(self.total_step, 1, step)
        steps = np.floor(steps)
        steps = np.concatenate((steps, steps[-1:]), axis=0)

        x_t = random_clip(torch.randn((batch_size, vae_middle_c, *shape))).to(self.device)  # 32, 32
        for i in range(len(steps) - 1):
            x_t = self.sample(model, x_t, steps[i], steps[i + 1], eta)

            yield x_t

    @torch.no_grad()
    def sample_loop_img2img(self, input_img_latents, noise_steps, model, vae_middle_c, batch_size, step, eta):
        noised_latents = self.apple_noise(input_img_latents, noise_steps)  # (1,4,32,32)
        step = min(noise_steps, step)
        if step < 1000 and False:
            # 分两端均匀取子集
            # 1k步中的前20%用指定推理步数的50%
            big_steps = noise_steps * (1 - 0.3)
            big_ = int(step * 0.5)
            steps = np.linspace(noise_steps, big_steps, big_)
            steps = np.concatenate([steps, np.linspace(big_steps + int(steps[1] - steps[0]), 1, step - big_)],
                                   axis=0)
        else:
            # 均匀取子集
            # print(noise_steps, 1, step)
            steps = np.linspace(noise_steps, 1, step)
            # print("steps", len(steps))

        steps = np.floor(steps)
        # steps = np.concatenate((steps, steps[-1:]), axis=0)

        x_t = torch.tile(noised_latents, (batch_size, 1, 1, 1)).to(self.device)  # 32, 32
        for i in trange(len(steps) - 1):
            x_t = self.sample(model, x_t, steps[i], steps[i + 1], eta)

            yield x_t


class EulerDpmppSampler(Sampler):
    def __init__(self, device, normal_t):
        super(EulerDpmppSampler, self).__init__(device, normal_t)
        self.sample_fun = self.sample_dpmpp_2m

    @staticmethod
    def append_zero(x):
        return torch.cat([x, x.new_zeros([1])])

    # 4e-5 0.99
    @staticmethod
    def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cuda'):
        """Constructs the noise schedule of Karras et al. (2022)."""
        ramp = torch.linspace(0, 1, n)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return EulerDpmppSampler.append_zero(sigmas).to(device)

    @staticmethod
    def default_noise_sampler(x):
        return lambda sigma, sigma_next: torch.randn_like(x)

    @staticmethod
    def get_ancestral_step(sigma_from, sigma_to, eta=1.):
        """Calculates the noise level (sigma_down) to step down to and the amount
        of noise to add (sigma_up) when doing an ancestral sampling step."""
        if not eta:
            return sigma_to, 0.
        sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
        sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
        return sigma_down, sigma_up

    @staticmethod
    def append_dims(x, target_dims):
        """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
        dims_to_append = target_dims - x.ndim
        if dims_to_append < 0:
            raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
        return x[(...,) + (None,) * dims_to_append]

    @staticmethod
    def to_d(x, sigma, denoised):
        """Converts a denoiser output to a Karras ODE derivative."""
        return (x - denoised) / EulerDpmppSampler.append_dims(sigma, x.ndim)

    @staticmethod
    def to_denoised(x, sigma, d):
        return x - d * EulerDpmppSampler.append_dims(sigma, x.ndim)

    @torch.no_grad()
    def sample_euler_ancestral(self, model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1.,
                               noise_sampler=None):
        """Ancestral sampling with Euler method steps."""
        extra_args = {} if extra_args is None else extra_args
        noise_sampler = EulerDpmppSampler.default_noise_sampler(x) if noise_sampler is None else noise_sampler
        s_in = x.new_ones([x.shape[0], 1])
        for i in trange(len(sigmas) - 1, disable=disable):
            t = sigmas[i] * (1 - 1 / self.total_step) + 1 / self.total_step
            t = torch.floor(t * self.total_step)  # 不归一化t需要输入整数

            afa_bar_t = self.afas_cumprod[int(t) - 1]  # 获得加噪用的afa bar
            if self.normal_t:
                t = t / self.total_step

            t = t * s_in
            output = model(x, t, **extra_args)
            denoised = (x - torch.sqrt(1 - afa_bar_t) * output) / torch.sqrt(afa_bar_t)

            sigma_down, sigma_up = self.get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
            if callback is not None:
                callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
            d = self.to_d(x, sigmas[i], denoised)
            # d = denoised
            # Euler method
            dt = sigma_down - sigmas[i]
            x = x + d * dt
            if sigmas[i + 1] > 0:
                x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
            yield x
        # return x

    @torch.no_grad()
    def sample_dpmpp_2m(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
        """DPM-Solver++(2M)."""
        extra_args = {} if extra_args is None else extra_args
        s_in = x.new_ones([x.shape[0], 1])
        sigma_fn = lambda t: t.neg().exp()
        t_fn = lambda sigma: sigma.log().neg()
        old_denoised = None

        for i in trange(len(sigmas) - 1, disable=disable):
            t = sigmas[i] * (1 - 1 / self.total_step) + 1 / self.total_step
            t = torch.floor(t * self.total_step)  # 不归一化t需要输入整数

            afa_bar_t = self.afas_cumprod[int(t) - 1]  # 获得加噪用的afa bar
            if self.normal_t:
                t = t / self.total_step

            t = t * s_in
            output = model(x, t, **extra_args)
            denoised = (x - torch.sqrt(1 - afa_bar_t) * output) / torch.sqrt(afa_bar_t)

            if callback is not None:
                callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
            t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
            h = t_next - t
            if old_denoised is None or sigmas[i + 1] == 0:
                x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
            else:
                h_last = t - t_fn(sigmas[i - 1])
                r = h_last / h
                denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
                x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
            old_denoised = denoised
            yield x

    def switch_sampler(self, sampler_name):
        if sampler_name == "euler a":
            self.sample_fun = self.sample_euler_ancestral
        elif sampler_name == "dpmpp 2m":
            self.sample_fun = self.sample_dpmpp_2m
        else:
            self.sample_fun = self.sample_euler_ancestral

    def sample_loop(self, model, vae_middle_c, batch_size, step, eta, shape=(32, 32)):
        x = torch.randn((batch_size, vae_middle_c, 32, 32)).to(device)
        sigmas = self.get_sigmas_karras(step, 1e-5, 0.999, device=device)
        # sigmas = self.get_named_beta_schedule("scaled_linear", step)

        looper = self.sample_fun(unet, x, sigmas)
        for _ in trange(len(sigmas) - 1):
            x_t = next(looper)
            yield x_t


class PretrainVae:
    def __init__(self, device):
        from diffusers import AutoencoderKL, DiffusionPipeline
        self.vae = AutoencoderKL.from_pretrained("gsdf/Counterfeit-V2.5",  # segmind/small-sd
                                                 subfolder="vae",
                                                 cache_dir="./vae/pretrain_vae").to(device)
        self.vae.requires_grad_(False)
        self.middle_c = 4
        self.vae_scaleing = 0.18215

    def encoder(self, x):
        latents = self.vae.encode(x)
        latents = latents.latent_dist
        mean = latents.mean * self.vae_scaleing
        var = latents.var * self.vae_scaleing
        return mean, var

    def decoder(self, latents):
        latents = latents / self.vae_scaleing
        output = self.vae.decode(latents).sample
        return output

    # 释放encoder
    def res_encoder(self):
        del self.vae.encoder
        torch.cuda.empty_cache()


# ================================================================

def merge_images(images: np.ndarray):
    """
    合并图像
    :param images: 图像数组
    :return: 合并后的图像数组
    """
    n, h, w, c = images.shape
    nn = int(np.ceil(n ** 0.5))
    merged_image = np.zeros((h * nn, w * nn, 3), dtype=images.dtype)
    for i in range(n):
        row = i // nn
        col = i % nn
        merged_image[row * h:(row + 1) * h, col * w:(col + 1) * w, :] = images[i]

    merged_image = np.clip(merged_image, 0, 255)
    merged_image = np.array(merged_image, dtype=np.uint8)
    return merged_image


def get_models(device):
    def modelLoad(model, model_path, data_parallel=False):
        if str(device) == "cuda":
            model.load_state_dict(torch.load(model_path), strict=True)
        else:
            model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)

        if data_parallel:
            model = torch.nn.DataParallel(model)
        return model

    from net.UNet import UNet
    config = {
        # 模型结构相关
        "en_out_c": (256, 256, 256, 320, 320, 320, 576, 576, 576, 704, 704, 704),
        "en_down": (0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
        "en_skip": (0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1),
        "en_att_heads": (8, 8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8),
        "de_out_c": (704, 576, 576, 576, 320, 320, 320, 256, 256, 256, 256),
        "de_up": ("none", "subpix", "none", "none", "subpix", "none", "none", "subpix", "none", "none", "none"),
        "de_skip": (1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
        "de_att_heads": (8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8),  # skip的地方不做self-attention
        "t_out_c": 256,
        "vae_c": 4,
        "block_deep": 3,
        "use_pretrain_vae": True,

        "normal_t": True,

        "model_save_path": "./weight",
        "model_name": "unet",
        "model_tail": "ema",
    }
    print("加载模型...")
    unet = UNet(config["en_out_c"], config["en_down"], config["en_skip"], config["en_att_heads"],
                config["de_out_c"], config["de_up"], config["de_skip"], config["de_att_heads"],
                config["t_out_c"], config["vae_c"], config["block_deep"]).to(device)
    unet = modelLoad(unet, os.path.join(config["model_save_path"],
                                        f"{config['model_name']}_{config['model_tail']}.pth"))

    vae = PretrainVae(device)
    print("加载完成")
    return unet, vae, config["normal_t"]


def init_webui(unet, vae, normal_t):
    # 定义回调函数
    def process_image(input_image_value, noise_step, step_value, batch_size, sampler_name, img_size, random_seed,
                      progress=gr.Progress()):
        progress(0, desc="开始...")

        setup_seed(int(random_seed))
        noise_step = float(noise_step)
        step_value = int(step_value)
        batch_size = int(batch_size)
        img_size = int(img_size) // 8
        img_size = (img_size, img_size)

        if sampler_name == "DDIM":
            sampler = DDIMSampler(device, normal_t)
        elif sampler_name == "euler a" or sampler_name == "dpmpp 2m":
            sampler = EulerDpmppSampler(device, normal_t)
            sampler.switch_sampler(sampler_name)
        else:
            raise ValueError(f"Unknow sampler_name: {sampler_name}")
        if input_image_value is None:
            looper = sampler.sample_loop(unet, vae.middle_c, batch_size, step_value, shape=img_size, eta=1.)
        else:
            input_image_value = Image.fromarray(input_image_value).resize((img_size[0] * 8, img_size[1] * 8),
                                                                          resample=Image.BILINEAR)
            input_image_value = np.array(input_image_value, dtype=np.float32) / 255.
            input_image_value = np.transpose(input_image_value, (2, 0, 1))
            input_image_value = torch.Tensor([input_image_value]).to(device)
            input_img_latents = sampler.encode_img(vae, input_image_value)
            looper = sampler.sample_loop_img2img(input_img_latents,
                                                 int(noise_step * sampler.total_step),
                                                 unet,
                                                 vae.middle_c,
                                                 batch_size,
                                                 step_value,
                                                 eta=1.)
        # print(step_value)
        ss = 0
        for i in progress.tqdm(range(1, step_value + 1)):
            try:
                output = next(looper)
                ss += 1
            except StopIteration:
                # print("StopIteration", ss)
                break

        output = sampler.decode_img(vae, output)
        output = np.clip(output, 0, 255)
        marge_img = merge_images(output)

        output = [marge_img] + list(output)

        return output

    def process_image_u(step_value, batch_size, sampler_name, img_size, random_seed,
                        progress=gr.Progress()):
        return process_image(None, 0, step_value, batch_size, sampler_name, img_size, random_seed,
                             progress)

    with gr.Blocks() as iface:
        gr.Markdown(
            "This is a diffusion model for generating second-dimensional avatars, which can be used for unconditional generation or image-to-image generation.")

        with gr.Tab(label="unconditional generation"):
            with gr.Column():
                with gr.Row():
                    # 选择sampler
                    sampler_name_u = gr.Dropdown(["DDIM"], label="sampler", value="DDIM")  # , "euler a", "dpmpp 2m"
                    # 创建滑动条组件
                    step_u = gr.Slider(minimum=1, maximum=1000, value=40, label="steps", step=1)
                    batch_size_u = gr.Slider(minimum=1, maximum=4, label="batch size", step=1)
                    img_size_u = gr.Slider(minimum=256, maximum=512, value=256, label="img size", step=64)
                    ramdom_seed_u = gr.Number(value=-1, label="ramdom seed(-1 as random number)")
                    # 创建开始按钮组件
                    start_button_u = gr.Button(value="Run")
            # 创建输出组件
            output_images_u = gr.Gallery(show_label=False, height=400, columns=5)
            gr.Examples(
                examples=[[60, 4, "DDIM", 256, 255392]],  # 255392
                inputs=[step_u, batch_size_u, sampler_name_u, img_size_u, ramdom_seed_u],
                outputs=output_images_u,
                fn=process_image_u,
                cache_examples=False,
            )
        with gr.Tab(label="image to image"):
            with gr.Column():
                with gr.Row():
                    with gr.Column():
                        # 创建输入组件
                        input_image = gr.Image(label="image to image")
                        # 加噪程度
                        noise_step = gr.Slider(minimum=0.05, maximum=1, value=0.6, label="加噪程度", step=0.01)
                    with gr.Column():
                        # 选择sampler
                        sampler_name = gr.Dropdown(["DDIM"], label="sampler", value="DDIM")  # , "euler a", "dpmpp 2m"
                        # 创建滑动条组件
                        step = gr.Slider(minimum=1, maximum=1000, value=40, label="steps", step=1)
                        batch_size = gr.Slider(minimum=1, maximum=4, label="batch size", step=1)
                        img_size = gr.Slider(minimum=256, maximum=512, value=256, label="img size", step=64)
                        ramdom_seed = gr.Number(value=-1, label="ramdom seed(-1 as random number)")
                        # 创建开始按钮组件
                        start_button = gr.Button(value="Run")

            # 创建输出组件
            output_images = gr.Gallery(show_label=False, height=400, columns=5)
            gr.Examples(
                examples=[["./example.jpg", 0.4, 60, 4, "DDIM", 320, 231324]],  # 224477,378754
                inputs=[input_image, noise_step, step, batch_size, sampler_name, img_size, ramdom_seed],
                outputs=output_images,
                fn=process_image,
                cache_examples=False,
            )

        start_button.click(process_image,
                           [input_image, noise_step, step, batch_size, sampler_name, img_size, ramdom_seed],
                           [output_images])
        start_button_u.click(process_image_u, [step_u, batch_size_u, sampler_name_u, img_size_u, ramdom_seed_u],
                             [output_images_u])

    return iface


def setup_seed(seed=0):
    import random
    if seed == -1:
        seed = random.randint(0, 1000000)
    print(seed)
    torch.manual_seed(seed)  # 为CPU设置随机种子
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    if torch.cuda.is_available():
        # torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
        torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子
        # os.environ['PYTHONHASHSEED'] = str(seed)


if __name__ == '__main__':
    device = torch.device('cpu')
    # device = torch.device('cuda')
    unet, vae, normal_t = get_models(device)


    def run_with_ui(unet, vae, normal_t):
        # 创建界面
        iface = init_webui(unet, vae, normal_t)

        # 运行界面
        iface.queue().launch()  #


    run_with_ui(unet, vae, normal_t)