import torch from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token from transformers.generation.streamers import TextIteratorStreamer from PIL import Image import requests from io import BytesIO from cog import BasePredictor, Input, Path, ConcatenateIterator import time import subprocess from threading import Thread import os os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights" # url for the weights mirror REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default" # files to download from the weights mirrors weights = [ { "dest": "liuhaotian/llava-v1.5-13b", # git commit hash from huggingface "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8", "files": [ "config.json", "generation_config.json", "pytorch_model-00001-of-00003.bin", "pytorch_model-00002-of-00003.bin", "pytorch_model-00003-of-00003.bin", "pytorch_model.bin.index.json", "special_tokens_map.json", "tokenizer.model", "tokenizer_config.json", ] }, { "dest": "openai/clip-vit-large-patch14-336", "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1", "files": [ "config.json", "preprocessor_config.json", "pytorch_model.bin" ], } ] def download_json(url: str, dest: Path): res = requests.get(url, allow_redirects=True) if res.status_code == 200 and res.content: with dest.open("wb") as f: f.write(res.content) else: print(f"Failed to download {url}. Status code: {res.status_code}") def download_weights(baseurl: str, basedest: str, files: list[str]): basedest = Path(basedest) start = time.time() print("downloading to: ", basedest) basedest.mkdir(parents=True, exist_ok=True) for f in files: dest = basedest / f url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f) if not dest.exists(): print("downloading url: ", url) if dest.suffix == ".json": download_json(url, dest) else: subprocess.check_call(["pget", url, str(dest)], close_fds=False) print("downloading took: ", time.time() - start) class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" for weight in weights: download_weights(weight["src"], weight["dest"], weight["files"]) disable_torch_init() self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False) def predict( self, image: Path = Input(description="Input image"), prompt: str = Input(description="Prompt to use for text generation"), top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0), temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0), max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0), ) -> ConcatenateIterator[str]: """Run a single prediction on the model""" conv_mode = "llava_v1" conv = conv_templates[conv_mode].copy() image_data = load_image(str(image)) image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda() # loop start # just one turn, always prepend image token inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0) with torch.inference_mode(): thread = Thread(target=self.model.generate, kwargs=dict( inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, streamer=streamer, use_cache=True)) thread.start() # workaround: second-to-last token is always " " # but we want to keep it if it's not the second-to-last token prepend_space = False for new_text in streamer: if new_text == " ": prepend_space = True continue if new_text.endswith(stop_str): new_text = new_text[:-len(stop_str)].strip() prepend_space = False elif prepend_space: new_text = " " + new_text prepend_space = False if len(new_text): yield new_text if prepend_space: yield " " thread.join() def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image