File size: 4,349 Bytes
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import math
import numpy as np
import torch


def cal_rectify_ratio(start_t, gamma):
    return 1 / (math.sqrt(1 - (1 / gamma)) * (1 - start_t) + start_t)


class PixelFlowScheduler:
    def __init__(self, num_train_timesteps, num_stages, gamma=-1 / 3):
        assert num_stages > 0, f"num_stages must be positive, got {num_stages}"
        self.num_stages = num_stages
        self.gamma = gamma

        self.Timesteps = torch.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=torch.float32)

        self.t = self.Timesteps / num_train_timesteps  # normalized time in [0, 1]

        self.stage_range = [x / num_stages for x in range(num_stages + 1)]

        self.original_start_t = dict()
        self.start_t, self.end_t = dict(), dict()
        self.t_window_per_stage = dict()
        self.Timesteps_per_stage = dict()
        stage_distance = list()

        # stage_idx = 0: min t, min resolution, most noisy
        # stage_idx = num_stages - 1 : max t, max resolution, most clear
        for stage_idx in range(num_stages):
            start_idx = max(int(num_train_timesteps * self.stage_range[stage_idx]), 0)
            end_idx = min(int(num_train_timesteps * self.stage_range[stage_idx + 1]), num_train_timesteps)

            start_t = self.t[start_idx].item()
            end_t = self.t[end_idx].item() if end_idx < num_train_timesteps else 1.0

            self.original_start_t[stage_idx] = start_t

            if stage_idx > 0:
                start_t *= cal_rectify_ratio(start_t, gamma)

            self.start_t[stage_idx] = start_t
            self.end_t[stage_idx] = end_t
            stage_distance.append(end_t - start_t)

        total_stage_distance = sum(stage_distance)
        t_within_stage = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float64)[:-1]

        for stage_idx in range(num_stages):
            start_ratio = 0.0 if stage_idx == 0 else sum(stage_distance[:stage_idx]) / total_stage_distance
            end_ratio = 1.0 if stage_idx == num_stages - 1 else sum(stage_distance[:stage_idx + 1]) / total_stage_distance

            Timestep_start = self.Timesteps[int(num_train_timesteps * start_ratio)]
            Timestep_end = self.Timesteps[min(int(num_train_timesteps * end_ratio), num_train_timesteps - 1)]

            self.t_window_per_stage[stage_idx] = t_within_stage

            if stage_idx == num_stages - 1:
                self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps, dtype=torch.float64)
            else:
                self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps + 1, dtype=torch.float64)[:-1]

    @staticmethod
    def time_linear_to_Timesteps(t, t_start, t_end, T_start, T_end):
        """
        linearly map t to T: T = k * t + b
        """
        k = (T_end - T_start) / (t_end - t_start)
        b = T_start - t_start * k
        return k * t + b

    def set_timesteps(self, num_inference_steps, stage_index, device=None, shift=1.0):
        self.num_inference_steps = num_inference_steps

        stage_T_start = self.Timesteps_per_stage[stage_index][0].item()
        stage_T_end = self.Timesteps_per_stage[stage_index][-1].item()

        t_start = self.t_window_per_stage[stage_index][0].item()
        t_end = self.t_window_per_stage[stage_index][-1].item()

        t = np.linspace(t_start, t_end, num_inference_steps, dtype=np.float64)
        t = t / (shift  + (1 - shift) * t)

        Timesteps = self.time_linear_to_Timesteps(t, t_start, t_end, stage_T_start, stage_T_end)
        self.Timesteps = torch.from_numpy(Timesteps).to(device=device)

        self.t = torch.from_numpy(np.append(t, 1.0)).to(device=device, dtype=torch.float64)
        self._step_index = None

    def step(self, model_output, sample):
        if self.step_index is None:
            self._step_index = 0

        sample = sample.to(torch.float32)
        t = self.t[self.step_index].float()
        t_next = self.t[self.step_index + 1].float()

        prev_sample = sample + (t_next - t) * model_output
        self._step_index += 1

        return prev_sample.to(model_output.dtype)

    @property
    def step_index(self):
        """Current step index for the scheduler."""
        return self._step_index