Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
import asyncio | |
import json | |
import multiprocessing | |
import os | |
import random | |
import sys | |
import threading | |
import time | |
import traceback | |
from functools import partial | |
from typing import Any, Generator, TypeVar | |
import redis | |
import redis.asyncio as async_redis | |
import torch | |
from tokenizers import Tokenizer | |
from chameleon.inference.image_tokenizer import ImageTokenizer | |
from chameleon.inference.loader import load_model | |
from chameleon.inference.vocab import VocabInfo | |
from chameleon.viewer.backend.data_types import WSMessageType | |
from chameleon.viewer.backend.models.abstract_model import ( | |
DEFAULT_IMAGE_CFG_IMAGE, | |
DEFAULT_IMAGE_CFG_TEXT, | |
DEFAULT_MULTIMODAL_CFG_IMAGE, | |
DEFAULT_MULTIMODAL_CFG_TEXT, | |
AbstractMultimodalGenerator, | |
MixedSequenceType, | |
StreamingImage, | |
) | |
from chameleon.viewer.backend.models.chameleon_local import ( | |
ChameleonForwardMixin, | |
ChameleonTokenizationMixin, | |
) | |
from chameleon.viewer.backend.utils import get_logger | |
logger = get_logger(__name__) | |
START = "START" | |
T = TypeVar("T") | |
def find_any(queue_by_id: dict[str, list]) -> str | None: | |
for candidate_queue_id, candidate_queue in queue_by_id.items(): | |
if len(candidate_queue) > 0: | |
return candidate_queue_id | |
return None | |
class RedisQueue: | |
def __init__(self, redis_client: redis.Redis, name: str, interval: float = 0.1): | |
self.redis_client = redis_client | |
self.name = name | |
self.interval = interval | |
self.lock = redis.lock.Lock(redis_client, f"lock_for_{name}") | |
def reset(self): | |
self.redis_client.set(self.name, json.dumps({})) | |
try: | |
self.lock.release() | |
except redis.lock.LockError: | |
pass | |
def size(self) -> int: | |
maybe_queue_by_id = self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
return 0 | |
else: | |
return len(json.loads(maybe_queue_by_id)) | |
def clear(self, queue_id: str): | |
with self.lock: | |
maybe_queue_by_id = self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
queue_by_id: dict[str, list] = {} | |
else: | |
queue_by_id: dict[str, list] = json.loads(maybe_queue_by_id) | |
queue_by_id[queue_id] = [] | |
self.redis_client.set(self.name, json.dumps(queue_by_id)) | |
def put(self, queue_id: str, value: T): | |
logger.debug( | |
"Thread %s: Starting PUT(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
with self.lock: | |
maybe_queue_by_id = self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
queue_by_id: dict[str, list[T]] = {} | |
else: | |
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id) | |
if queue_id not in queue_by_id: | |
queue_by_id[queue_id] = [] | |
queue_by_id[queue_id] = [value] + queue_by_id[queue_id] | |
self.redis_client.set(self.name, json.dumps(queue_by_id)) | |
logger.debug( | |
"Thread %s: Finished PUT(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
def get(self, queue_id: str | None) -> tuple[str, T]: | |
""" | |
Get the next value in the queue. | |
if queue_id is None, will get a value from any queue | |
if queue_id is not none, will wait to get a value from a specific queue | |
""" | |
logger.debug( | |
"Thread %s: Starting GET(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
while True: | |
with self.lock: | |
# Initialization hasn't happened, so wait for it to happen | |
maybe_queue_by_id = self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
continue | |
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id) | |
if queue_id is None: | |
queue_id = find_any(queue_by_id) | |
# Ensure a queue_id was found or that it already existed | |
if queue_id is not None and queue_id in queue_by_id: | |
queue = queue_by_id[queue_id] | |
if len(queue) == 0: | |
continue | |
value = queue.pop(-1) | |
# queue is mutated and queue_by_id references it, so this works | |
self.redis_client.set(self.name, json.dumps(queue_by_id)) | |
logger.debug( | |
"Thread %s: Finished GET(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
return queue_id, value | |
time.sleep(self.interval) | |
class AsyncRedisQueue: | |
def __init__( | |
self, redis_client: async_redis.Redis, name: str, interval: float = 0.1 | |
) -> None: | |
self.redis_client = redis_client | |
self.name = name | |
self.interval = interval | |
self.lock = async_redis.lock.Lock(redis_client, f"lock_for_{name}") | |
async def reset(self): | |
await self.redis_client.set(self.name, json.dumps({})) | |
try: | |
await self.lock.release() | |
except async_redis.lock.LockError: | |
pass | |
async def size(self) -> int: | |
maybe_queue_by_id = await self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
return 0 | |
else: | |
return len(json.loads(maybe_queue_by_id)) | |
async def clear(self, queue_id: str): | |
logger.debug( | |
"ASYNC Thread %s: Starting CLEAR(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
async with self.lock: | |
maybe_queue_by_id = await self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
queue_by_id: dict[str, list] = {} | |
else: | |
queue_by_id: dict[str, list] = json.loads(maybe_queue_by_id) | |
queue_by_id[queue_id] = [] | |
await self.redis_client.set(self.name, json.dumps(queue_by_id)) | |
logger.debug( | |
"ASYNC Thread %s: Finished CLEAR(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
async def put(self, queue_id: str, value: T): | |
logger.debug( | |
"ASYNC Thread %s: Starting PUT(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
async with self.lock: | |
maybe_queue_by_id = await self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
queue_by_id: dict[str, list[T]] = {} | |
else: | |
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id) | |
if queue_id not in queue_by_id: | |
queue_by_id[queue_id] = [] | |
queue_by_id[queue_id] = [value] + queue_by_id[queue_id] | |
await self.redis_client.set(self.name, json.dumps(queue_by_id)) | |
logger.debug( | |
"ASYNC Thread %s: Finished PUT(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
async def get(self, queue_id: str | None): | |
""" | |
Get the next value in the queue. | |
if queue_id is None, will get a value from any queue | |
if queue_id is not none, will wait to get a value from a specific queue | |
""" | |
logger.debug( | |
"ASYNC Thread %s: Starting GET(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
while True: | |
async with self.lock: | |
maybe_queue_by_id = await self.redis_client.get(self.name) | |
if maybe_queue_by_id is None: | |
continue | |
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id) | |
if queue_id is None: | |
queue_id = find_any(queue_by_id) | |
# Ensure a queue_id was found or that it already existed | |
if queue_id is not None and queue_id in queue_by_id: | |
queue: list = queue_by_id[queue_id] | |
if len(queue) == 0: | |
continue | |
value = queue.pop(-1) | |
# queue is mutated and queue_by_id references it, so this works | |
await self.redis_client.set(self.name, json.dumps(queue_by_id)) | |
logger.debug( | |
"ASYNC Thread %s: Finished GET(%s) for %s", | |
threading.get_ident(), | |
self.name, | |
queue_id, | |
) | |
return queue_id, value | |
await asyncio.sleep(self.interval) | |
class AsyncRedisCounter: | |
def __init__(self, redis_client: async_redis.Redis, name: str) -> None: | |
self.redis_client = redis_client | |
self.name = name | |
self.lock = async_redis.lock.Lock(redis_client, f"lock_for_{name}") | |
async def reset(self) -> int: | |
try: | |
await self.lock.release() | |
except async_redis.lock.LockError: | |
pass | |
await self.redis_client.set(self.name, 0) | |
async def add(self, n: int) -> int: | |
async with self.lock: | |
current_val = await self.redis_client.get(self.name) | |
if current_val is None: | |
current_val = 0 | |
else: | |
current_val = int(current_val) | |
new_val = current_val + n | |
await self.redis_client.set(self.name, new_val) | |
return new_val | |
async def sub(self, n: int) -> int: | |
async with self.lock: | |
current_val = await self.redis_client.get(self.name) | |
if current_val is None: | |
raise ValueError("Invalid sub counter when counter does not exist") | |
current_val = int(current_val) | |
if current_val <= 0: | |
raise ValueError("Invalid sub counter to counter that is already zero") | |
new_val = current_val - n | |
await self.redis_client.set(self.name, new_val) | |
return new_val | |
async def count(self) -> int: | |
value = await self.redis_client.get(self.name) | |
if value is None: | |
return 0 | |
else: | |
return int(value) | |
def distributed_workers( | |
model_args: dict, | |
master_address: str, | |
master_port: str, | |
world_size: int, | |
rank: int, | |
redis_port: int, | |
worker_queues: dict[int, multiprocessing.Queue], | |
) -> None: | |
redis_client = redis.Redis("redis", redis_port) | |
request_queue = RedisQueue(redis_client, "request") | |
response_queue = RedisQueue(redis_client, "response") | |
os.environ["MASTER_ADDR"] = master_address | |
os.environ["MASTER_PORT"] = str(master_port) | |
torch.set_default_tensor_type("torch.cuda.FloatTensor") | |
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) | |
assert rank == torch.distributed.get_rank() | |
torch.cuda.set_device(rank) | |
is_coord = rank == 0 | |
worker = ChameleonWorker( | |
rank=rank, | |
model_path=model_args["model_path"], | |
tokenizer_path=model_args["tokenizer_path"], | |
additional_eos_tokens=model_args["additional_eos_tokens"], | |
) | |
worker_id = id(worker) | |
logger.info("Rank %s, master_port=%s worker=%s", rank, master_port, worker_id) | |
step = 0 | |
while True: | |
step += 1 | |
redis_client.set(f"status_rank_{rank}", "Pre-coordinator sync") | |
if is_coord: | |
distributed_objs = [request_queue.get(None)] | |
logger.info("Objects from queue: %s", distributed_objs) | |
for worker_rank in range(1, world_size): | |
worker_message = {"message": START, "src": rank, "dst": worker_rank} | |
logger.info("Rank %s Sending: %s", rank, worker_message) | |
worker_queues[worker_rank].put(worker_message) | |
else: | |
distributed_objs = [None] | |
logger.info("Rank %s worker %s waiting for rank 0", rank, worker_id) | |
message_from_rank_0 = worker_queues[rank].get() | |
logger.info( | |
"Received message from rank 0 in rank %s: %s", rank, message_from_rank_0 | |
) | |
if message_from_rank_0["message"] != START: | |
raise ValueError( | |
f"Unexpected message from rank 0: {message_from_rank_0['message']}" | |
) | |
redis_client.set(f"status_rank_{rank}", "Post-coordinator sync") | |
try: | |
logger.info( | |
"Broadcast Starting: Rank %s, worker %s, step %s", | |
rank, | |
worker_id, | |
step, | |
) | |
redis_client.set(f"status_rank_{rank}", "Pre-torch sync") | |
torch.distributed.broadcast_object_list(distributed_objs, src=0) | |
redis_client.set(f"status_rank_{rank}", "Post-torch sync") | |
logger.info( | |
"Broadcast Complete: Rank %s, worker %s, step %s", | |
rank, | |
worker_id, | |
step, | |
) | |
except RuntimeError as e: | |
logger.error( | |
"Rank %s, worker %s, step %s, Error detected in torch broadcast: %s", | |
rank, | |
worker_id, | |
step, | |
str(e), | |
) | |
raise | |
logger.info("rank %s, objs %s", rank, distributed_objs) | |
queue_id, data = distributed_objs[0] | |
mode = data.pop("mode") | |
request_id = data.pop("request_id") | |
assert queue_id == request_id | |
tokenized_prompt = data.pop("tokenized_prompt") | |
try: | |
match mode: | |
case WSMessageType.GENERATE_TEXT: | |
generator_fn = partial( | |
worker._generate_text_streaming, tokenized_prompt, **data | |
) | |
case WSMessageType.GENERATE_IMAGE: | |
generator_fn = partial( | |
worker._generate_image_streaming, tokenized_prompt, **data | |
) | |
case WSMessageType.GENERATE_MULTIMODAL: | |
generator_fn = partial( | |
worker._generate_multimodal_streaming, tokenized_prompt, **data | |
) | |
case _: | |
logger.error( | |
"Encountered unknown mode, crashing the program: %s", mode | |
) | |
response_queue.put( | |
queue_id, {"error": True, "final": True, "message": mode} | |
) | |
raise ValueError("Unknown mode") | |
logger.info("Rank: %s, Processing request: %s", rank, request_id) | |
i = 0 | |
redis_client.set(f"status_rank_{rank}", "Pre-generate") | |
for output in generator_fn(): | |
i += 1 | |
if is_coord: | |
response = {"final": False, "output": output, "error": False} | |
logger.info( | |
"Rank: %s, Adding to response queue: %.100s", | |
rank, | |
response, | |
) | |
redis_client.set(f"status_rank_{rank}", f"Generate Pre Put {i}") | |
response_queue.put(queue_id, response) | |
redis_client.set(f"status_rank_{rank}", f"Generate Post Put {i}") | |
else: | |
redis_client.set(f"status_rank_{rank}", f"Generate {i}") | |
redis_client.set(f"step_on_rank_{rank}", i) | |
redis_client.set(f"status_rank_{rank}", "Post-generate") | |
if is_coord: | |
logger.info("Rank: %s, Adding final result to output queue", rank) | |
response_queue.put(queue_id, {"final": True, "error": False}) | |
except torch.cuda.OutOfMemoryError as e: | |
logger.error("Encountered OOM, crashing the program: %s", e) | |
response_queue.put( | |
queue_id, {"error": True, "final": True, "message": str(e)} | |
) | |
crash_program() | |
except RuntimeError as e: | |
message = str(e) | |
if "CUDA" in message: | |
logger.error("Encountered CUDA error, crashing the program: %s", e) | |
response_queue.put( | |
queue_id, {"error": True, "final": True, "message": str(e)} | |
) | |
crash_program() | |
else: | |
logger.error( | |
"Encountered unexpected runtime error, crashing the program: %s %s", | |
e, | |
traceback.format_exc(), | |
) | |
response_queue.put( | |
queue_id, {"error": True, "final": True, "message": str(e)} | |
) | |
crash_program() | |
except Exception as e: | |
logger.error( | |
"Encountered unexpected exception: %s %s", | |
str(e), | |
traceback.format_exc(), | |
) | |
response_queue.put( | |
queue_id, {"error": True, "final": True, "message": str(e)} | |
) | |
crash_program() | |
class ChameleonWorker(ChameleonForwardMixin): | |
def __init__( | |
self, | |
*, | |
rank: int, | |
model_path: str, | |
tokenizer_path: str, | |
additional_eos_tokens: list[str] | None, | |
) -> None: | |
self.rank = rank | |
self.model_path = model_path | |
self.additional_eos_tokens = additional_eos_tokens | |
torch.set_default_device(f"cuda:{rank}") | |
self.model = load_model(model_path, rank) | |
self.tokenizer = Tokenizer.from_file(str(tokenizer_path)) | |
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"]) | |
logger.info( | |
"Rank: %s, Model loaded in worker_obj: %s", | |
rank, | |
id(self), | |
) | |
def crash_program() -> None: | |
logger.error( | |
"Crashing the program as instructed, likely due to distributed worker failures" | |
) | |
sys.exit(1) | |
class ChameleonDistributedGenerator(AbstractMultimodalGenerator, ChameleonTokenizationMixin): | |
def __init__( | |
self, | |
*, | |
world_size: int, | |
model_path: str, | |
master_port: int, | |
tokenizer_path: str, | |
vqgan_config_path: str, | |
vqgan_ckpt_path: str | None = None, | |
master_address: str = "0.0.0.0", | |
additional_eos_tokens: list[str] | None = None, | |
redis_port: int | None = None, | |
) -> None: | |
self.master_port = master_port | |
self.master_address = master_address | |
self.additional_eos_tokens = additional_eos_tokens | |
logger.info("Loading tokenizer...") | |
tokenizer_path = tokenizer_path | |
self.tokenizer = Tokenizer.from_file(str(tokenizer_path)) | |
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"]) | |
logger.info("Loading VQGAN...") | |
self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path) | |
self.redis_port = redis_port | |
self.redis_pool = async_redis.ConnectionPool.from_url( | |
f"redis://redis:{redis_port}" | |
) | |
self.redis_client = async_redis.Redis.from_pool(self.redis_pool) | |
self.request_queue = AsyncRedisQueue(self.redis_client, "request") | |
self.response_queue = AsyncRedisQueue(self.redis_client, "response") | |
self.worker_queues: dict[int, multiprocessing.Queue] = { | |
rank: multiprocessing.Queue() for rank in range(world_size) | |
} | |
self.procs: list[multiprocessing.Process] = [] | |
model_args = { | |
"model_path": model_path, | |
"master_address": master_address, | |
"master_port": master_port, | |
"tokenizer_path": tokenizer_path, | |
"additional_eos_tokens": additional_eos_tokens, | |
} | |
logger.info("Launching paralle model with world_size=%s", world_size) | |
for i in range(world_size): | |
proc = multiprocessing.Process( | |
target=distributed_workers, | |
args=( | |
model_args, | |
master_address, | |
master_port, | |
world_size, | |
i, | |
self.redis_port, | |
self.worker_queues, | |
), | |
daemon=True, | |
) | |
self.procs.append(proc) | |
proc.start() | |
def check_error(self, output: dict) -> None: | |
if output["error"]: | |
import sys | |
print(f"check_error({output})", file=sys.stderr) | |
self.kill_procs() | |
logger.error( | |
"COORDINATOR: Encountered error in managed processes, exiting: %s", | |
output, | |
) | |
crash_program() | |
def __del__(self) -> None: | |
self.kill_procs(error=False) | |
def kill_procs(self, error: bool = True) -> None: | |
if error: | |
log_fn = logger.error | |
else: | |
log_fn = logger.info | |
log_fn("Error encountered, killing worker procs: %s", self.procs) | |
for p in self.procs: | |
try: | |
log_fn("Killing: %s", p) | |
p.kill() | |
except: | |
log_fn("Encountered issue killing process and ignoring: %s", p) | |
# ALLOW_ANY(get_next_output.return) | |
async def get_next_output(self, request_id: str) -> Any: | |
logger.info("Waiting for response for request_id=%s", request_id) | |
queue_id, output = await self.response_queue.get(request_id) | |
assert queue_id == request_id | |
return output | |
async def generate_text_streaming( | |
self, | |
prompt: MixedSequenceType, | |
max_gen_tokens: int = 256, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
repetition_penalty: float = 1.2, | |
seed: int | None = None, | |
debug: dict | None = None, | |
) -> Generator[str, None, None]: | |
tokenized_prompt = self.tokens_from_inputs(prompt) | |
request_id = f"request_{random.randint(100_000, 200_000)}" | |
if seed is None: | |
seed = random.randint(1, 2048) | |
if debug is not None: | |
debug["seed"] = seed | |
if len(tokenized_prompt) > (4096 - 3): | |
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." | |
return | |
assert not isinstance(tokenized_prompt, torch.Tensor) | |
request = { | |
"mode": WSMessageType.GENERATE_TEXT.value, | |
"request_id": request_id, | |
"tokenized_prompt": tokenized_prompt, | |
"max_gen_tokens": max_gen_tokens, | |
"temp": temp, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"seed": seed, | |
} | |
logger.info( | |
"Sending request_id=%s: %s", | |
request_id, | |
request, | |
) | |
await asyncio.gather( | |
self.request_queue.clear(request_id), | |
self.response_queue.clear(request_id), | |
) | |
logger.info("Cleared request/response queue for %s", request_id) | |
await self.request_queue.put(request_id, request) | |
logger.info("Sent request to coordinator %s", request_id) | |
try: | |
while True: | |
output = await self.get_next_output(request_id) | |
logger.info("Received response for %s", request_id) | |
self.check_error(output) | |
if output["final"]: | |
break | |
n_outs = len(output["output"]) | |
if n_outs != 1: | |
logger.error( | |
"Encountered unexpected number of %s arguments in: %s", | |
n_outs, | |
output["output"], | |
) | |
tokens = output["output"] | |
assert not isinstance(tokens, torch.Tensor) | |
logger.info("output info: type=%s, value=%.20s", type(tokens), tokens) | |
yield self.tokenizer.decode(tokens) | |
finally: | |
logger.info("Cleaning up queues in request_id=%s", request_id) | |
await asyncio.gather( | |
self.request_queue.clear(request_id), | |
self.response_queue.clear(request_id), | |
) | |
logger.info("Completed cleaning for request_id=%s", request_id) | |
async def generate_image_streaming( | |
self, | |
prompt: MixedSequenceType, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE, | |
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT, | |
yield_every_n: int = 32, | |
debug: dict | None = None, | |
seed: int | None = None, | |
) -> Generator[StreamingImage, None, None]: | |
tokenized_prompt = self.tokens_from_inputs(prompt) | |
tokenized_prompt.append(self.vocab.begin_image) | |
assert not isinstance(tokenized_prompt, torch.Tensor) | |
request_id = f"request_{random.randint(100_000, 200_000)}" | |
if seed is None: | |
seed = random.randint(1, 2048) | |
if debug is not None: | |
debug["seed"] = seed | |
if len(tokenized_prompt) > (4096 - 3 - 1024): | |
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." | |
return | |
request = { | |
"mode": WSMessageType.GENERATE_IMAGE.value, | |
"request_id": request_id, | |
"tokenized_prompt": tokenized_prompt, | |
"cfg_image_weight": cfg_image_weight, | |
"cfg_text_weight": cfg_text_weight, | |
"yield_every_n": yield_every_n, | |
"temp": temp, | |
"top_p": top_p, | |
"seed": seed, | |
} | |
logger.info( | |
"Sending request_id=%s: %s", | |
request_id, | |
request, | |
) | |
await asyncio.gather( | |
self.request_queue.clear(request_id), | |
self.response_queue.clear(request_id), | |
) | |
logger.info("Cleared request/response queue for %s", request_id) | |
await self.request_queue.put(request_id, request) | |
logger.info("Sent request to coordinator %s", request_id) | |
try: | |
while True: | |
output = await self.get_next_output(request_id) | |
logger.info("Received response for %s", request_id) | |
self.check_error(output) | |
if output["final"]: | |
break | |
n_outs = len(output["output"]) | |
if n_outs != 2: | |
logger.error( | |
"Encountered unexpected number of %s arguments in: %s", | |
n_outs, | |
output["output"], | |
) | |
tokens, final = output["output"] | |
assert not isinstance(tokens, torch.Tensor) | |
yield StreamingImage( | |
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final | |
) | |
finally: | |
logger.info("Cleaning up queues in request_id=%s", request_id) | |
await asyncio.gather( | |
self.request_queue.clear(request_id), | |
self.response_queue.clear(request_id), | |
) | |
logger.info("Completed cleaning for request_id=%s", request_id) | |
async def generate_multimodal_streaming( | |
self, | |
prompt: MixedSequenceType, | |
temp: float = 1.0, | |
top_p: float = 0.8, | |
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE, | |
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT, | |
yield_every_n: int = 32, | |
max_gen_tokens: int = 4096, | |
repetition_penalty: float = 1.2, | |
suffix_tokens: list[str] | None = None, | |
seed: int | None = None, | |
debug: dict | None = None, | |
) -> Generator[MixedSequenceType, None, None]: | |
tokenized_prompt = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens) | |
assert not isinstance(tokenized_prompt, torch.Tensor) | |
request_id = f"request_{random.randint(100_000, 200_000)}" | |
if seed is None: | |
seed = random.randint(1, 2048) | |
if debug is not None: | |
debug["seed"] = seed | |
if len(tokenized_prompt) > (4096 - 3): | |
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens." | |
return | |
request = { | |
"mode": WSMessageType.GENERATE_MULTIMODAL.value, | |
"request_id": request_id, | |
"tokenized_prompt": tokenized_prompt, | |
"cfg_image_weight": cfg_image_weight, | |
"cfg_text_weight": cfg_text_weight, | |
"repetition_penalty": repetition_penalty, | |
"yield_every_n": yield_every_n, | |
"max_gen_tokens": max_gen_tokens, | |
"temp": temp, | |
"top_p": top_p, | |
"seed": seed, | |
} | |
logger.info( | |
"Sending request_id=%s: %s", | |
request_id, | |
request, | |
) | |
await asyncio.gather( | |
self.request_queue.clear(request_id), | |
self.response_queue.clear(request_id), | |
) | |
logger.info("Cleared request/response queue for %s", request_id) | |
await self.request_queue.put(request_id, request) | |
logger.info("Sent request to coordinator %s", request_id) | |
try: | |
while True: | |
output = await self.get_next_output(request_id) | |
logger.info("Received response for %s", request_id) | |
self.check_error(output) | |
if output["final"]: | |
break | |
n_outs = len(output["output"]) | |
if n_outs != 3: | |
logger.error( | |
"Encountered unexpected number of %s arguments in: %s", | |
n_outs, | |
output["output"], | |
) | |
token_type, tokens, image_is_final = output["output"] | |
assert not isinstance(tokens, torch.Tensor) | |
match token_type: | |
case "TEXT": | |
yield self.tokenizer.decode(tokens) | |
case "IMAGE": | |
yield StreamingImage( | |
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), | |
final=image_is_final, | |
) | |
case _: | |
raise ValueError("Unknown token type") | |
finally: | |
logger.info("Cleaning up queues in request_id=%s", request_id) | |
await self.request_queue.clear(request_id) | |
await self.response_queue.clear(request_id) | |