from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import gradio as gr
from PIL import Image
import cv2
import os, random, gc
import numpy as np
from transformers import pipeline
import PIL.Image
from diffusers.utils import load_image, export_to_video
from accelerate import Accelerator
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
import torch
from moviepy.video.fx.all import crop
from diffusers.utils import export_to_gif
import mediapy
from image_tools.sizes import resize_and_crop
from moviepy.editor import *
from pathlib import Path
from typing import Optional, List
from tqdm import tqdm
import supervision as sv


accelerator = Accelerator(cpu=True)
models =[
    "runwayml/stable-diffusion-v1-5",
    "prompthero/openjourney-v4",
    "CompVis/stable-diffusion-v1-4",
    "stabilityai/stable-diffusion-2-1",
    "stablediffusionapi/edge-of-realism",
    "sd-dreambooth-library/fashion",
    "DucHaiten/DucHaitenDreamWorld",
    "kandinsky-community/kandinsky-2-1",
    "plasmo/woolitize-768sd1-5",
    "wavymulder/modelshoot",
    "Fictiverse/Stable_Diffusion_VoxelArt_Model",
    "darkstorm2150/Protogen_v2.2_Official_Release",
    "hassanblend/HassanBlend1.5.1.2",
    "hassanblend/hassanblend1.4",
    "nitrosocke/redshift-diffusion",
    "prompthero/openjourney-v2",
    "Lykon/DreamShaper",
    "nitrosocke/mo-di-diffusion",
    "dreamlike-art/dreamlike-diffusion-1.0",
    "dreamlike-art/dreamlike-photoreal-2.0",
    "digiplay/RealismEngine_v1",
    "digiplay/AIGEN_v1.4_diffusers",
    "stablediffusionapi/dreamshaper-v6",
    "TheLastBen/froggy-style-v21-768",
    "digiplay/PotoPhotoRealism_v1",
]

controlnet = accelerator.prepare(ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float32))
def plex(fpath, text, neg_prompt, modil, one, two, three, four, five):
    gc.collect()
    modal=""+modil+""
    pipe = accelerator.prepare(StableDiffusionControlNetImg2ImgPipeline.from_pretrained(modal, controlnet=controlnet, torch_dtype=torch.float32, use_safetensors=False, safety_checker=None))
    pipe.unet.to(memory_format=torch.channels_last)
    pipe.scheduler = accelerator.prepare(DPMSolverMultistepScheduler.from_config(pipe.scheduler.config))
    pipe = pipe.to("cpu")
    prompt = text
    video = './video.mp4'
    orvid = './orvid.mp4'
    canvid = './canvid.mp4'
    frames = []
    canframes = []
    orframes = []
    fin_frames = []
    max_frames=0
    cap = cv2.VideoCapture(fpath)
    clip = VideoFileClip(fpath)
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    fps = cap.get(cv2.CAP_PROP_FPS)
    aspect = width / height
    if aspect == 1 and height >= 512:
        nwidth = 512
        nheight = 512
        prep = clip.resize(height=nheight)
        left = 0
        top = 0
        right = 512
        bottom = 512
    if aspect > 1 and height >= 512:
        nheight = 512
        nwidth = int(nheight * aspect)
        prep = clip.resize(height=nheight)
        left = (nwidth - width) / 2
        top = 0
        right = (nwidth + width) / 2
        bottom = nheight
    if aspect < 1 and width >= 512:
        nwidth = 512
        nheight = int(nwidth / aspect)
        prep = clip.resize(height=nheight)
        left = 0
        top = (height - nheight) / 2
        right = nwidth
        bottom = (height + nheight) / 2
    if aspect < 1 and width < 512:
        return None
    if aspect > 1 and height < 512:
        return None
    closer = crop(clip, x1=left, y1=top, x2=right, y2=bottom)
    if fps > 10:
        closer.write_videofile('./video.mp4', fps=10)
        fps = 10
    else:
        closer.write_videofile('./video.mp4', fps=fps)
        fps = fps
    max_frames = int(fps * 2)
    for frame in tqdm(sv.get_video_frames_generator(source_path=video,)):
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release()
    cv2.destroyAllWindows()
    ncap = cv2.VideoCapture(video)
    total_frames = int(ncap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames <= 0:
        return None
    b = 0
    if total_frames > max_frames:
        max_frames = int(max_frames)
    if total_frames < max_frames:
        max_frames = int(total_frames)
    for b in range(int(max_frames)):
        frame = frames[b]
        original = load_image(Image.fromarray(frame))
        original.save('./image.png', 'PNG')
        original = original.resize((512, 512))
        original = original.convert("RGB")
        original.save('./image.png', 'PNG')
        orframes.append(original)
        cannyimage = np.array(original)
        cannyimage = cv2.Canny(cannyimage, 100, 200)
        cannyimage = cannyimage[:, :, None]
        cannyimage = np.concatenate([cannyimage, cannyimage, cannyimage], axis=2)
        cannyimage = Image.fromarray(cannyimage)
        canframes.append(cannyimage)
        generator = torch.Generator(device="cpu").manual_seed(five)
        imoge = pipe(prompt=prompt,image=[original],control_image=[cannyimage],guidance_scale=four,num_inference_steps=one,generator=generator,strength=two,negative_prompt=neg_prompt,controlnet_conditioning_scale=three,width=512,height=512)
        fin_frames.append(imoge.images[0])
        b += 1
    ncap.release()
    cv2.destroyAllWindows()
    export_to_video(fin_frames, video, fps=fps)
    export_to_video(orframes, orvid, fps=fps)
    export_to_video(canframes, canvid, fps=fps)
    return video, canvid, orvid

iface = gr.Interface(fn=plex, inputs=[gr.File(label="Your video",interactive=True, file_types=['.mp4',]),gr.Textbox(label="prompt"),gr.Textbox(label="neg prompt"),gr.Dropdown(choices=models, label="Models", value=models[0], type="value"), gr.Slider(label="num inference steps", minimum=1, step=1, maximum=10, value=4), gr.Slider(label="Strength", minimum=0.01, step=0.01, maximum=20.00, value=5.00), gr.Slider(label="controlnet scale", minimum=0.01, step=0.01, maximum=0.99, value=0.80), gr.Slider(label="Guidance scale", minimum=0.01, step=0.01, maximum=10.00, value=2.00), gr.Slider(label="Manual seed", minimum=0, step=32, maximum=4836928, value=0)], outputs=[gr.Video(label="final"), gr.Video(label="canny vid"), gr.Video(label="orig")],description="Running on cpu, very slow! by JoPmt.")
iface.queue(max_size=1,api_open=False)
iface.launch(max_threads=1)