File size: 4,209 Bytes
a560a5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""Send a test message."""
import argparse
import json
import time
from io import BytesIO
import cv2
from groundingdino.util.inference import annotate
import numpy as np
import requests
from PIL import Image
import base64
import torch
import torchvision.transforms.functional as F
def load_image(image_path):
img = Image.open(image_path).convert('RGB')
# import ipdb; ipdb.set_trace()
# resize if needed
w, h = img.size
if max(h, w) > 800:
if h > w:
new_h = 800
new_w = int(w * 800 / h)
else:
new_w = 800
new_h = int(h * 800 / w)
# import ipdb; ipdb.set_trace()
img = F.resize(img, (new_h, new_w))
return img
def encode(image: Image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
return img_b64_str
def main():
model_name = args.model_name
if args.worker_address:
worker_addr = args.worker_address
else:
controller_addr = args.controller_address
ret = requests.post(controller_addr + "/refresh_all_workers")
ret = requests.post(controller_addr + "/list_models")
models = ret.json()["models"]
models.sort()
print(f"Models: {models}")
ret = requests.post(
controller_addr + "/get_worker_address", json={"model": model_name}
)
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
if worker_addr == "":
print(f"No available workers for {model_name}")
return
headers = {"User-Agent": "GSAM Client"}
if args.send_image:
img = load_image(args.image_path)
img_arg = encode(img)
else:
img_arg = args.image_path
datas = {
"model": model_name,
"image": img_arg,
"box_threshold": args.box_threshold,
"text_threshold": args.text_threshold,
}
tic = time.time()
response = requests.post(
worker_addr + "/worker_generate",
headers=headers,
json=datas,
)
toc = time.time()
print(f"Time: {toc - tic:.3f}s")
print("detection result:")
print(response.json())
# import ipdb; ipdb.set_trace()
# response is 'Response' with :
# ['_content', '_content_consumed', '_next', 'status_code', 'headers', 'raw', 'url', 'encoding', 'history', 'reason', 'cookies', 'elapsed', 'request', 'connection', '__module__', '__doc__', '__attrs__', '__init__', '__enter__', '__exit__', '__getstate__', '__setstate__', '__repr__', '__bool__', '__nonzero__', '__iter__', 'ok', 'is_redirect', 'is_permanent_redirect', 'next', 'apparent_encoding', 'iter_content', 'iter_lines', 'content', 'text', 'json', 'links', 'raise_for_status', 'close', '__dict__', '__weakref__', '__hash__', '__str__', '__getattribute__', '__setattr__', '__delattr__', '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__', '__new__', '__reduce_ex__', '__reduce__', '__subclasshook__', '__init_subclass__', '__format__', '__sizeof__', '__dir__', '__class__']
# # visualize
# res = response.json()
# boxes = torch.Tensor(res["boxes"])
# logits = torch.Tensor(res["logits"])
# phrases = res["phrases"]
# image_source = np.array(Image.open(args.image_path).convert("RGB"))
# annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
# cv2.imwrite("annotated_image.jpg", annotated_frame)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# worker parameters
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default='clip')
# model parameters
parser.add_argument(
"--image_path", type=str, default="assets/demo2.jpg"
)
parser.add_argument(
"--box_threshold", type=float, default=0.3,
)
parser.add_argument(
"--text_threshold", type=float, default=0.25,
)
parser.add_argument(
"--send_image", action="store_true",
)
args = parser.parse_args()
main()
|