Spaces:
No application file
No application file
File size: 33,784 Bytes
8b14bed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 |
import os
import queue
import threading
import time
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
import click
import hydra
import numpy as np
import torch
import torch._dynamo.config
import torch._inductor.config
from loguru import logger
from tqdm import tqdm
from transformers import AutoTokenizer
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
from fish_speech.models.text2semantic.llama import BaseModelArgs
from fish_speech.text import clean_text, split_text
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
if hasattr(torch._inductor.config, "fx_graph_cache"):
# Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = True
from torch.nn.attention import SDPBackend, sdpa_kernel
from fish_speech.models.text2semantic.llama import (
BaseTransformer,
DualARTransformer,
NaiveTransformer,
)
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: torch.Tensor = 1.0,
top_p: torch.Tensor = 1.0,
repetition_penalty: torch.Tensor = 1.0,
) -> torch.Tensor:
# Apply repetition penalty
if previous_tokens is not None:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def multinomial_sample_one_no_sync_agent(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs_agent(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: torch.Tensor = 1.0,
top_p: torch.Tensor = 1.0,
repetition_penalty: torch.Tensor = 1.0,
) -> torch.Tensor:
# Apply repetition penalty
if previous_tokens is not None:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=-1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=-1, index=previous_tokens, src=score)
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(
logits,
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def sample_agent(
logits,
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs_agent(
logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
)
idx_next = multinomial_sample_one_no_sync_agent(probs)
return idx_next, probs
def decode_one_token_ar_agent(
model: DualARTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
previous_tokens: torch.Tensor = None,
semantic_id: int = 32003,
**sampling_kwargs,
) -> torch.Tensor:
# print(x, input_pos)
x = model.forward_generate(x, input_pos)
logits = x.logits # [:, -1:]
hidden_states = x.hidden_states # [:, -1:]
sampling_kwargs_main = sampling_kwargs.copy()
sampling_kwargs_main["temperature"] = 0.1
sampling_kwargs_main["top_p"] = 0.1
sampling_kwargs_main["repetition_penalty"] = 1.0
codebooks = [
sample_agent(
logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs_main,
)[0]
]
# Cleanup the cache
for layer in model.fast_layers:
layer.attention.kv_cache.k_cache.fill_(0)
layer.attention.kv_cache.v_cache.fill_(0)
for codebook_idx in range(model.config.num_codebooks):
input_pos = torch.tensor(
[codebook_idx], device=hidden_states.device, dtype=torch.long
)
logits = model.forward_generate_fast(hidden_states, input_pos)
a = sample_agent(
logits,
previous_tokens=(
previous_tokens[:, codebook_idx + 1]
if previous_tokens is not None
else None
),
**sampling_kwargs,
)[0]
hidden_states = model.fast_embeddings(a)
codebooks.append(a)
codebooks = torch.stack(codebooks, dim=1)
codebooks[:, 1:, :] = torch.masked_fill(
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
)
# for i in range(codebooks.size(1) - 1):
# codebooks[:, i + 1, :] = torch.masked_fill(
# codebooks[:, i + 1, :],
# codebooks[:, :1, :] != semantic_id,
# CODEBOOK_PAD_TOKEN_ID + i * 1024,
# )
# print(codebooks)
return codebooks
def decode_one_token_naive_agent(
model: NaiveTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
previous_tokens: torch.Tensor = None,
semantic_id: int = 32003,
**sampling_kwargs,
) -> torch.Tensor:
x = model.forward_generate(x, input_pos)
codebooks = [
sample(
x.token_logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs,
)[0]
]
for i in range(model.config.num_codebooks):
codebooks.append(
sample_agent(
x.codebook_logits[:, :, i],
previous_tokens=(
previous_tokens[:, i + 1] if previous_tokens is not None else None
),
**sampling_kwargs,
)[0]
)
codebooks = torch.stack(codebooks, dim=1)
codebooks[:, 1:, :] = torch.masked_fill(
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
)
return codebooks
def decode_one_token_ar(
model: DualARTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
previous_tokens: torch.Tensor = None,
semantic_id: int = 0,
**sampling_kwargs,
) -> torch.Tensor:
x = model.forward_generate(x, input_pos)
sampling_kwargs_main = sampling_kwargs.copy()
# sampling_kwargs_main["temperature"] = 0.1
# sampling_kwargs_main["top_p"] = 0.1
# sampling_kwargs_main["repetition_penalty"] = 1.0
codebooks = [
sample(
x.logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs_main,
)[0]
]
x = x.hidden_states
# Cleanup the cache
for layer in model.fast_layers:
layer.attention.kv_cache.k_cache.fill_(0)
layer.attention.kv_cache.v_cache.fill_(0)
for codebook_idx in range(model.config.num_codebooks):
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
logits = model.forward_generate_fast(x, input_pos)
a = sample(
logits,
previous_tokens=(
previous_tokens[codebook_idx + 1]
if previous_tokens is not None
else None
),
**sampling_kwargs,
)[0]
x = model.fast_embeddings(a)
codebooks.append(a)
codebooks = torch.stack(codebooks, dim=0)
codebooks[1:, :] = torch.masked_fill(
codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
)
return codebooks
def decode_one_token_naive(
model: NaiveTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
previous_tokens: torch.Tensor = None,
**sampling_kwargs,
) -> torch.Tensor:
x = model.forward_generate(x, input_pos)
sampling_kwargs_main = sampling_kwargs.copy()
sampling_kwargs_main["temperature"] = 0.1
sampling_kwargs_main["top_p"] = 0.1
sampling_kwargs_main["repetition_penalty"] = 1.0
codebooks = [
sample(
x.logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs_main,
)[0]
]
for i in range(model.config.num_codebooks):
codebooks.append(
sample(
x.codebook_logits[:, :, i],
previous_tokens=(
previous_tokens[i + 1] if previous_tokens is not None else None
),
**sampling_kwargs,
)[0]
)
return torch.stack(codebooks, dim=0)
def decode_n_tokens(
model: NaiveTransformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
im_end_id: int = 4,
decode_one_token=decode_one_token_naive,
semantic_id: int = 0,
**sampling_kwargs,
):
previous_tokens = torch.zeros(
(model.config.num_codebooks + 1, model.config.max_seq_len),
dtype=torch.int,
device=cur_token.device,
)
for i in tqdm(range(num_new_tokens)):
# We need to get windowed repeat penalty
win_size = 16
if i < win_size:
window = previous_tokens[:, :win_size]
else:
window = previous_tokens[:, i - win_size : i]
with (
torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
)
if torch.cuda.is_available()
else nullcontext()
): # Actually better for Inductor to codegen attention here
next_token = decode_one_token(
model=model,
x=cur_token,
input_pos=input_pos,
previous_tokens=window,
semantic_id=semantic_id,
**sampling_kwargs,
)
input_pos += 1
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
previous_tokens[:, i : i + 1] = next_token.view(
model.config.num_codebooks + 1, -1
)
if cur_token[0, 0, -1] == im_end_id:
break
return previous_tokens[:, : i + 1]
@torch.no_grad()
@torch.inference_mode()
def generate(
*,
model: NaiveTransformer,
prompt: torch.Tensor,
max_new_tokens: int,
im_end_id: int = 4,
decode_one_token=decode_one_token_naive,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)
semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T
device, dtype = prompt.device, prompt.dtype
codebook_dim = 1 + model.config.num_codebooks
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
)
empty[:, :T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)
# Use non-accelerated version for now, to avoid compilation overhead
prefill_decode = (
decode_one_token_naive
if isinstance(model, NaiveTransformer)
else decode_one_token_ar
)
next_token = prefill_decode(
model,
prompt.view(1, codebook_dim, -1),
input_pos,
semantic_id=semantic_id,
**sampling_kwargs,
)
seq[:, T : T + 1] = next_token
input_pos = torch.tensor([T], device=device, dtype=torch.int)
x = decode_n_tokens(
model,
next_token.view(1, codebook_dim, -1),
input_pos,
max_new_tokens - 1,
im_end_id=im_end_id,
decode_one_token=decode_one_token,
semantic_id=semantic_id,
**sampling_kwargs,
)
# x = torch.cat(generated_tokens, dim=1)
seq = seq[:, : T + 1 + x.size(1)]
seq[:, T + 1 :] = x
return seq
def decode_n_tokens_agent(
model: NaiveTransformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
im_end_id: int = 4,
semantic_id: int = 32003,
decode_one_token=decode_one_token_naive_agent,
early_stop_threshold: float = 0.6,
**sampling_kwargs,
):
batch_size = cur_token.size(0)
previous_tokens = torch.zeros(
(batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
dtype=torch.int,
device=cur_token.device,
)
finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
finished = finished | (cur_token[:, 0, -1] == im_end_id)
start_time = time.time()
for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
# We need to get windowed repeat penalty
win_size = 16
if i < win_size:
window = previous_tokens[:, :, :win_size]
else:
window = previous_tokens[:, :, i - win_size : i]
with sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
next_token = decode_one_token(
model=model,
x=cur_token,
input_pos=input_pos,
previous_tokens=window,
semantic_id=semantic_id,
**sampling_kwargs,
)
input_pos += 1
cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
previous_tokens[:, :, i : i + 1] = next_token.view(
batch_size, model.config.num_codebooks + 1, -1
)
yield cur_token.cpu()
finished = finished | (cur_token[:, 0, -1] == im_end_id)
if finished.all() or (
0 < early_stop_threshold < 1
and finished.sum() >= round(batch_size * early_stop_threshold)
):
break
total_time = time.time() - start_time
generated_tokens = i + 1
tokens_per_second = (generated_tokens / total_time) * batch_size
logger.info(
f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
)
@torch.no_grad()
@torch.inference_mode()
def generate_agent(
*,
model: BaseTransformer,
prompt: torch.Tensor,
max_new_tokens: int,
im_end_id: int = 4,
semantic_id: int = 32003,
decode_one_token=decode_one_token_naive_agent,
num_samples: int = 1,
early_stop_threshold: float = 0.6,
**sampling_kwargs,
):
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)
prompt = prompt[None].repeat(num_samples, 1, 1)
if T >= model.config.max_seq_len:
raise ValueError(
f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
)
if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T
device, dtype = prompt.device, prompt.dtype
codebook_dim = 1 + model.config.num_codebooks
input_pos = torch.arange(0, T, device=device)
# Use non-accelerated version for now, to avoid compilation overhead
prefill_decode = (
decode_one_token_naive_agent
if isinstance(model, NaiveTransformer)
else decode_one_token_ar_agent
)
next_token = prefill_decode(
model,
prompt,
input_pos,
semantic_id=semantic_id,
**sampling_kwargs,
).view(num_samples, codebook_dim, -1)
yield next_token.cpu()
input_pos = torch.tensor([T], device=device, dtype=torch.int)
yield from decode_n_tokens_agent(
model,
next_token,
input_pos,
max_new_tokens - 1,
im_end_id=im_end_id,
semantic_id=semantic_id,
decode_one_token=decode_one_token,
early_stop_threshold=early_stop_threshold,
**sampling_kwargs,
)
def encode_tokens(
tokenizer,
string,
device="cuda",
prompt_tokens=None,
num_codebooks=4,
):
string = clean_text(string)
string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
new_tokens = tokenizer.encode(
string,
add_special_tokens=False,
max_length=10**6,
truncation=False,
)
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
# Codebooks
zeros = (
torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
* CODEBOOK_PAD_TOKEN_ID
)
prompt = torch.cat((tokens, zeros), dim=0)
if prompt_tokens is None:
return prompt
# Get prompt tokens
if prompt_tokens.ndim == 3:
assert (
prompt_tokens.shape[0] == 1
), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
prompt_tokens = prompt_tokens[0]
assert prompt_tokens.ndim == 2
data = prompt_tokens + 1
if prompt_tokens.shape[0] > num_codebooks:
logger.warning(
f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
)
data = data[:num_codebooks]
# Add pad token for each codebook
data = torch.cat(
(data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
dim=1,
)
# Since 1.0, we use <|semantic|>
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
main_token_ids = (
torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
)
main_token_ids[0, -1] = end_token_id
data = torch.cat((main_token_ids, data), dim=0)
prompt = torch.cat((prompt, data), dim=1)
return prompt
def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
checkpoint_path, load_weights=True
)
model = model.to(device=device, dtype=precision)
logger.info(f"Restored model from checkpoint")
if isinstance(model, DualARTransformer):
decode_one_token = (
decode_one_token_ar_agent if is_agent else decode_one_token_ar
)
logger.info("Using DualARTransformer")
else:
decode_one_token = (
decode_one_token_naive_agent if is_agent else decode_one_token_naive
)
logger.info("Using NaiveTransformer")
if compile:
logger.info("Compiling function...")
decode_one_token = torch.compile(
decode_one_token,
fullgraph=True,
backend="inductor" if torch.cuda.is_available() else "aot_eager",
mode="reduce-overhead" if torch.cuda.is_available() else None,
)
return model.eval(), decode_one_token
@dataclass
class GenerateResponse:
action: Literal["sample", "next"]
codes: Optional[torch.Tensor] = None
text: Optional[str] = None
def generate_long(
*,
model,
device: str | torch.device,
decode_one_token: callable,
text: str,
num_samples: int = 1,
max_new_tokens: int = 0,
top_p: int = 0.7,
repetition_penalty: float = 1.5,
temperature: float = 0.7,
compile: bool = False,
iterative_prompt: bool = True,
max_length: int = 2048,
chunk_length: int = 150,
prompt_text: Optional[str | list[str]] = None,
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
):
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
assert 0 < temperature < 2, "temperature must be in (0, 2)"
use_prompt = prompt_text is not None and prompt_tokens is not None
if use_prompt and isinstance(prompt_text, str):
prompt_text = [prompt_text]
prompt_tokens = [prompt_tokens]
assert use_prompt is False or len(prompt_text) == len(
prompt_tokens
), "Prompt text and tokens must have the same length"
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
tokenizer = model.tokenizer
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
encoded = []
texts = split_text(text, chunk_length) if iterative_prompt else [text]
encoded_prompts = []
if use_prompt:
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
encoded_prompts.append(
encode_tokens(
tokenizer,
string=t,
device=device,
prompt_tokens=c,
num_codebooks=model.config.num_codebooks,
)
)
for idx, text in enumerate(texts):
encoded.append(
encode_tokens(
tokenizer,
string=text,
device=device,
num_codebooks=model.config.num_codebooks,
)
)
logger.info(f"Encoded text: {text}")
# Move temperature, top_p, repetition_penalty to device
# This is important so that changing params doesn't trigger recompile
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
repetition_penalty = torch.tensor(
repetition_penalty, device=device, dtype=torch.float
)
for sample_idx in range(num_samples):
if torch.cuda.is_available():
torch.cuda.synchronize()
global_encoded = []
seg_idx = 0
while seg_idx < len(encoded):
logger.info(
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
)
seg = encoded[seg_idx]
global_encoded.append(seg)
lengths = reversed([seg.size(1) for seg in global_encoded])
# Pick last 2000 tokens
count = 0
for i, length in enumerate(lengths):
count += length
if count + length > max_length - 1024 - sum(
t.shape[1] for t in encoded_prompts
):
break
if i != 0 and i % 2 == 0:
i -= 1
# Rotate the list, always make sure first segment is included to avoid drift
if i < len(global_encoded) - 2:
partial_encoded = global_encoded[:2] + global_encoded[-i:]
else:
partial_encoded = global_encoded
if use_prompt:
partial_encoded = encoded_prompts + partial_encoded
cat_encoded = torch.cat(partial_encoded, dim=1)
prompt_length = cat_encoded.size(1)
t0 = time.perf_counter()
y = generate(
model=model,
prompt=cat_encoded,
max_new_tokens=max_new_tokens,
im_end_id=im_end_id,
decode_one_token=decode_one_token,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
if sample_idx == 0 and seg_idx == 0 and compile:
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter() - t0
tokens_generated = y.size(1) - prompt_length
tokens_sec = tokens_generated / t
logger.info(
f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
)
logger.info(
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
)
if torch.cuda.is_available():
logger.info(
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
)
# Put the generated tokens
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
codes = y[1:, prompt_length:-1].clone()
codes = codes - 1
assert (codes >= 0).all(), f"Negative code found"
decoded = y[:, prompt_length:-1].clone()
# But for global encoding, we should keep the <im_end> token
global_encoded.append(decoded)
assert (codes >= 0).all(), f"Negative code found: {codes}"
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
seg_idx += 1
# This indicates the end of the current sample
yield GenerateResponse(action="next")
@dataclass
class WrappedGenerateResponse:
status: Literal["success", "error"]
response: Optional[GenerateResponse | Exception] = None
@dataclass
class GenerateRequest:
request: dict
response_queue: queue.Queue
def launch_thread_safe_queue(
checkpoint_path,
device,
precision,
compile: bool = False,
):
input_queue = queue.Queue()
init_event = threading.Event()
def worker():
model, decode_one_token = load_model(
checkpoint_path, device, precision, compile=compile
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
init_event.set()
while True:
item: GenerateRequest | None = input_queue.get()
if item is None:
break
kwargs = item.request
response_queue = item.response_queue
try:
for chunk in generate_long(
model=model, decode_one_token=decode_one_token, **kwargs
):
response_queue.put(
WrappedGenerateResponse(status="success", response=chunk)
)
except Exception as e:
response_queue.put(WrappedGenerateResponse(status="error", response=e))
threading.Thread(target=worker, daemon=True).start()
init_event.wait()
return input_queue
def launch_thread_safe_queue_agent(
checkpoint_path,
device,
precision,
compile: bool = False,
):
input_queue = queue.Queue()
init_event = threading.Event()
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
config = BaseModelArgs.from_pretrained(checkpoint_path)
def worker():
model, decode_one_token = load_model(
checkpoint_path, device, precision, compile=compile, is_agent=True
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
init_event.set()
while True:
item: GenerateRequest | None = input_queue.get()
if item is None:
break
kwargs = item.request
response_queue = item.response_queue
try:
for token in generate_agent(
model=model,
decode_one_token=decode_one_token,
**kwargs,
):
response_queue.put(token)
response_queue.put("stop")
except Exception as e:
import traceback
logger.exception(f"Error in worker: {traceback.format_exc()}")
response_queue.put("error")
threading.Thread(target=worker, daemon=True).start()
init_event.wait()
return input_queue, tokenizer, config
@click.command()
@click.option(
"--text",
type=str,
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
)
@click.option("--prompt-text", type=str, default=None, multiple=True)
@click.option(
"--prompt-tokens",
type=click.Path(path_type=Path, exists=True),
default=None,
multiple=True,
)
@click.option("--num-samples", type=int, default=1)
@click.option("--max-new-tokens", type=int, default=0)
@click.option("--top-p", type=float, default=0.7)
@click.option("--repetition-penalty", type=float, default=1.2)
@click.option("--temperature", type=float, default=0.7)
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="checkpoints/fish-speech-1.4",
)
@click.option("--device", type=str, default="cuda")
@click.option("--compile/--no-compile", default=False)
@click.option("--seed", type=int, default=42)
@click.option("--half/--no-half", default=False)
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
@click.option("--chunk-length", type=int, default=100)
def main(
text: str,
prompt_text: Optional[list[str]],
prompt_tokens: Optional[list[Path]],
num_samples: int,
max_new_tokens: int,
top_p: int,
repetition_penalty: float,
temperature: float,
checkpoint_path: Path,
device: str,
compile: bool,
seed: int,
half: bool,
iterative_prompt: bool,
chunk_length: int,
) -> None:
precision = torch.half if half else torch.bfloat16
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
raise ValueError(
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
)
logger.info("Loading model ...")
t0 = time.time()
model, decode_one_token = load_model(
checkpoint_path, device, precision, compile=compile
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
if prompt_tokens is not None:
prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
generator = generate_long(
model=model,
device=device,
decode_one_token=decode_one_token,
text=text,
num_samples=num_samples,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
compile=compile,
iterative_prompt=iterative_prompt,
chunk_length=chunk_length,
prompt_text=prompt_text,
prompt_tokens=prompt_tokens,
)
idx = 0
codes = []
for response in generator:
if response.action == "sample":
codes.append(response.codes)
logger.info(f"Sampled text: {response.text}")
elif response.action == "next":
if codes:
np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
logger.info(f"Saved codes to codes_{idx}.npy")
logger.info(f"Next sample")
codes = []
idx += 1
else:
logger.error(f"Error: {response}")
if __name__ == "__main__":
main()
|