|
"""Send a test message.""" |
|
import argparse |
|
import json |
|
import time |
|
from io import BytesIO |
|
import cv2 |
|
from groundingdino.util.inference import annotate, annotate_xyxy |
|
import numpy as np |
|
|
|
|
|
import requests |
|
from PIL import Image |
|
import base64 |
|
|
|
import torch |
|
import torchvision.transforms.functional as F |
|
|
|
import pycocotools.mask as mask_util |
|
|
|
|
|
def load_image(image_path): |
|
img = Image.open(image_path).convert('RGB') |
|
|
|
|
|
w, h = img.size |
|
if max(h, w) > 800: |
|
if h > w: |
|
new_h = 800 |
|
new_w = int(w * 800 / h) |
|
else: |
|
new_w = 800 |
|
new_h = int(h * 800 / w) |
|
|
|
img = F.resize(img, (new_h, new_w)) |
|
return img |
|
|
|
def encode(image: Image): |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_b64_str = base64.b64encode(buffered.getvalue()).decode() |
|
return img_b64_str |
|
|
|
def show_mask(mask, image, random_color=True): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
|
|
annotated_frame_pil = Image.fromarray(image).convert("RGBA") |
|
mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA") |
|
|
|
return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil)) |
|
|
|
def main(): |
|
model_name = args.model_name |
|
|
|
if args.worker_address: |
|
worker_addr = args.worker_address |
|
else: |
|
controller_addr = args.controller_address |
|
ret = requests.post(controller_addr + "/refresh_all_workers") |
|
ret = requests.post(controller_addr + "/list_models") |
|
models = ret.json()["models"] |
|
models.sort() |
|
print(f"Models: {models}") |
|
|
|
ret = requests.post( |
|
controller_addr + "/get_worker_address", json={"model": model_name} |
|
) |
|
worker_addr = ret.json()["address"] |
|
print(f"worker_addr: {worker_addr}") |
|
|
|
if worker_addr == "": |
|
print(f"No available workers for {model_name}") |
|
return |
|
|
|
headers = {"User-Agent": "FastChat Client"} |
|
if args.send_image: |
|
img = load_image(args.image_path) |
|
img_arg = encode(img) |
|
else: |
|
img_arg = args.image_path |
|
img = None |
|
datas = { |
|
"model": model_name, |
|
"caption": args.caption, |
|
"image": img_arg, |
|
"box_threshold": args.box_threshold, |
|
"text_threshold": args.text_threshold, |
|
} |
|
tic = time.time() |
|
response = requests.post( |
|
worker_addr + "/worker_generate", |
|
headers=headers, |
|
json=datas, |
|
) |
|
toc = time.time() |
|
print(f"Time: {toc - tic:.3f}s") |
|
|
|
print("detection result:") |
|
print(response.json()) |
|
|
|
|
|
|
|
|
|
res = response.json() |
|
|
|
boxes = torch.Tensor(res["boxes"]) |
|
logits = torch.Tensor(res["logits"]) |
|
phrases = res["phrases"] |
|
if img is not None: |
|
image_source = np.array(img.convert("RGB")) |
|
else: |
|
image_source = np.array(Image.open(args.image_path).convert("RGB")) |
|
annotated_frame = annotate_xyxy(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) |
|
|
|
|
|
|
|
masks_rle = res["masks_rle"] |
|
for mask_rle in masks_rle: |
|
mask = mask_util.decode(mask_rle) |
|
mask = torch.Tensor(mask) |
|
annotated_frame = show_mask(mask, annotated_frame) |
|
cv2.imwrite("annotated_image_mask.jpg", annotated_frame) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--controller-address", type=str, default="http://localhost:21001" |
|
) |
|
parser.add_argument("--worker-address", type=str) |
|
parser.add_argument("--model-name", type=str, default='grounded_sam') |
|
|
|
|
|
parser.add_argument( |
|
"--caption", type=str, default="dogs ." |
|
) |
|
parser.add_argument( |
|
"--image_path", type=str, default="/home/liushilong/code/GroundingFolder/Grounded-Segment-Anything/assets/demo2.jpg" |
|
) |
|
parser.add_argument( |
|
"--box_threshold", type=float, default=0.3, |
|
) |
|
parser.add_argument( |
|
"--text_threshold", type=float, default=0.25, |
|
) |
|
parser.add_argument( |
|
"--send_image", action="store_true", |
|
) |
|
args = parser.parse_args() |
|
|
|
main() |
|
|