cogvlm-chat-hf / handler.py
Marlon Wiprud
log outputs
4f23a3a
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, LlamaTokenizer
import torch
from accelerate import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
import os
import logging
from transformers import logging as hf_logging
logging.basicConfig(level=logging.INFO)
hf_logging.set_verbosity_debug()
def list_files(directory, depth, max_depth=5):
# Lists all files and directories in the given directory
for filename in os.listdir(directory):
print(os.path.join(directory, filename))
if not os.path.isfile(filename) and depth < max_depth:
list_files(os.path.join(directory, filename), depth + 1, max_depth)
class EndpointHandler:
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# self.pipeline = pipeline(
# "text-generation", model="THUDM/cogvlm-chat-hf", trust_remote_code=True
# )
# self.model = AutoModelForCausalLM.from_pretrained(
# "THUDM/cogvlm-chat-hf", trust_remote_code=True
# )
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
self.model = (
AutoModelForCausalLM.from_pretrained(
"THUDM/cogvlm-chat-hf",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
.to("cuda")
.eval()
)
# DISTRIBUTED GPUS
# with init_empty_weights():
# self.model = AutoModelForCausalLM.from_pretrained(
# "THUDM/cogvlm-chat-hf",
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# )
# # print("LISTING FILES IN ", "/root/.cache/huggingface")
# # list_files("/root/.cache/huggingface", 0, 5)
# device_map = infer_auto_device_map(
# self.model,
# max_memory={
# 0: "12GiB",
# 1: "12GiB",
# 2: "12GiB",
# 3: "12GiB",
# "cpu": "180GiB",
# },
# no_split_module_classes=["CogVLMDecoderLayer"],
# )
# self.model = load_checkpoint_and_dispatch(
# self.model,
# "/root/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/8abca878c4257412c4c38eeafaed3fe27a036730",
# device_map=device_map,
# no_split_module_classes=["CogVLMDecoderLayer"],
# )
# self.model = self.model.eval()
## DISTRIBUTED GPUS
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
query = data["query"]
img_uri = data["img_uri"]
image = Image.open(
requests.get(
img_uri,
stream=True,
).raw
).convert("RGB")
inputs = self.model.build_conversation_input_ids(
self.tokenizer,
query=query,
history=[],
images=[image],
template_version="vqa",
) # vqa mode
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
"token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
"attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
"images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
}
gen_kwargs = {"max_length": 2048, "do_sample": False}
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
print("outputs 1: ", outputs)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
print("outputs 2: ", outputs)
response = self.tokenizer.decode(outputs[0])
return response
# query = "How many houses are there in this cartoon?"
# image = Image.open(
# requests.get(
# "https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true", stream=True
# ).raw
# ).convert("RGB")
# inputs = model.build_conversation_input_ids(
# tokenizer, query=query, history=[], images=[image], template_version="vqa"
# ) # vqa mode
# inputs = {
# "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
# "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
# "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
# "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
# }
# gen_kwargs = {"max_length": 2048, "do_sample": False}
# with torch.no_grad():
# outputs = model.generate(**inputs, **gen_kwargs)
# outputs = outputs[:, inputs["input_ids"].shape[1] :]
# print(tokenizer.decode(outputs[0]))