# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Chameleon License found in the # LICENSE file in the root directory of this source tree. import io import json from typing import Generator import PIL.Image import torch import transformers from tokenizers import Tokenizer from transformers import ( MaxLengthCriteria, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, ) from chameleon.inference.alignment import AlignPromptRight from chameleon.inference.generation import ChameleonGenerator from chameleon.inference.image_tokenizer import ImageTokenizer from chameleon.inference.loader import load_model from chameleon.inference.logits_processor import ( AllowOnlyTokensAfterIndexLogitsProcessor, AllowOnlyTokensLogitsProcessor, InBatchInstructCFGLogitsProcessor, ) from chameleon.inference.model_adapter import ChameleonModelAdapter from chameleon.inference.stopping_criteria import StopOnEOS, StopOnEOSAfterBatchIndex from chameleon.inference.token_selector import ( MultinomialTokenSelector, ReplicatedInputTokenSelector, ) from chameleon.inference.vocab import VocabInfo, VocabTranslation from chameleon.viewer.backend.models.abstract_model import ( DEFAULT_IMAGE_CFG_IMAGE, DEFAULT_IMAGE_CFG_TEXT, DEFAULT_MULTIMODAL_CFG_IMAGE, DEFAULT_MULTIMODAL_CFG_TEXT, AbstractMultimodalGenerator, MixedSequenceType, StreamingImage, ) from chameleon.viewer.backend.utils import get_logger logger = get_logger(__name__) def set_seed(seed: int) -> None: transformers.enable_full_determinism(seed, warn_only=True) def get_rank() -> int: if torch.distributed.is_initialized(): return torch.distributed.get_rank() else: return 0 class ChameleonTokenizationMixin: def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes: img = self.pillow_from_bpe_tokens(bpe_tokens) img_io = io.BytesIO() img.save(img_io, format="PNG") return img_io.getvalue() def pillow_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image.Image: image_tensor = VocabTranslation(self.vocab).convert_bpe2img(bpe_tokens) if image_tensor.shape[0] < 1024: padding = ( torch.ones([1024 - image_tensor.shape[0]], dtype=int) * image_tensor[0] ) image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0) return self.image_tokenizer.pil_from_img_toks(image_tensor) def tokens_from_inputs( self, inputs: MixedSequenceType, suffix_tokens: list[str] | None = None, ) -> list[int]: tokens = [self.vocab.bos_id] for input_ in inputs: if isinstance(input_, str): tokens.extend(self.tokenizer.encode(input_.strip()).ids) elif isinstance(input_, PIL.Image.Image): tokens.append(self.vocab.begin_image) imgtoks = self.image_tokenizer.img_tokens_from_pil(input_) tokens.extend(VocabTranslation(self.vocab).convert_img2bp2(imgtoks)) tokens.append(self.vocab.end_image) else: raise ValueError(f"Unknown input type: {type(input_)}") if suffix_tokens is not None: for t in suffix_tokens: tokens.extend(self.tokenizer.encode(t).ids) sanitized_tokens = [] for t in tokens: if isinstance(t, torch.Tensor): sanitized_tokens.append(t.item()) else: sanitized_tokens.append(t) return sanitized_tokens class GeneratorWrapper: def __init__(self, gen): self.gen = gen def __iter__(self): return self def __next__(self): return next(self.gen) class Decoder: def __init__( self, chameleon_generator: "ChameleonLocalGenerator", input_ids: list[int], ): ... def __next__(self) -> tuple[list[int], dict | None, type["Decoder"] | None]: ... class TextDecoder(Decoder): def __init__( self, chameleon_generator: "ChameleonLocalGenerator", input_ids: list[int], *, temp: float, top_p: float, max_seq_len: int, # TODO: Propagage setting upwards repetition_penalty: float, **kwargs, ): self.chameleon_generator = chameleon_generator assert chameleon_generator.vocab.eos_id is not None stopping_criteria = [ StopOnEOS(chameleon_generator.vocab.eos_id), MaxLengthCriteria(max_seq_len), ] if chameleon_generator.additional_eos_tokens is not None: for token in chameleon_generator.additional_eos_tokens: stopping_criteria.append( StopOnEOSAfterBatchIndex( chameleon_generator.tokenizer.token_to_id(token), [len(input_ids)] ) ) logits_processors = [ AllowOnlyTokensLogitsProcessor( chameleon_generator.vocab.text_tokens + [chameleon_generator.vocab.eos_id, chameleon_generator.vocab.begin_image] ), # Don't allow any more images near the end since there isn't enough room AllowOnlyTokensAfterIndexLogitsProcessor( chameleon_generator.vocab.text_tokens + [chameleon_generator.vocab.eos_id], # TODO: Calculate exact 1024 * 3 - 3, ), RepetitionPenaltyLogitsProcessor(repetition_penalty), TemperatureLogitsWarper(temp), TopPLogitsWarper(top_p), ] self.gen = ChameleonGenerator( model=ChameleonModelAdapter(chameleon_generator.model, max_seq_len=max_seq_len), input_ids=[input_ids], stopping_criteria=stopping_criteria, logits_processors=logits_processors, ) for _ in range(len(input_ids)): next(self.gen) def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]: gpu_tok = next(self.gen).id.item() cpu_tok = gpu_tok if cpu_tok == self.chameleon_generator.vocab.begin_image: # return "TEXT", [cpu_tok], [], False, ImageDecoder raise StopIteration() return ( "TEXT", [cpu_tok], [cpu_tok], False, None, ) class ImageDecoder(Decoder): def __init__( self, chameleon_generator: "ChameleonLocalGenerator", input_ids: list[int], *, cfg_image_weight: float, cfg_text_weight: float, temp: float, top_p: float, yield_every_n: int, **kwargs, ): self.yield_every_n = yield_every_n self.chameleon_generator = chameleon_generator logits_processors = [ InBatchInstructCFGLogitsProcessor(cfg_text_weight, cfg_image_weight), AllowOnlyTokensLogitsProcessor(chameleon_generator.vocab.image_tokens), TemperatureLogitsWarper(temp), TopPLogitsWarper(top_p), ] image_conditioned_allowed = set(chameleon_generator.vocab.image_tokens) | { chameleon_generator.vocab.bos_id, chameleon_generator.vocab.begin_image, chameleon_generator.vocab.end_image, } full_conditioned = input_ids image_conditioned = [ in_id for in_id in input_ids if in_id in image_conditioned_allowed ] unconditioned = [ chameleon_generator.vocab.bos_id, chameleon_generator.vocab.begin_image, ] self.gen = ChameleonGenerator( model=ChameleonModelAdapter( chameleon_generator.model, max_seq_len=len(input_ids) + 1024 ), input_ids=[full_conditioned, image_conditioned, unconditioned], logits_processors=logits_processors, alignment=AlignPromptRight(chameleon_generator.vocab.pad_id), token_selector=ReplicatedInputTokenSelector( MultinomialTokenSelector(), n=3 ), ) for _ in range(len(input_ids)): next(self.gen) self.image_builder: list[torch.LongTensor] = [] self.gpu_tok_batch: list[torch.LongTensor] = [] def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]: while True: gpu_tok = next(self.gen) gpu_tok = torch.chunk(gpu_tok, chunks=3, dim=0)[0] self.image_builder.append(gpu_tok) self.gpu_tok_batch.append(gpu_tok) if len(self.image_builder) == 1024: return ( "IMAGE", torch.tensor(self.gpu_tok_batch).tolist() + [self.chameleon_generator.vocab.end_image], torch.tensor(self.image_builder).tolist(), True, TextDecoder, ) elif len(self.image_builder) % self.yield_every_n == 0: cpu_toks = torch.tensor(self.gpu_tok_batch).tolist() self.gpu_tok_batch = [] return ( "IMAGE", cpu_toks, torch.tensor(self.image_builder).tolist(), False, None, ) class ChameleonForwardMixin: @torch.inference_mode() def _generate_text_streaming( self, input_ids: list[int], max_gen_tokens: int = 256, temp: float = 1.0, top_p: float = 0.8, repetition_penalty: float = 1.2, seed: int | None = None, ) -> Generator[str, None, None]: if seed is not None: set_seed(seed) logger.info( "Rank: %s, set seed: %s", get_rank(), seed, ) logits_processors = [ # Only allow text tokens and end-of-sequence. AllowOnlyTokensLogitsProcessor( self.vocab.text_tokens + [self.vocab.eos_id] ), # Don't allow the first token to be end-of-sequence. # DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()), RepetitionPenaltyLogitsProcessor(repetition_penalty), TemperatureLogitsWarper(temp), TopPLogitsWarper(top_p), ] stopping_criteria = [ StopOnEOS(self.vocab.eos_id), MaxLengthCriteria(len(input_ids) + max_gen_tokens), ] if self.additional_eos_tokens is not None: for token in self.additional_eos_tokens: stopping_criteria.append( StopOnEOSAfterBatchIndex( self.tokenizer.token_to_id(token), [len(input_ids)] ) ) for tok in ChameleonGenerator( model=ChameleonModelAdapter( self.model, max_seq_len=len(input_ids) + max_gen_tokens, ), input_ids=[input_ids], stopping_criteria=stopping_criteria, logits_processors=logits_processors, ): yield tok.tolist() @torch.inference_mode() def _generate_batched_text_streaming( self, batch: list[list[int]], max_gen_tokens: int = 256, temp: float = 1.0, top_p: float = 0.8, repetition_penalty: float = 1.2, seed: int | None = None, ) -> Generator[list[str], None, None]: if seed is not None: set_seed(seed) logits_processors = [ # Only allow text tokens and end-of-sequence. AllowOnlyTokensLogitsProcessor( self.vocab.text_tokens + [self.vocab.eos_id] ), # Don't allow the first token to be end-of-sequence. # DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()), RepetitionPenaltyLogitsProcessor(repetition_penalty), TemperatureLogitsWarper(temp), TopPLogitsWarper(top_p), ] max_batch_size = max(len(p) for p in batch) stopping_criteria = [ StopOnEOS(self.vocab.eos_id), MaxLengthCriteria(max_batch_size + max_gen_tokens), ] if self.additional_eos_tokens is not None: for token in self.additional_eos_tokens: stopping_criteria.append( StopOnEOSAfterBatchIndex( self.tokenizer.token_to_id(token), [len(x) for x in batch] ) ) for tok in ChameleonGenerator( model=ChameleonModelAdapter( self.model, max_seq_len=max_batch_size + max_gen_tokens, ), input_ids=batch, stopping_criteria=stopping_criteria, logits_processors=logits_processors, ): yield tok.unsqueeze(1).tolist() @torch.inference_mode() def _generate_image_streaming( self, tokenized_prompt: list[int], temp: float = 1.0, top_p: float = 0.8, cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE, cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT, yield_every_n: int = 32, seed: int | None = None, ) -> Generator[tuple[list[int], bool], None, None]: if seed is not None: set_seed(seed) logger.info( "Rank: %s, set seed: %s", get_rank(), seed, ) decoder = ImageDecoder( self, tokenized_prompt, cfg_image_weight=cfg_image_weight, cfg_text_weight=cfg_text_weight, temp=temp, top_p=top_p, yield_every_n=yield_every_n, ) for _, _, frontend_tokens, is_final, next_decoder in GeneratorWrapper(decoder): if next_decoder is not None: break yield torch.tensor(frontend_tokens).tolist(), is_final @torch.inference_mode() def _generate_multimodal_streaming( self, input_ids: list[int], temp: float = 1.0, top_p: float = 0.8, cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE, cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT, yield_every_n: int = 32, max_gen_tokens: int = 4096, repetition_penalty: float = 1.2, seed: int | None = None, ) -> Generator[tuple[str, list[int], bool], None, None]: if seed is not None: set_seed(seed) logger.info( "Rank: %s, set seed: %s", get_rank(), seed, ) max_seq_len = min(len(input_ids) + max_gen_tokens, 4096) gen_wrapper = GeneratorWrapper( TextDecoder( self, input_ids, temp=temp, top_p=top_p, max_seq_len=max_seq_len, repetition_penalty=repetition_penalty, ) ) for ( message_type, cpu_toks, frontend_tokens, is_final, next_decoder, ) in gen_wrapper: input_ids.extend(cpu_toks) if len(frontend_tokens) > 0: yield message_type, frontend_tokens, is_final if next_decoder is not None: gen_wrapper.gen = next_decoder( self, input_ids, temp=temp, top_p=top_p, max_seq_len=max_seq_len, cfg_image_weight=cfg_image_weight, cfg_text_weight=cfg_text_weight, yield_every_n=yield_every_n, repetition_penalty=repetition_penalty, ) class ChameleonLocalGenerator( AbstractMultimodalGenerator, ChameleonForwardMixin, ChameleonTokenizationMixin ): def __init__( self, model_path: str, tokenizer_path: str, vqgan_config_path: str, vqgan_ckpt_path: str | None = None, additional_eos_tokens: list[str] | None = None, ) -> None: super().__init__() logger.info("Loading model...") self.model = load_model(model_path) self.additional_eos_tokens = additional_eos_tokens logger.info("Loading tokenizer...") tokenizer_path = tokenizer_path self.tokenizer = Tokenizer.from_file(str(tokenizer_path)) self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"]) logger.info("Loading VQGAN...") self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path) @torch.inference_mode() def generate_batched_text( self, prompts: list[MixedSequenceType], max_gen_tokens: int = 256, temp: float = 1.0, top_p: float = 0.8, repetition_penalty: float = 1.2, seed: int | None = None, ) -> list[str]: outputs = [""] * len(prompts) for vals in self.generate_batched_text_streaming( prompts, max_gen_tokens=max_gen_tokens, temp=temp, top_p=top_p, repetition_penalty=repetition_penalty, seed=seed, ): for idx, val in enumerate(vals): outputs[idx] += val return outputs @torch.inference_mode() def generate_batched_text_streaming( self, prompts: list[MixedSequenceType], max_gen_tokens: int = 256, temp: float = 1.0, top_p: float = 0.8, repetition_penalty: float = 1.2, seed: int | None = None, ) -> Generator[list[str], None, None]: batch = [] for prompt in prompts: batch.append(self.tokens_from_inputs(prompt)) for tok in self._generate_batched_text_streaming( batch, max_gen_tokens=max_gen_tokens, temp=temp, top_p=top_p, repetition_penalty=repetition_penalty, seed=seed, ): yield self.tokenizer.decode_batch(tok) @torch.inference_mode() async def generate_text_streaming( self, prompt: MixedSequenceType, max_gen_tokens: int = 256, temp: float = 1.0, top_p: float = 0.8, repetition_penalty: float = 1.2, seed: int | None = None, debug: dict | None = None, ) -> Generator[str, None, None]: tokenized_prompt = self.tokens_from_inputs(prompt) if len(tokenized_prompt) > (4096 - 3): yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." return for out in self.generate_batched_text_streaming( [prompt], max_gen_tokens=max_gen_tokens, temp=temp, top_p=top_p, repetition_penalty=repetition_penalty, seed=seed, ): yield out[0] @torch.inference_mode() async def generate_image_streaming( self, prompt: MixedSequenceType, temp: float = 1.0, top_p: float = 0.8, cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE, cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT, yield_every_n: int = 32, seed: int | None = None, debug: dict | None = None, ) -> Generator[StreamingImage, None, None]: assert isinstance(prompt, list) tokenized_prompt = self.tokens_from_inputs(prompt) tokenized_prompt.append(self.vocab.begin_image) if len(tokenized_prompt) > (4096 - 3 - 1024): yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." return for tokens, final in self._generate_image_streaming( tokenized_prompt, temp=temp, top_p=top_p, cfg_image_weight=cfg_image_weight, cfg_text_weight=cfg_text_weight, yield_every_n=yield_every_n, seed=seed, ): yield StreamingImage( image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final ) @torch.inference_mode() async def generate_multimodal_streaming( self, prompt: MixedSequenceType, temp: float = 1.0, top_p: float = 0.8, cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE, cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT, yield_every_n: int = 32, max_gen_tokens: int = 4096, repetition_penalty: float = 1.2, suffix_tokens: list[str] | None = None, seed: int | None = None, debug: dict | None = None, ) -> Generator[MixedSequenceType, None, None]: input_ids = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens) if len(input_ids) > (4096 - 3): yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens." return for token_type, tokens, is_final in self._generate_multimodal_streaming( input_ids, temp=temp, top_p=top_p, cfg_image_weight=cfg_image_weight, cfg_text_weight=cfg_text_weight, yield_every_n=yield_every_n, max_gen_tokens=max_gen_tokens, repetition_penalty=repetition_penalty, seed=seed, ): match token_type: case "TEXT": yield self.tokenizer.decode(tokens) case "IMAGE": yield StreamingImage( image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=is_final, ) case _: raise ValueError("Unknown token type")