File size: 4,619 Bytes
8b14bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import pyrootutils
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from transformers import AutoTokenizer

# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

from torch.utils.data import DataLoader

from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
from tools.llama.generate import load_model


def smooth(
    scalars: list[float], weight: float
) -> list[float]:  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)  # Save it
        last = smoothed_val  # Anchor the last smoothed value

    return smoothed


@torch.inference_mode()
def analyze_one_model(loader, config, weight, max_length):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = load_model(
        config,
        weight,
        device,
        torch.bfloat16,
        max_length,
        compile=False,
    )[0]

    current_step = 0
    model.eval()

    semantic_loss_sum = torch.zeros(
        max_length,
        dtype=torch.float32,
        device=device,
    )
    counter = torch.zeros(
        max_length,
        dtype=torch.long,
        device=device,
    )

    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}

        labels = batch["labels"]
        outputs = model(
            inp=batch["inputs"],
            key_padding_mask=batch["attention_masks"],
        )

        token_logits = outputs.token_logits
        codebook_logits = outputs.codebook_logits

        # Generate labels
        base_loss = F.cross_entropy(
            token_logits.reshape(-1, token_logits.size(-1)),
            labels[:, 0].reshape(-1),
            ignore_index=-100,
            reduction="none",
        )

        codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
        semantic_loss = F.cross_entropy(
            codebook_logits.reshape(-1, codebook_logits.size(-1)),
            codebook_labels.reshape(-1),
            ignore_index=-100,
            reduction="none",
        )

        base_loss = base_loss.reshape(labels[:, 0].shape)
        semantic_loss = semantic_loss.reshape(codebook_labels.shape)

        semantic_loss_frame = semantic_loss.mean(-1)
        pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks

        for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
            semantic_loss_sum[~pad] += loss_sample[~pad]
            counter[~pad] += 1

        current_step += 1
        if current_step == 10:
            break

    semantic_loss = semantic_loss.cpu()
    counter = counter.cpu()
    xs, ys = [], []

    for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
        if count > 0:
            xs.append(i)
            ys.append((loss / count).item())  # for better loss visualization

    smoothed_ys = smooth(ys, 0.95)

    # Unload model
    del model
    torch.cuda.empty_cache()

    return xs, ys, smoothed_ys


def main():
    tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
    max_length = 4096

    ds = AutoAugTextDataset(
        ["data/protos/sft/云天河"],
        tokenizer=tokenizer,
        use_speaker=False,
        interactive_prob=1.0,
        max_length=max_length,
    )

    loader = DataLoader(
        ds,
        batch_size=8,
        collate_fn=TextDataCollator(tokenizer, max_length=max_length),
        num_workers=0,
        shuffle=False,
    )

    plt.figure(figsize=(10, 5), dpi=200)

    plt.xlabel("Frame")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.title("Semantic Loss")
    plt.grid(which="both", axis="both")
    plt.xlim(0, max_length)

    tests = [
        (
            "pertrain-medium",
            "dual_ar_2_codebook_medium",
            "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
        ),
        (
            "sft-medium",
            "dual_ar_2_codebook_medium",
            "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
        ),
        (
            "sft-large",
            "dual_ar_2_codebook_large",
            "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
        ),
    ]

    for name, config, weight in tests:
        xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
        plt.plot(xs, smoothed_ys, label=name)

    plt.legend()
    plt.savefig("semantic_loss.png")


if __name__ == "__main__":
    main()