Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
import io | |
import json | |
from typing import Generator | |
import PIL.Image | |
import torch | |
import transformers | |
from tokenizers import Tokenizer | |
from transformers import ( | |
MaxLengthCriteria, | |
RepetitionPenaltyLogitsProcessor, | |
TemperatureLogitsWarper, | |
TopPLogitsWarper, | |
) | |
from chameleon.inference.alignment import AlignPromptRight | |
from chameleon.inference.generation import ChameleonGenerator | |
from chameleon.inference.image_tokenizer import ImageTokenizer | |
from chameleon.inference.loader import load_model | |
from chameleon.inference.logits_processor import ( | |
AllowOnlyTokensAfterIndexLogitsProcessor, | |
AllowOnlyTokensLogitsProcessor, | |
InBatchInstructCFGLogitsProcessor, | |
) | |
from chameleon.inference.model_adapter import ChameleonModelAdapter | |
from chameleon.inference.stopping_criteria import StopOnEOS, StopOnEOSAfterBatchIndex | |
from chameleon.inference.token_selector import ( | |
MultinomialTokenSelector, | |
ReplicatedInputTokenSelector, | |
) | |
from chameleon.inference.vocab import VocabInfo, VocabTranslation | |
from chameleon.viewer.backend.models.abstract_model import ( | |
DEFAULT_IMAGE_CFG_IMAGE, | |
DEFAULT_IMAGE_CFG_TEXT, | |
DEFAULT_MULTIMODAL_CFG_IMAGE, | |
DEFAULT_MULTIMODAL_CFG_TEXT, | |
AbstractMultimodalGenerator, | |
MixedSequenceType, | |
StreamingImage, | |
) | |
from chameleon.viewer.backend.utils import get_logger | |
logger = get_logger(__name__) | |
def set_seed(seed: int) -> None: | |
transformers.enable_full_determinism(seed, warn_only=True) | |
def get_rank() -> int: | |
if torch.distributed.is_initialized(): | |
return torch.distributed.get_rank() | |
else: | |
return 0 | |
class ChameleonTokenizationMixin: | |
def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes: | |
img = self.pillow_from_bpe_tokens(bpe_tokens) | |
img_io = io.BytesIO() | |
img.save(img_io, format="PNG") | |
return img_io.getvalue() | |
def pillow_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image.Image: | |
image_tensor = VocabTranslation(self.vocab).convert_bpe2img(bpe_tokens) | |
if image_tensor.shape[0] < 1024: | |
padding = ( | |
torch.ones([1024 - image_tensor.shape[0]], dtype=int) * image_tensor[0] | |
) | |
image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0) | |
return self.image_tokenizer.pil_from_img_toks(image_tensor) | |
def tokens_from_inputs( | |
self, | |
inputs: MixedSequenceType, | |
suffix_tokens: list[str] | None = None, | |
) -> list[int]: | |
tokens = [self.vocab.bos_id] | |
for input_ in inputs: | |
if isinstance(input_, str): | |
tokens.extend(self.tokenizer.encode(input_.strip()).ids) | |
elif isinstance(input_, PIL.Image.Image): | |
tokens.append(self.vocab.begin_image) | |
imgtoks = self.image_tokenizer.img_tokens_from_pil(input_) | |
tokens.extend(VocabTranslation(self.vocab).convert_img2bp2(imgtoks)) | |
tokens.append(self.vocab.end_image) | |
else: | |
raise ValueError(f"Unknown input type: {type(input_)}") | |
if suffix_tokens is not None: | |
for t in suffix_tokens: | |
tokens.extend(self.tokenizer.encode(t).ids) | |
sanitized_tokens = [] | |
for t in tokens: | |
if isinstance(t, torch.Tensor): | |
sanitized_tokens.append(t.item()) | |
else: | |
sanitized_tokens.append(t) | |
return sanitized_tokens | |
class GeneratorWrapper: | |
def __init__(self, gen): | |
self.gen = gen | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return next(self.gen) | |
class Decoder: | |
def __init__( | |
self, | |
chameleon_generator: "ChameleonLocalGenerator", | |
input_ids: list[int], | |
): | |
... | |
def __next__(self) -> tuple[list[int], dict | None, type["Decoder"] | None]: | |
... | |
class TextDecoder(Decoder): | |
def __init__( | |
self, | |
chameleon_generator: "ChameleonLocalGenerator", | |
input_ids: list[int], | |
*, | |
temp: float, | |
top_p: float, | |
max_seq_len: int, | |
# TODO: Propagage setting upwards | |
repetition_penalty: float, | |
**kwargs, | |
): | |
self.chameleon_generator = chameleon_generator | |
assert chameleon_generator.vocab.eos_id is not None | |
stopping_criteria = [ | |
StopOnEOS(chameleon_generator.vocab.eos_id), | |
MaxLengthCriteria(max_seq_len), | |
] | |
if chameleon_generator.additional_eos_tokens is not None: | |
for token in chameleon_generator.additional_eos_tokens: | |
stopping_criteria.append( | |
StopOnEOSAfterBatchIndex( | |
chameleon_generator.tokenizer.token_to_id(token), [len(input_ids)] | |
) | |
) | |
logits_processors = [ | |
AllowOnlyTokensLogitsProcessor( | |
chameleon_generator.vocab.text_tokens | |
+ [chameleon_generator.vocab.eos_id, chameleon_generator.vocab.begin_image] | |
), | |
# Don't allow any more images near the end since there isn't enough room | |
AllowOnlyTokensAfterIndexLogitsProcessor( | |
chameleon_generator.vocab.text_tokens + [chameleon_generator.vocab.eos_id], | |
# TODO: Calculate exact | |
1024 * 3 - 3, | |
), | |
RepetitionPenaltyLogitsProcessor(repetition_penalty), | |
TemperatureLogitsWarper(temp), | |
TopPLogitsWarper(top_p), | |
] | |
self.gen = ChameleonGenerator( | |
model=ChameleonModelAdapter(chameleon_generator.model, max_seq_len=max_seq_len), | |
input_ids=[input_ids], | |
stopping_criteria=stopping_criteria, | |
logits_processors=logits_processors, | |
) | |
for _ in range(len(input_ids)): | |
next(self.gen) | |
def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]: | |
gpu_tok = next(self.gen).id.item() | |
cpu_tok = gpu_tok | |
if cpu_tok == self.chameleon_generator.vocab.begin_image: | |
# return "TEXT", [cpu_tok], [], False, ImageDecoder | |
raise StopIteration() | |
return ( | |
"TEXT", | |
[cpu_tok], | |
[cpu_tok], | |
False, | |
None, | |
) | |
class ImageDecoder(Decoder): | |
def __init__( | |
self, | |
chameleon_generator: "ChameleonLocalGenerator", | |
input_ids: list[int], | |
*, | |
cfg_image_weight: float, | |
cfg_text_weight: float, | |
temp: float, | |
top_p: float, | |
yield_every_n: int, | |
**kwargs, | |
): | |
self.yield_every_n = yield_every_n | |
self.chameleon_generator = chameleon_generator | |
logits_processors = [ | |
InBatchInstructCFGLogitsProcessor(cfg_text_weight, cfg_image_weight), | |
AllowOnlyTokensLogitsProcessor(chameleon_generator.vocab.image_tokens), | |
TemperatureLogitsWarper(temp), | |
TopPLogitsWarper(top_p), | |
] | |
image_conditioned_allowed = set(chameleon_generator.vocab.image_tokens) | { | |
chameleon_generator.vocab.bos_id, | |
chameleon_generator.vocab.begin_image, | |
chameleon_generator.vocab.end_image, | |
} | |
full_conditioned = input_ids | |
image_conditioned = [ | |
in_id for in_id in input_ids if in_id in image_conditioned_allowed | |
] | |
unconditioned = [ | |
chameleon_generator.vocab.bos_id, | |
chameleon_generator.vocab.begin_image, | |
] | |
self.gen = ChameleonGenerator( | |
model=ChameleonModelAdapter( | |
chameleon_generator.model, max_seq_len=len(input_ids) + 1024 | |
), | |
input_ids=[full_conditioned, image_conditioned, unconditioned], | |
logits_processors=logits_processors, | |
alignment=AlignPromptRight(chameleon_generator.vocab.pad_id), | |
token_selector=ReplicatedInputTokenSelector( | |
MultinomialTokenSelector(), n=3 | |
), | |
) | |
for _ in range(len(input_ids)): | |
next(self.gen) | |
self.image_builder: list[torch.LongTensor] = [] | |
self.gpu_tok_batch: list[torch.LongTensor] = [] | |
def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]: | |
while True: | |
gpu_tok = next(self.gen) | |
gpu_tok = torch.chunk(gpu_tok, chunks=3, dim=0)[0] | |
self.image_builder.append(gpu_tok) | |
self.gpu_tok_batch.append(gpu_tok) | |
if len(self.image_builder) == 1024: | |
return ( | |
"IMAGE", | |
torch.tensor(self.gpu_tok_batch).tolist() | |
+ [self.chameleon_generator.vocab.end_image], | |
torch.tensor(self.image_builder).tolist(), | |
True, | |
TextDecoder, | |
) | |
elif len(self.image_builder) % self.yield_every_n == 0: | |
cpu_toks = torch.tensor(self.gpu_tok_batch).tolist() | |
self.gpu_tok_batch = [] | |
return ( | |
"IMAGE", | |
cpu_toks, | |
torch.tensor(self.image_builder).tolist(), | |
False, | |
None, | |
) | |
class ChameleonForwardMixin: | |
def _generate_text_streaming( | |
self, | |
input_ids: list[int], | |
max_gen_tokens: int = 256, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
repetition_penalty: float = 1.2, | |
seed: int | None = None, | |
) -> Generator[str, None, None]: | |
if seed is not None: | |
set_seed(seed) | |
logger.info( | |
"Rank: %s, set seed: %s", | |
get_rank(), | |
seed, | |
) | |
logits_processors = [ | |
# Only allow text tokens and end-of-sequence. | |
AllowOnlyTokensLogitsProcessor( | |
self.vocab.text_tokens + [self.vocab.eos_id] | |
), | |
# Don't allow the first token to be end-of-sequence. | |
# DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()), | |
RepetitionPenaltyLogitsProcessor(repetition_penalty), | |
TemperatureLogitsWarper(temp), | |
TopPLogitsWarper(top_p), | |
] | |
stopping_criteria = [ | |
StopOnEOS(self.vocab.eos_id), | |
MaxLengthCriteria(len(input_ids) + max_gen_tokens), | |
] | |
if self.additional_eos_tokens is not None: | |
for token in self.additional_eos_tokens: | |
stopping_criteria.append( | |
StopOnEOSAfterBatchIndex( | |
self.tokenizer.token_to_id(token), [len(input_ids)] | |
) | |
) | |
for tok in ChameleonGenerator( | |
model=ChameleonModelAdapter( | |
self.model, | |
max_seq_len=len(input_ids) + max_gen_tokens, | |
), | |
input_ids=[input_ids], | |
stopping_criteria=stopping_criteria, | |
logits_processors=logits_processors, | |
): | |
yield tok.tolist() | |
def _generate_batched_text_streaming( | |
self, | |
batch: list[list[int]], | |
max_gen_tokens: int = 256, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
repetition_penalty: float = 1.2, | |
seed: int | None = None, | |
) -> Generator[list[str], None, None]: | |
if seed is not None: | |
set_seed(seed) | |
logits_processors = [ | |
# Only allow text tokens and end-of-sequence. | |
AllowOnlyTokensLogitsProcessor( | |
self.vocab.text_tokens + [self.vocab.eos_id] | |
), | |
# Don't allow the first token to be end-of-sequence. | |
# DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()), | |
RepetitionPenaltyLogitsProcessor(repetition_penalty), | |
TemperatureLogitsWarper(temp), | |
TopPLogitsWarper(top_p), | |
] | |
max_batch_size = max(len(p) for p in batch) | |
stopping_criteria = [ | |
StopOnEOS(self.vocab.eos_id), | |
MaxLengthCriteria(max_batch_size + max_gen_tokens), | |
] | |
if self.additional_eos_tokens is not None: | |
for token in self.additional_eos_tokens: | |
stopping_criteria.append( | |
StopOnEOSAfterBatchIndex( | |
self.tokenizer.token_to_id(token), [len(x) for x in batch] | |
) | |
) | |
for tok in ChameleonGenerator( | |
model=ChameleonModelAdapter( | |
self.model, | |
max_seq_len=max_batch_size + max_gen_tokens, | |
), | |
input_ids=batch, | |
stopping_criteria=stopping_criteria, | |
logits_processors=logits_processors, | |
): | |
yield tok.unsqueeze(1).tolist() | |
def _generate_image_streaming( | |
self, | |
tokenized_prompt: list[int], | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE, | |
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT, | |
yield_every_n: int = 32, | |
seed: int | None = None, | |
) -> Generator[tuple[list[int], bool], None, None]: | |
if seed is not None: | |
set_seed(seed) | |
logger.info( | |
"Rank: %s, set seed: %s", | |
get_rank(), | |
seed, | |
) | |
decoder = ImageDecoder( | |
self, | |
tokenized_prompt, | |
cfg_image_weight=cfg_image_weight, | |
cfg_text_weight=cfg_text_weight, | |
temp=temp, | |
top_p=top_p, | |
yield_every_n=yield_every_n, | |
) | |
for _, _, frontend_tokens, is_final, next_decoder in GeneratorWrapper(decoder): | |
if next_decoder is not None: | |
break | |
yield torch.tensor(frontend_tokens).tolist(), is_final | |
def _generate_multimodal_streaming( | |
self, | |
input_ids: list[int], | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE, | |
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT, | |
yield_every_n: int = 32, | |
max_gen_tokens: int = 4096, | |
repetition_penalty: float = 1.2, | |
seed: int | None = None, | |
) -> Generator[tuple[str, list[int], bool], None, None]: | |
if seed is not None: | |
set_seed(seed) | |
logger.info( | |
"Rank: %s, set seed: %s", | |
get_rank(), | |
seed, | |
) | |
max_seq_len = min(len(input_ids) + max_gen_tokens, 4096) | |
gen_wrapper = GeneratorWrapper( | |
TextDecoder( | |
self, | |
input_ids, | |
temp=temp, | |
top_p=top_p, | |
max_seq_len=max_seq_len, | |
repetition_penalty=repetition_penalty, | |
) | |
) | |
for ( | |
message_type, | |
cpu_toks, | |
frontend_tokens, | |
is_final, | |
next_decoder, | |
) in gen_wrapper: | |
input_ids.extend(cpu_toks) | |
if len(frontend_tokens) > 0: | |
yield message_type, frontend_tokens, is_final | |
if next_decoder is not None: | |
gen_wrapper.gen = next_decoder( | |
self, | |
input_ids, | |
temp=temp, | |
top_p=top_p, | |
max_seq_len=max_seq_len, | |
cfg_image_weight=cfg_image_weight, | |
cfg_text_weight=cfg_text_weight, | |
yield_every_n=yield_every_n, | |
repetition_penalty=repetition_penalty, | |
) | |
class ChameleonLocalGenerator( | |
AbstractMultimodalGenerator, ChameleonForwardMixin, ChameleonTokenizationMixin | |
): | |
def __init__( | |
self, | |
model_path: str, | |
tokenizer_path: str, | |
vqgan_config_path: str, | |
vqgan_ckpt_path: str | None = None, | |
additional_eos_tokens: list[str] | None = None, | |
) -> None: | |
super().__init__() | |
logger.info("Loading model...") | |
self.model = load_model(model_path) | |
self.additional_eos_tokens = additional_eos_tokens | |
logger.info("Loading tokenizer...") | |
tokenizer_path = tokenizer_path | |
self.tokenizer = Tokenizer.from_file(str(tokenizer_path)) | |
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"]) | |
logger.info("Loading VQGAN...") | |
self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path) | |
def generate_batched_text( | |
self, | |
prompts: list[MixedSequenceType], | |
max_gen_tokens: int = 256, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
repetition_penalty: float = 1.2, | |
seed: int | None = None, | |
) -> list[str]: | |
outputs = [""] * len(prompts) | |
for vals in self.generate_batched_text_streaming( | |
prompts, | |
max_gen_tokens=max_gen_tokens, | |
temp=temp, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
seed=seed, | |
): | |
for idx, val in enumerate(vals): | |
outputs[idx] += val | |
return outputs | |
def generate_batched_text_streaming( | |
self, | |
prompts: list[MixedSequenceType], | |
max_gen_tokens: int = 256, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
repetition_penalty: float = 1.2, | |
seed: int | None = None, | |
) -> Generator[list[str], None, None]: | |
batch = [] | |
for prompt in prompts: | |
batch.append(self.tokens_from_inputs(prompt)) | |
for tok in self._generate_batched_text_streaming( | |
batch, | |
max_gen_tokens=max_gen_tokens, | |
temp=temp, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
seed=seed, | |
): | |
yield self.tokenizer.decode_batch(tok) | |
async def generate_text_streaming( | |
self, | |
prompt: MixedSequenceType, | |
max_gen_tokens: int = 256, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
repetition_penalty: float = 1.2, | |
seed: int | None = None, | |
debug: dict | None = None, | |
) -> Generator[str, None, None]: | |
tokenized_prompt = self.tokens_from_inputs(prompt) | |
if len(tokenized_prompt) > (4096 - 3): | |
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." | |
return | |
for out in self.generate_batched_text_streaming( | |
[prompt], | |
max_gen_tokens=max_gen_tokens, | |
temp=temp, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
seed=seed, | |
): | |
yield out[0] | |
async def generate_image_streaming( | |
self, | |
prompt: MixedSequenceType, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE, | |
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT, | |
yield_every_n: int = 32, | |
seed: int | None = None, | |
debug: dict | None = None, | |
) -> Generator[StreamingImage, None, None]: | |
assert isinstance(prompt, list) | |
tokenized_prompt = self.tokens_from_inputs(prompt) | |
tokenized_prompt.append(self.vocab.begin_image) | |
if len(tokenized_prompt) > (4096 - 3 - 1024): | |
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." | |
return | |
for tokens, final in self._generate_image_streaming( | |
tokenized_prompt, | |
temp=temp, | |
top_p=top_p, | |
cfg_image_weight=cfg_image_weight, | |
cfg_text_weight=cfg_text_weight, | |
yield_every_n=yield_every_n, | |
seed=seed, | |
): | |
yield StreamingImage( | |
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final | |
) | |
async def generate_multimodal_streaming( | |
self, | |
prompt: MixedSequenceType, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE, | |
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT, | |
yield_every_n: int = 32, | |
max_gen_tokens: int = 4096, | |
repetition_penalty: float = 1.2, | |
suffix_tokens: list[str] | None = None, | |
seed: int | None = None, | |
debug: dict | None = None, | |
) -> Generator[MixedSequenceType, None, None]: | |
input_ids = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens) | |
if len(input_ids) > (4096 - 3): | |
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens." | |
return | |
for token_type, tokens, is_final in self._generate_multimodal_streaming( | |
input_ids, | |
temp=temp, | |
top_p=top_p, | |
cfg_image_weight=cfg_image_weight, | |
cfg_text_weight=cfg_text_weight, | |
yield_every_n=yield_every_n, | |
max_gen_tokens=max_gen_tokens, | |
repetition_penalty=repetition_penalty, | |
seed=seed, | |
): | |
match token_type: | |
case "TEXT": | |
yield self.tokenizer.decode(tokens) | |
case "IMAGE": | |
yield StreamingImage( | |
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), | |
final=is_final, | |
) | |
case _: | |
raise ValueError("Unknown token type") | |