|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image, ImageEnhance, ImageDraw |
|
import torch |
|
import streamlit as st |
|
from model.inference_cpu import inference_case |
|
|
|
initial_rectangle = { |
|
"version": "4.4.0", |
|
'objects': [ |
|
{ |
|
"type": "rect", |
|
"version": "4.4.0", |
|
"originX": "left", |
|
"originY": "top", |
|
"left": 50, |
|
"top": 50, |
|
"width": 100, |
|
"height": 100, |
|
'fill': 'rgba(255, 165, 0, 0.3)', |
|
'stroke': '#2909F1', |
|
'strokeWidth': 3, |
|
'strokeDashArray': None, |
|
'strokeLineCap': 'butt', |
|
'strokeDashOffset': 0, |
|
'strokeLineJoin': 'miter', |
|
'strokeUniform': True, |
|
'strokeMiterLimit': 4, |
|
'scaleX': 1, |
|
'scaleY': 1, |
|
'angle': 0, |
|
'flipX': False, |
|
'flipY': False, |
|
'opacity': 1, |
|
'shadow': None, |
|
'visible': True, |
|
'backgroundColor': '', |
|
'fillRule': |
|
'nonzero', |
|
'paintFirst': |
|
'fill', |
|
'globalCompositeOperation': 'source-over', |
|
'skewX': 0, |
|
'skewY': 0, |
|
'rx': 0, |
|
'ry': 0 |
|
} |
|
] |
|
} |
|
|
|
def run(): |
|
image = st.session_state.data_item["image"].float() |
|
image_zoom_out = st.session_state.data_item["zoom_out_image"].float() |
|
text_prompt = None |
|
point_prompt = None |
|
box_prompt = None |
|
if st.session_state.use_text_prompt: |
|
text_prompt = st.session_state.text_prompt |
|
if st.session_state.use_point_prompt and len(st.session_state.points) > 0: |
|
point_prompt = reflect_points_into_model(st.session_state.points) |
|
if st.session_state.use_box_prompt: |
|
box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox) |
|
inference_case.clear() |
|
st.session_state.preds_3D = inference_case(image, image_zoom_out, |
|
text_prompt=text_prompt, |
|
_point_prompt=point_prompt, |
|
_box_prompt=box_prompt) |
|
|
|
def reflect_box_into_model(box_3d): |
|
z1, y1, x1, z2, y2, x2 = box_3d |
|
x1_prompt = int(x1 * 256.0 / 325.0) |
|
y1_prompt = int(y1 * 256.0 / 325.0) |
|
z1_prompt = int(z1 * 32.0 / 325.0) |
|
x2_prompt = int(x2 * 256.0 / 325.0) |
|
y2_prompt = int(y2 * 256.0 / 325.0) |
|
z2_prompt = int(z2 * 32.0 / 325.0) |
|
return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])) |
|
|
|
def reflect_json_data_to_3D_box(json_data, view): |
|
if view == 'xy': |
|
st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top'] |
|
st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left'] |
|
st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY'] |
|
st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX'] |
|
print(st.session_state.rectangle_3Dbox) |
|
|
|
def reflect_points_into_model(points): |
|
points_prompt_list = [] |
|
for point in points: |
|
z, y, x = point |
|
x_prompt = int(x * 256.0 / 325.0) |
|
y_prompt = int(y * 256.0 / 325.0) |
|
z_prompt = int(z * 32.0 / 325.0) |
|
points_prompt_list.append([z_prompt, y_prompt, x_prompt]) |
|
points_prompt = np.array(points_prompt_list) |
|
points_label = np.ones(points_prompt.shape[0]) |
|
print(points_prompt, points_label) |
|
return (torch.tensor(points_prompt), torch.tensor(points_label)) |
|
|
|
def show_points(points_ax, points_label, ax): |
|
color = 'red' if points_label == 0 else 'blue' |
|
ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200) |
|
|
|
def make_fig(image, preds, point_axs=None, current_idx=None, view=None): |
|
|
|
image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB") |
|
enhancer = ImageEnhance.Contrast(image) |
|
image = enhancer.enhance(2.0) |
|
|
|
|
|
if preds is not None: |
|
mask = np.where(preds == 1, 255, 0).astype(np.uint8) |
|
mask = Image.merge("RGB", |
|
(Image.fromarray(mask), |
|
Image.fromarray(mask), |
|
Image.fromarray(np.zeros_like(mask, dtype=np.uint8)))) |
|
|
|
|
|
image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency) |
|
|
|
if point_axs is not None: |
|
draw = ImageDraw.Draw(image) |
|
radius = 5 |
|
for point in point_axs: |
|
z, y, x = point |
|
if view == 'xy' and z == current_idx: |
|
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue") |
|
elif view == 'xz'and y == current_idx: |
|
draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue") |
|
return image |