File size: 2,473 Bytes
d323598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

import rerun.blueprint as rrb
import torch
from transformers.utils import hub

from . import sample, sample_utils


def create_model():
    return sample_utils.init_model(
        {
            "config": "./vista/configs/inference/vista.yaml",
            "ckpt": hub.get_file_from_repo("OpenDriveLab/Vista", "vista.safetensors"),
        }
    )


def generate_blueprint(n_rounds: int) -> rrb.Blueprint:
    row1 = rrb.Horizontal(
        *[
            rrb.TensorView(origin=f"diffusion_{i}", name=f"Latents Segment {i+1}")
            for i in range(n_rounds)
        ],
    )
    row2 = rrb.Spatial2DView(origin="generated_image", name="Generated Video")

    return rrb.Blueprint(rrb.Vertical(row1, row2), collapse_panels=True)


def run_sampling(
    log_queue,
    first_frame_file_name,
    height,
    width,
    n_rounds,
    n_frames,
    n_steps,
    cfg_scale,
    cond_aug,
    model=None,
) -> None:
    if model is None:
        model = create_model()

    unique_keys = set([x.input_key for x in model.conditioner.embedders])
    value_dict = sample_utils.init_embedder_options(unique_keys)

    action_dict = None

    first_frame = sample.load_img(first_frame_file_name, height, width, "cuda")[None]
    repeated_frame = first_frame.expand(n_frames, -1, -1, -1)

    value_dict = sample_utils.init_embedder_options(unique_keys)
    cond_img = first_frame
    value_dict["cond_frames_without_noise"] = cond_img
    value_dict["cond_aug"] = cond_aug
    value_dict["cond_frames"] = cond_img + cond_aug * torch.randn_like(cond_img)
    if action_dict is not None:
        for key, value in action_dict.items():
            value_dict[key] = value

    if n_rounds > 1:
        guider = "TrianglePredictionGuider"
    else:
        guider = "VanillaCFG"
    sampler = sample_utils.init_sampling(
        guider=guider,
        steps=n_steps,
        cfg_scale=cfg_scale,
        num_frames=n_frames,
    )

    uc_keys = [
        "cond_frames",
        "cond_frames_without_noise",
        "command",
        "trajectory",
        "speed",
        "angle",
        "goal",
    ]

    _generated_images, _samples_z, _inputs = sample_utils.do_sample(
        repeated_frame,
        model,
        sampler,
        value_dict,
        num_rounds=n_rounds,
        num_frames=n_frames,
        force_uc_zero_embeddings=uc_keys,
        initial_cond_indices=[0],
        log_queue=log_queue,
    )

    log_queue.put("done")