EasyAnimate / easyanimate /video_caption /utils /image_captioner_sglang.py
bubbliiiing
Create Code
19fe404
raw
history blame
3.56 kB
import os
import time
from datetime import datetime
from typing import List, Union
from pathlib import Path
import sglang as sgl
from PIL import Image
from utils.logger import logger
TMP_DIR = "./tmp"
def get_timestamp():
timestamp_ns = int(time.time_ns())
milliseconds = timestamp_ns // 1000000
formatted_time = datetime.fromtimestamp(milliseconds / 1000).strftime("%Y-%m-%d_%H-%M-%S-%f")[:-3]
return formatted_time
class LLaVASRT:
def __init__(self, device: str = "cuda:0", quantized: bool = True):
self.runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path="llava-hf/llava-1.5-7b-hf")
sgl.set_default_backend(self.runtime)
logger.info(
f"Start the SGLang runtime for llava-v1.6-vicuna-7b with chat template: {self.runtime.endpoint.chat_template.name}. "
"Input parameter device and quantized do not take effect."
)
if not os.path.exists(TMP_DIR):
os.makedirs(TMP_DIR, exist_ok=True)
@sgl.function
def image_qa(s, prompt: str, image: str):
s += sgl.user(sgl.image(image) + prompt)
s += sgl.assistant(sgl.gen("answer"))
def __call__(self, prompt: Union[str, List[str]], image: Union[str, Image.Image, List[str]]):
pil_input_flag = False
if isinstance(prompt, str) and (isinstance(image, str) or isinstance(image, Image.Image)):
if isinstance(image, Image.Image):
pil_input_flag = True
image_path = os.path.join(TMP_DIR, get_timestamp() + ".jpg")
image.save(image_path)
state = self.image_qa.run(prompt=prompt, image=image, max_new_tokens=256)
# Post-process.
if pil_input_flag:
os.remove(image)
return state["answer"], state
elif isinstance(prompt, list) and isinstance(image, list):
assert len(prompt) == len(image)
if isinstance(image[0], Image.Image):
pil_input_flag = True
image_path = [os.path.join(TMP_DIR, get_timestamp() + f"-{i}" + ".jpg") for i in range(len(image))]
for i in range(len(image)):
image[i].save(image_path[i])
image = image_path
batch_query = [{"prompt": p, "image": img} for p, img in zip(prompt, image)]
state = self.image_qa.run_batch(batch_query, max_new_tokens=256)
# Post-process.
if pil_input_flag:
for i in range(len(image)):
os.remove(image[i])
return [s["answer"] for s in state], state
else:
raise ValueError("Input prompt and image must be both strings or list of strings with the same length.")
def __del__(self):
self.runtime.shutdown()
if __name__ == "__main__":
image_folder = "demo/"
wildcard_list = ["*.jpg", "*.png"]
image_list = []
for wildcard in wildcard_list:
image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)])
# SGLang need the exclusive GPU and cannot re-initialize CUDA in forked subprocess.
llava_srt = LLaVASRT()
# Batch inference.
llava_srt_prompt = ["Please describe this image in detail."] * len(image_list)
response, _ = llava_srt(llava_srt_prompt, image_list)
print(response)
llava_srt_prompt = "Please describe this image in detail."
for image in image_list:
response, _ = llava_srt(llava_srt_prompt, image)
print(image, response)