import torch from seagull.utils import disable_torch_init from transformers import AutoTokenizer, CLIPImageProcessor from seagull.model.language_model.seagull_llama import SeagullLlamaForCausalLM from seagull.mm_utils import tokenizer_image_token from seagull.conversation import conv_templates, SeparatorStyle from seagull.constants import IMAGE_TOKEN_INDEX from seagull.train.train import DataArguments from functools import partial import os import numpy as np import cv2 from typing import List from PIL import Image from pycocotools import mask as mask_utils class Seagull(): def __init__(self, model_path, device='cuda'): disable_torch_init() model_path = os.path.expanduser(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=2048, padding_side="right", use_fast=True) self.model = SeagullLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16,).to(device) self.tokenizer.pad_token = self.tokenizer.unk_token self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512}, do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, ) spi_tokens = ['', ''] self.tokenizer.add_tokens(spi_tokens, special_tokens=True) for m in self.model.modules(): m.tokenizer = self.tokenizer vision_tower = self.model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() vision_tower.to(dtype=torch.float16, device=device) begin_str = "\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region . " instruction = { 'distortion': 'Provide the distortion type of this region.', 'quality': 'Analyze the quality of this region.', 'importance': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.' } self.ids_input = {} for ins_type, ins in instruction.items(): conv = conv_templates['seagull_v1'].copy() qs = begin_str + ins conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() self.ids_input[ins_type] = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device) self.stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 def init_image(self, img): if isinstance(img, dict): img = img['image'] elif isinstance(img, List): img = cv2.imread(img[0]) img = img[:, :, ::-1] h_, w_ = img.shape[:2] if h_ > 512: ratio = 512 / h_ new_h, new_w = int(h_ * ratio), int(w_ * ratio) preprocessed_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) else: preprocessed_img = img.copy() return (preprocessed_img, preprocessed_img, preprocessed_img, preprocessed_img) def preprocess(self, img): image = self.image_processor.preprocess(img, do_center_crop=False, return_tensors='pt')['pixel_values'][0] image = torch.nn.functional.interpolate(image.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False).squeeze(0) return image def seagull_predict(self, img, mask, instruct_type, mask_type='rle'): if isinstance(img, str): img = cv2.imread(img) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w, _ = img.shape if mask_type == 'rle': # use the mask to indicate the roi compressed_rle = {'size' : [h, w], 'counts' : mask} mask = mask_utils.decode(compressed_rle) elif mask_type == 'points': # use the point to indicate the roi x_min, y_min, w1, h1 = mask x_max, y_max = x_min + w1, y_min + h1 mask = np.zeros_like(img[:, :, 0]) mask[max(0, y_min):min(y_max, mask.shape[0]), max(0, x_min):min(x_max, mask.shape[1])] = 1 image = self.preprocess(img) mask = np.array(mask, dtype=np.int) ys, xs = np.where(mask > 0) if len(xs) > 0 and len(ys) > 0: x_min, x_max = np.min(xs), np.max(xs) y_min, y_max = np.min(ys), np.max(ys) w1 = x_max - x_min h1 = y_max - y_min bounding_box = (x_min, y_min, w1, h1) else: bounding_box = None mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST) mask = np.array(mask > 0.1, dtype=np.uint8) masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device) input_ids = self.ids_input[instruct_type.split()[0].lower()] x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h cropped_img = img[y1:y1 + h1, x1:x1 + w1] cropped_img = Image.fromarray(cropped_img) cropped_img = self.preprocess(cropped_img) with torch.inference_mode(): self.model.orig_forward = self.model.forward self.model.forward = partial(self.model.orig_forward, img_metas=[None], masks=[masks.half()], cropped_img=cropped_img.unsqueeze(0) ) output_ids = self.model.generate( input_ids, images=image.unsqueeze(0).half().to(self.model.device), do_sample=False, temperature=1, max_new_tokens=2048, use_cache=True, num_beams=1, top_k = 0, top_p = 1, ) self.model.forward = self.model.orig_forward input_token_len = input_ids.shape[1] n_diff_input_output = ( input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print( f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(self.stop_str): outputs = outputs[:-len(self.stop_str)] outputs = outputs.strip() if ':' in outputs: outputs = outputs.split(':')[1] outputs_list = outputs.split('.') outputs_list_final = [] outputs_str = '' for output in outputs_list: if output not in outputs_list_final: if output=='': continue outputs_list_final.append(output) outputs_str+=output+'.' else: break return outputs_str