Spaces:
Running
Running
import os | |
import time | |
from datetime import datetime | |
from typing import List, Union | |
from pathlib import Path | |
import sglang as sgl | |
from PIL import Image | |
from utils.logger import logger | |
TMP_DIR = "./tmp" | |
def get_timestamp(): | |
timestamp_ns = int(time.time_ns()) | |
milliseconds = timestamp_ns // 1000000 | |
formatted_time = datetime.fromtimestamp(milliseconds / 1000).strftime("%Y-%m-%d_%H-%M-%S-%f")[:-3] | |
return formatted_time | |
class LLaVASRT: | |
def __init__(self, device: str = "cuda:0", quantized: bool = True): | |
self.runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path="llava-hf/llava-1.5-7b-hf") | |
sgl.set_default_backend(self.runtime) | |
logger.info( | |
f"Start the SGLang runtime for llava-v1.6-vicuna-7b with chat template: {self.runtime.endpoint.chat_template.name}. " | |
"Input parameter device and quantized do not take effect." | |
) | |
if not os.path.exists(TMP_DIR): | |
os.makedirs(TMP_DIR, exist_ok=True) | |
def image_qa(s, prompt: str, image: str): | |
s += sgl.user(sgl.image(image) + prompt) | |
s += sgl.assistant(sgl.gen("answer")) | |
def __call__(self, prompt: Union[str, List[str]], image: Union[str, Image.Image, List[str]]): | |
pil_input_flag = False | |
if isinstance(prompt, str) and (isinstance(image, str) or isinstance(image, Image.Image)): | |
if isinstance(image, Image.Image): | |
pil_input_flag = True | |
image_path = os.path.join(TMP_DIR, get_timestamp() + ".jpg") | |
image.save(image_path) | |
state = self.image_qa.run(prompt=prompt, image=image, max_new_tokens=256) | |
# Post-process. | |
if pil_input_flag: | |
os.remove(image) | |
return state["answer"], state | |
elif isinstance(prompt, list) and isinstance(image, list): | |
assert len(prompt) == len(image) | |
if isinstance(image[0], Image.Image): | |
pil_input_flag = True | |
image_path = [os.path.join(TMP_DIR, get_timestamp() + f"-{i}" + ".jpg") for i in range(len(image))] | |
for i in range(len(image)): | |
image[i].save(image_path[i]) | |
image = image_path | |
batch_query = [{"prompt": p, "image": img} for p, img in zip(prompt, image)] | |
state = self.image_qa.run_batch(batch_query, max_new_tokens=256) | |
# Post-process. | |
if pil_input_flag: | |
for i in range(len(image)): | |
os.remove(image[i]) | |
return [s["answer"] for s in state], state | |
else: | |
raise ValueError("Input prompt and image must be both strings or list of strings with the same length.") | |
def __del__(self): | |
self.runtime.shutdown() | |
if __name__ == "__main__": | |
image_folder = "demo/" | |
wildcard_list = ["*.jpg", "*.png"] | |
image_list = [] | |
for wildcard in wildcard_list: | |
image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)]) | |
# SGLang need the exclusive GPU and cannot re-initialize CUDA in forked subprocess. | |
llava_srt = LLaVASRT() | |
# Batch inference. | |
llava_srt_prompt = ["Please describe this image in detail."] * len(image_list) | |
response, _ = llava_srt(llava_srt_prompt, image_list) | |
print(response) | |
llava_srt_prompt = "Please describe this image in detail." | |
for image in image_list: | |
response, _ = llava_srt(llava_srt_prompt, image) | |
print(image, response) |