File size: 4,194 Bytes
f949b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# General
import os
from os.path import join as opj
import datetime
import torch
from einops import rearrange, repeat

# Utilities
from inference_utils import *

from modelscope.outputs import OutputKeys
import imageio
from PIL import Image
import numpy as np

import torch.nn.functional as F
import torchvision.transforms as transforms
from diffusers.utils import load_image
transform = transforms.Compose([ 
    transforms.PILToTensor() 
])


def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"):
    frames = ms_model(prompt, 
                    num_inference_steps=t, 
                    generator=inference_generator, 
                    eta=1.0, 
                    height=256, 
                    width=256, 
                    latents=None).frames
    frames = torch.stack([torch.from_numpy(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    return rearrange(frames[0], "F W H C -> F C W H")

def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"):
    frames = ad_model(prompt, 
                    negative_prompt="bad quality, worse quality",
                    num_frames=16,
                    num_inference_steps=t, 
                    generator=inference_generator, 
                    guidance_scale=7.5).frames[0]
    frames = torch.stack([transform(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    frames = F.interpolate(frames, size=256)
    frames = frames/255.0
    return frames

def sdxl_image_gen(prompt, sdxl_model):
    image = sdxl_model(prompt=prompt).images[0]
    return image

def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"):
    if image is None or image == "":
        image = sdxl_image_gen(prompt, sdxl_model)
        image = image.resize((576, 576))
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
    else:
        image = load_image(image)
        image = resize_and_keep(image)
        image = center_crop(image)
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))

    frames = svd_model(image, decode_chunk_size=8, generator=inference_generator).frames[0]
    frames = torch.stack([transform(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    frames = frames[:16,:,:,224:-224]
    frames = F.interpolate(frames, size=256)
    frames = frames/255.0
    return frames


def stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, result_file_stem, stream_cli, stream_model):
    trainer = stream_cli.trainer
    trainer.limit_predict_batches = 1
    trainer.predict_cfg = {
        "predict_dir": stream_cli.config["result_fol"].as_posix(),
        "result_file_stem": result_file_stem,
        "prompt": prompt,
        "video": short_video,
        "seed": seed,
        "num_inference_steps": t,
        "guidance_scale": image_guidance,
        'n_autoregressive_generations': n_autoreg_gen,
    }

    trainer.predict(model=stream_model, datamodule=stream_cli.datamodule)


def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True):
    downscale = cfg_v2v['downscale']
    upscale_size = cfg_v2v['upscale_size']
    pad = cfg_v2v['pad']

    now = datetime.datetime.now()
    name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
    enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4")

    video_frames = imageio.mimread(video)
    h, w, _ = video_frames[0].shape

    # Downscale video, then resize to fit the upscale size
    video = [Image.fromarray(frame).resize((w//downscale, h//downscale)) for frame in video_frames]
    video = [resize_to_fit(frame, upscale_size) for frame in video]

    if pad:
        video = [pad_to_fit(frame, upscale_size) for frame in video]
    # video = [np.array(frame) for frame in video]

    imageio.mimsave(opj(where_to_log, 'temp.mp4'), video, fps=8)

    p_input = {
        'video_path': opj(where_to_log, 'temp.mp4'),
        'text': prompt
    }
    output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO]
    return enhanced_video_mp4