|
import os |
|
from time import perf_counter |
|
from typing import Any, List, Tuple, Union |
|
|
|
import numpy as np |
|
import requests |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM, LlamaTokenizer |
|
|
|
from inference.core.entities.requests.cogvlm import CogVLMInferenceRequest |
|
from inference.core.entities.responses.cogvlm import CogVLMResponse |
|
from inference.core.env import ( |
|
API_KEY, |
|
COGVLM_LOAD_4BIT, |
|
COGVLM_LOAD_8BIT, |
|
COGVLM_VERSION_ID, |
|
MODEL_CACHE_DIR, |
|
) |
|
from inference.core.models.base import Model, PreprocessReturnMetadata |
|
from inference.core.utils.image_utils import load_image_rgb |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class CogVLM(Model): |
|
def __init__(self, model_id=f"cogvlm/{COGVLM_VERSION_ID}", **kwargs): |
|
self.model_id = model_id |
|
self.endpoint = model_id |
|
self.api_key = API_KEY |
|
self.dataset_id, self.version_id = model_id.split("/") |
|
if COGVLM_LOAD_4BIT and COGVLM_LOAD_8BIT: |
|
raise ValueError( |
|
"Only one of environment variable `COGVLM_LOAD_4BIT` or `COGVLM_LOAD_8BIT` can be true" |
|
) |
|
self.cache_dir = os.path.join(MODEL_CACHE_DIR, self.endpoint) |
|
with torch.inference_mode(): |
|
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
f"THUDM/{self.version_id}", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
load_in_4bit=COGVLM_LOAD_4BIT, |
|
load_in_8bit=COGVLM_LOAD_8BIT, |
|
cache_dir=self.cache_dir, |
|
).eval() |
|
self.task_type = "lmm" |
|
|
|
def preprocess( |
|
self, image: Any, **kwargs |
|
) -> Tuple[Image.Image, PreprocessReturnMetadata]: |
|
pil_image = Image.fromarray(load_image_rgb(image)) |
|
|
|
return pil_image, PreprocessReturnMetadata({}) |
|
|
|
def postprocess( |
|
self, |
|
predictions: Tuple[str], |
|
preprocess_return_metadata: PreprocessReturnMetadata, |
|
**kwargs, |
|
) -> Any: |
|
return predictions[0] |
|
|
|
def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs): |
|
images = [image_in] |
|
if history is None: |
|
history = [] |
|
built_inputs = self.model.build_conversation_input_ids( |
|
self.tokenizer, query=prompt, history=history, images=images |
|
) |
|
inputs = { |
|
"input_ids": built_inputs["input_ids"].unsqueeze(0).to(DEVICE), |
|
"token_type_ids": built_inputs["token_type_ids"].unsqueeze(0).to(DEVICE), |
|
"attention_mask": built_inputs["attention_mask"].unsqueeze(0).to(DEVICE), |
|
"images": [[built_inputs["images"][0].to(DEVICE).to(torch.float16)]], |
|
} |
|
gen_kwargs = {"max_length": 2048, "do_sample": False} |
|
|
|
with torch.inference_mode(): |
|
outputs = self.model.generate(**inputs, **gen_kwargs) |
|
outputs = outputs[:, inputs["input_ids"].shape[1] :] |
|
text = self.tokenizer.decode(outputs[0]) |
|
if text.endswith("</s>"): |
|
text = text[:-4] |
|
return (text,) |
|
|
|
def infer_from_request(self, request: CogVLMInferenceRequest) -> CogVLMResponse: |
|
t1 = perf_counter() |
|
text = self.infer(**request.dict()) |
|
response = CogVLMResponse(response=text) |
|
response.time = perf_counter() - t1 |
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
m = CogVLM() |
|
m.infer() |
|
|