Anole / chameleon /viewer /backend /models /chameleon_local.py
xuefengli
update
7362797
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
import io
import json
from typing import Generator
import PIL.Image
import torch
import transformers
from tokenizers import Tokenizer
from transformers import (
MaxLengthCriteria,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
)
from chameleon.inference.alignment import AlignPromptRight
from chameleon.inference.generation import ChameleonGenerator
from chameleon.inference.image_tokenizer import ImageTokenizer
from chameleon.inference.loader import load_model
from chameleon.inference.logits_processor import (
AllowOnlyTokensAfterIndexLogitsProcessor,
AllowOnlyTokensLogitsProcessor,
InBatchInstructCFGLogitsProcessor,
)
from chameleon.inference.model_adapter import ChameleonModelAdapter
from chameleon.inference.stopping_criteria import StopOnEOS, StopOnEOSAfterBatchIndex
from chameleon.inference.token_selector import (
MultinomialTokenSelector,
ReplicatedInputTokenSelector,
)
from chameleon.inference.vocab import VocabInfo, VocabTranslation
from chameleon.viewer.backend.models.abstract_model import (
DEFAULT_IMAGE_CFG_IMAGE,
DEFAULT_IMAGE_CFG_TEXT,
DEFAULT_MULTIMODAL_CFG_IMAGE,
DEFAULT_MULTIMODAL_CFG_TEXT,
AbstractMultimodalGenerator,
MixedSequenceType,
StreamingImage,
)
from chameleon.viewer.backend.utils import get_logger
logger = get_logger(__name__)
def set_seed(seed: int) -> None:
transformers.enable_full_determinism(seed, warn_only=True)
def get_rank() -> int:
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
class ChameleonTokenizationMixin:
def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
img = self.pillow_from_bpe_tokens(bpe_tokens)
img_io = io.BytesIO()
img.save(img_io, format="PNG")
return img_io.getvalue()
def pillow_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image.Image:
image_tensor = VocabTranslation(self.vocab).convert_bpe2img(bpe_tokens)
if image_tensor.shape[0] < 1024:
padding = (
torch.ones([1024 - image_tensor.shape[0]], dtype=int) * image_tensor[0]
)
image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
return self.image_tokenizer.pil_from_img_toks(image_tensor)
def tokens_from_inputs(
self,
inputs: MixedSequenceType,
suffix_tokens: list[str] | None = None,
) -> list[int]:
tokens = [self.vocab.bos_id]
for input_ in inputs:
if isinstance(input_, str):
tokens.extend(self.tokenizer.encode(input_.strip()).ids)
elif isinstance(input_, PIL.Image.Image):
tokens.append(self.vocab.begin_image)
imgtoks = self.image_tokenizer.img_tokens_from_pil(input_)
tokens.extend(VocabTranslation(self.vocab).convert_img2bp2(imgtoks))
tokens.append(self.vocab.end_image)
else:
raise ValueError(f"Unknown input type: {type(input_)}")
if suffix_tokens is not None:
for t in suffix_tokens:
tokens.extend(self.tokenizer.encode(t).ids)
sanitized_tokens = []
for t in tokens:
if isinstance(t, torch.Tensor):
sanitized_tokens.append(t.item())
else:
sanitized_tokens.append(t)
return sanitized_tokens
class GeneratorWrapper:
def __init__(self, gen):
self.gen = gen
def __iter__(self):
return self
def __next__(self):
return next(self.gen)
class Decoder:
def __init__(
self,
chameleon_generator: "ChameleonLocalGenerator",
input_ids: list[int],
):
...
def __next__(self) -> tuple[list[int], dict | None, type["Decoder"] | None]:
...
class TextDecoder(Decoder):
def __init__(
self,
chameleon_generator: "ChameleonLocalGenerator",
input_ids: list[int],
*,
temp: float,
top_p: float,
max_seq_len: int,
# TODO: Propagage setting upwards
repetition_penalty: float,
**kwargs,
):
self.chameleon_generator = chameleon_generator
assert chameleon_generator.vocab.eos_id is not None
stopping_criteria = [
StopOnEOS(chameleon_generator.vocab.eos_id),
MaxLengthCriteria(max_seq_len),
]
if chameleon_generator.additional_eos_tokens is not None:
for token in chameleon_generator.additional_eos_tokens:
stopping_criteria.append(
StopOnEOSAfterBatchIndex(
chameleon_generator.tokenizer.token_to_id(token), [len(input_ids)]
)
)
logits_processors = [
AllowOnlyTokensLogitsProcessor(
chameleon_generator.vocab.text_tokens
+ [chameleon_generator.vocab.eos_id, chameleon_generator.vocab.begin_image]
),
# Don't allow any more images near the end since there isn't enough room
AllowOnlyTokensAfterIndexLogitsProcessor(
chameleon_generator.vocab.text_tokens + [chameleon_generator.vocab.eos_id],
# TODO: Calculate exact
1024 * 3 - 3,
),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
TemperatureLogitsWarper(temp),
TopPLogitsWarper(top_p),
]
self.gen = ChameleonGenerator(
model=ChameleonModelAdapter(chameleon_generator.model, max_seq_len=max_seq_len),
input_ids=[input_ids],
stopping_criteria=stopping_criteria,
logits_processors=logits_processors,
)
for _ in range(len(input_ids)):
next(self.gen)
def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]:
gpu_tok = next(self.gen).id.item()
cpu_tok = gpu_tok
if cpu_tok == self.chameleon_generator.vocab.begin_image:
# return "TEXT", [cpu_tok], [], False, ImageDecoder
raise StopIteration()
return (
"TEXT",
[cpu_tok],
[cpu_tok],
False,
None,
)
class ImageDecoder(Decoder):
def __init__(
self,
chameleon_generator: "ChameleonLocalGenerator",
input_ids: list[int],
*,
cfg_image_weight: float,
cfg_text_weight: float,
temp: float,
top_p: float,
yield_every_n: int,
**kwargs,
):
self.yield_every_n = yield_every_n
self.chameleon_generator = chameleon_generator
logits_processors = [
InBatchInstructCFGLogitsProcessor(cfg_text_weight, cfg_image_weight),
AllowOnlyTokensLogitsProcessor(chameleon_generator.vocab.image_tokens),
TemperatureLogitsWarper(temp),
TopPLogitsWarper(top_p),
]
image_conditioned_allowed = set(chameleon_generator.vocab.image_tokens) | {
chameleon_generator.vocab.bos_id,
chameleon_generator.vocab.begin_image,
chameleon_generator.vocab.end_image,
}
full_conditioned = input_ids
image_conditioned = [
in_id for in_id in input_ids if in_id in image_conditioned_allowed
]
unconditioned = [
chameleon_generator.vocab.bos_id,
chameleon_generator.vocab.begin_image,
]
self.gen = ChameleonGenerator(
model=ChameleonModelAdapter(
chameleon_generator.model, max_seq_len=len(input_ids) + 1024
),
input_ids=[full_conditioned, image_conditioned, unconditioned],
logits_processors=logits_processors,
alignment=AlignPromptRight(chameleon_generator.vocab.pad_id),
token_selector=ReplicatedInputTokenSelector(
MultinomialTokenSelector(), n=3
),
)
for _ in range(len(input_ids)):
next(self.gen)
self.image_builder: list[torch.LongTensor] = []
self.gpu_tok_batch: list[torch.LongTensor] = []
def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]:
while True:
gpu_tok = next(self.gen)
gpu_tok = torch.chunk(gpu_tok, chunks=3, dim=0)[0]
self.image_builder.append(gpu_tok)
self.gpu_tok_batch.append(gpu_tok)
if len(self.image_builder) == 1024:
return (
"IMAGE",
torch.tensor(self.gpu_tok_batch).tolist()
+ [self.chameleon_generator.vocab.end_image],
torch.tensor(self.image_builder).tolist(),
True,
TextDecoder,
)
elif len(self.image_builder) % self.yield_every_n == 0:
cpu_toks = torch.tensor(self.gpu_tok_batch).tolist()
self.gpu_tok_batch = []
return (
"IMAGE",
cpu_toks,
torch.tensor(self.image_builder).tolist(),
False,
None,
)
class ChameleonForwardMixin:
@torch.inference_mode()
def _generate_text_streaming(
self,
input_ids: list[int],
max_gen_tokens: int = 256,
temp: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float = 1.2,
seed: int | None = None,
) -> Generator[str, None, None]:
if seed is not None:
set_seed(seed)
logger.info(
"Rank: %s, set seed: %s",
get_rank(),
seed,
)
logits_processors = [
# Only allow text tokens and end-of-sequence.
AllowOnlyTokensLogitsProcessor(
self.vocab.text_tokens + [self.vocab.eos_id]
),
# Don't allow the first token to be end-of-sequence.
# DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
TemperatureLogitsWarper(temp),
TopPLogitsWarper(top_p),
]
stopping_criteria = [
StopOnEOS(self.vocab.eos_id),
MaxLengthCriteria(len(input_ids) + max_gen_tokens),
]
if self.additional_eos_tokens is not None:
for token in self.additional_eos_tokens:
stopping_criteria.append(
StopOnEOSAfterBatchIndex(
self.tokenizer.token_to_id(token), [len(input_ids)]
)
)
for tok in ChameleonGenerator(
model=ChameleonModelAdapter(
self.model,
max_seq_len=len(input_ids) + max_gen_tokens,
),
input_ids=[input_ids],
stopping_criteria=stopping_criteria,
logits_processors=logits_processors,
):
yield tok.tolist()
@torch.inference_mode()
def _generate_batched_text_streaming(
self,
batch: list[list[int]],
max_gen_tokens: int = 256,
temp: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float = 1.2,
seed: int | None = None,
) -> Generator[list[str], None, None]:
if seed is not None:
set_seed(seed)
logits_processors = [
# Only allow text tokens and end-of-sequence.
AllowOnlyTokensLogitsProcessor(
self.vocab.text_tokens + [self.vocab.eos_id]
),
# Don't allow the first token to be end-of-sequence.
# DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
TemperatureLogitsWarper(temp),
TopPLogitsWarper(top_p),
]
max_batch_size = max(len(p) for p in batch)
stopping_criteria = [
StopOnEOS(self.vocab.eos_id),
MaxLengthCriteria(max_batch_size + max_gen_tokens),
]
if self.additional_eos_tokens is not None:
for token in self.additional_eos_tokens:
stopping_criteria.append(
StopOnEOSAfterBatchIndex(
self.tokenizer.token_to_id(token), [len(x) for x in batch]
)
)
for tok in ChameleonGenerator(
model=ChameleonModelAdapter(
self.model,
max_seq_len=max_batch_size + max_gen_tokens,
),
input_ids=batch,
stopping_criteria=stopping_criteria,
logits_processors=logits_processors,
):
yield tok.unsqueeze(1).tolist()
@torch.inference_mode()
def _generate_image_streaming(
self,
tokenized_prompt: list[int],
temp: float = 1.0,
top_p: float = 0.8,
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
yield_every_n: int = 32,
seed: int | None = None,
) -> Generator[tuple[list[int], bool], None, None]:
if seed is not None:
set_seed(seed)
logger.info(
"Rank: %s, set seed: %s",
get_rank(),
seed,
)
decoder = ImageDecoder(
self,
tokenized_prompt,
cfg_image_weight=cfg_image_weight,
cfg_text_weight=cfg_text_weight,
temp=temp,
top_p=top_p,
yield_every_n=yield_every_n,
)
for _, _, frontend_tokens, is_final, next_decoder in GeneratorWrapper(decoder):
if next_decoder is not None:
break
yield torch.tensor(frontend_tokens).tolist(), is_final
@torch.inference_mode()
def _generate_multimodal_streaming(
self,
input_ids: list[int],
temp: float = 1.0,
top_p: float = 0.8,
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
yield_every_n: int = 32,
max_gen_tokens: int = 4096,
repetition_penalty: float = 1.2,
seed: int | None = None,
) -> Generator[tuple[str, list[int], bool], None, None]:
if seed is not None:
set_seed(seed)
logger.info(
"Rank: %s, set seed: %s",
get_rank(),
seed,
)
max_seq_len = min(len(input_ids) + max_gen_tokens, 4096)
gen_wrapper = GeneratorWrapper(
TextDecoder(
self,
input_ids,
temp=temp,
top_p=top_p,
max_seq_len=max_seq_len,
repetition_penalty=repetition_penalty,
)
)
for (
message_type,
cpu_toks,
frontend_tokens,
is_final,
next_decoder,
) in gen_wrapper:
input_ids.extend(cpu_toks)
if len(frontend_tokens) > 0:
yield message_type, frontend_tokens, is_final
if next_decoder is not None:
gen_wrapper.gen = next_decoder(
self,
input_ids,
temp=temp,
top_p=top_p,
max_seq_len=max_seq_len,
cfg_image_weight=cfg_image_weight,
cfg_text_weight=cfg_text_weight,
yield_every_n=yield_every_n,
repetition_penalty=repetition_penalty,
)
class ChameleonLocalGenerator(
AbstractMultimodalGenerator, ChameleonForwardMixin, ChameleonTokenizationMixin
):
def __init__(
self,
model_path: str,
tokenizer_path: str,
vqgan_config_path: str,
vqgan_ckpt_path: str | None = None,
additional_eos_tokens: list[str] | None = None,
) -> None:
super().__init__()
logger.info("Loading model...")
self.model = load_model(model_path)
self.additional_eos_tokens = additional_eos_tokens
logger.info("Loading tokenizer...")
tokenizer_path = tokenizer_path
self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
logger.info("Loading VQGAN...")
self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path)
@torch.inference_mode()
def generate_batched_text(
self,
prompts: list[MixedSequenceType],
max_gen_tokens: int = 256,
temp: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float = 1.2,
seed: int | None = None,
) -> list[str]:
outputs = [""] * len(prompts)
for vals in self.generate_batched_text_streaming(
prompts,
max_gen_tokens=max_gen_tokens,
temp=temp,
top_p=top_p,
repetition_penalty=repetition_penalty,
seed=seed,
):
for idx, val in enumerate(vals):
outputs[idx] += val
return outputs
@torch.inference_mode()
def generate_batched_text_streaming(
self,
prompts: list[MixedSequenceType],
max_gen_tokens: int = 256,
temp: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float = 1.2,
seed: int | None = None,
) -> Generator[list[str], None, None]:
batch = []
for prompt in prompts:
batch.append(self.tokens_from_inputs(prompt))
for tok in self._generate_batched_text_streaming(
batch,
max_gen_tokens=max_gen_tokens,
temp=temp,
top_p=top_p,
repetition_penalty=repetition_penalty,
seed=seed,
):
yield self.tokenizer.decode_batch(tok)
@torch.inference_mode()
async def generate_text_streaming(
self,
prompt: MixedSequenceType,
max_gen_tokens: int = 256,
temp: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float = 1.2,
seed: int | None = None,
debug: dict | None = None,
) -> Generator[str, None, None]:
tokenized_prompt = self.tokens_from_inputs(prompt)
if len(tokenized_prompt) > (4096 - 3):
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
return
for out in self.generate_batched_text_streaming(
[prompt],
max_gen_tokens=max_gen_tokens,
temp=temp,
top_p=top_p,
repetition_penalty=repetition_penalty,
seed=seed,
):
yield out[0]
@torch.inference_mode()
async def generate_image_streaming(
self,
prompt: MixedSequenceType,
temp: float = 1.0,
top_p: float = 0.8,
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
yield_every_n: int = 32,
seed: int | None = None,
debug: dict | None = None,
) -> Generator[StreamingImage, None, None]:
assert isinstance(prompt, list)
tokenized_prompt = self.tokens_from_inputs(prompt)
tokenized_prompt.append(self.vocab.begin_image)
if len(tokenized_prompt) > (4096 - 3 - 1024):
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
return
for tokens, final in self._generate_image_streaming(
tokenized_prompt,
temp=temp,
top_p=top_p,
cfg_image_weight=cfg_image_weight,
cfg_text_weight=cfg_text_weight,
yield_every_n=yield_every_n,
seed=seed,
):
yield StreamingImage(
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final
)
@torch.inference_mode()
async def generate_multimodal_streaming(
self,
prompt: MixedSequenceType,
temp: float = 1.0,
top_p: float = 0.8,
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
yield_every_n: int = 32,
max_gen_tokens: int = 4096,
repetition_penalty: float = 1.2,
suffix_tokens: list[str] | None = None,
seed: int | None = None,
debug: dict | None = None,
) -> Generator[MixedSequenceType, None, None]:
input_ids = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens)
if len(input_ids) > (4096 - 3):
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens."
return
for token_type, tokens, is_final in self._generate_multimodal_streaming(
input_ids,
temp=temp,
top_p=top_p,
cfg_image_weight=cfg_image_weight,
cfg_text_weight=cfg_text_weight,
yield_every_n=yield_every_n,
max_gen_tokens=max_gen_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
):
match token_type:
case "TEXT":
yield self.tokenizer.decode(tokens)
case "IMAGE":
yield StreamingImage(
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)),
final=is_final,
)
case _:
raise ValueError("Unknown token type")