|
from typing import NamedTuple, List, Callable |
|
import sys |
|
import re |
|
import numpy as np |
|
import torch |
|
from numpy.linalg import norm |
|
from itertools import product, groupby |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
def ccw(A, B, C): |
|
return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) |
|
|
|
|
|
def intersect(A, B, C, D): |
|
"""Do line segments AB and CD intersect?""" |
|
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) |
|
|
|
|
|
class Box(NamedTuple): |
|
x: int |
|
y: int |
|
w: int = 0 |
|
h: int = 0 |
|
|
|
@property |
|
def left(self): |
|
return self.x |
|
|
|
@property |
|
def right(self): |
|
return self.x + self.w |
|
|
|
@property |
|
def top(self): |
|
return self.y |
|
|
|
@property |
|
def bottom(self): |
|
return self.y + self.h |
|
|
|
@property |
|
def center(self): |
|
return Box(self.x + self.w // 2, self.y + self.h // 2) |
|
|
|
def corners(self): |
|
yield Box(self.x, self.y) |
|
yield Box(self.x + self.w, self.y) |
|
yield Box(self.x + self.w, self.y + self.h) |
|
yield Box(self.x, self.y + self.h) |
|
|
|
@property |
|
def area(self): |
|
return self.w * self.h |
|
|
|
def intersect(self, other: "Box") -> "Box": |
|
x1 = max(self.x, other.x) |
|
x2 = max(x1, min(self.x+self.w, other.x+other.w)) |
|
y1 = max(self.y, other.y) |
|
y2 = max(y1, min(self.y+self.h, other.y+other.h)) |
|
return Box(x=x1, y=y1, w=x2-x1, h=y2-y1) |
|
|
|
def min_bounding(self, other: "Box") -> "Box": |
|
corners = list(self.corners()) |
|
corners.extend(other.corners()) |
|
min_x = min_y = float("inf") |
|
max_x = max_y = -float("inf") |
|
|
|
for item in corners: |
|
min_x = min(min_x, item.x) |
|
min_y = min(min_y, item.y) |
|
max_x = max(max_x, item.x) |
|
max_y = max(max_y, item.y) |
|
|
|
return Box(min_x, min_y, max_x - min_x, max_y - min_y) |
|
|
|
def expand(self, growth: float = .1) -> "Box": |
|
factor = 1 + growth |
|
w = factor * self.w |
|
h = factor * self.h |
|
return Box(min_x - (w - self.w) / 2, min_y - (h - self.h) / 2, w, h) |
|
|
|
|
|
def iou(box1, box2): |
|
x1 = max(box1.x, box2.x) |
|
x2 = max(x1, min(box1.x+box1.w, box2.x+box2.w)) |
|
y1 = max(box1.y, box2.y) |
|
y2 = max(y1, min(box1.y+box1.h, box2.y+box2.h)) |
|
intersection = Box(x=x1, y=y1, w=x2-x1, h=y2-y1) |
|
intersection_area = intersection.area |
|
union_area = box1.area+box2.area-intersection_area |
|
return intersection_area / union_area |
|
|
|
|
|
def all_equal(iterable): |
|
"""Are all elements the same?""" |
|
g = groupby(iterable) |
|
return next(g, True) and not next(g, False) |
|
|
|
|
|
class spatial: |
|
"""A decorator that converts a predicate over boxes to a function that returns a tensor over all boxes.""" |
|
|
|
def __init__(self, arity: int = 2, enforce_antisymmetry: bool = False): |
|
self.arity = arity |
|
self.enforce_antisymmetry = enforce_antisymmetry |
|
|
|
def __call__(self, predicate: Callable[[Box], float]) -> Callable[["Environment"], np.ndarray]: |
|
def _rel(env): |
|
n_boxes = len(env.boxes) |
|
tensor = np.empty([n_boxes for _ in range(self.arity)]) |
|
enum_boxes = list(enumerate(env.boxes)) |
|
for pairs in product(*[enum_boxes for _ in range(self.arity)]): |
|
indices, boxes = zip(*pairs) |
|
if self.enforce_antisymmetry and len(set(indices)) < len(indices): |
|
tensor[indices] = 0. |
|
else: |
|
tensor[indices] = predicate(*boxes) |
|
return tensor |
|
return _rel |
|
|
|
|
|
class Environment: |
|
def __init__(self, image: Image, boxes: List[Box], executor: "Executor" = None, freeform_boxes: bool = False, image_name: str = None, image_pth: str=None): |
|
self.image = image |
|
self.boxes = boxes |
|
self.executor = executor |
|
self.freeform_boxes = freeform_boxes |
|
self.image_name = image_name |
|
self.image_pth=image_pth |
|
|
|
def uniform(self) -> np.ndarray: |
|
n_boxes = len(self.boxes) |
|
return 1 / n_boxes * np.ones(n_boxes) |
|
|
|
def filter(self, |
|
caption: str, |
|
temperature: float = 1., |
|
area_threshold: float = 0.0, |
|
softmax: bool = False, |
|
expand: float = None |
|
) -> np.ndarray: |
|
"""Return a new distribution reflecting the likelihood that `caption` describes the content of each box.""" |
|
area_filtered_dist = torch.from_numpy(self.filter_area(area_threshold)).to(self.executor.device) |
|
candidate_indices = [i for i in range(len(self.boxes)) if float(area_filtered_dist[i]) > 0.0] |
|
boxes = [self.boxes[i] for i in candidate_indices] |
|
if len(boxes) == 0: |
|
boxes = self.boxes |
|
candidate_indices = list(range(len(boxes))) |
|
if expand is not None: |
|
boxes = [box.expand(expand) for box in boxes] |
|
result_partial = self.executor(caption, self.image, boxes, image_name=self.image_name, image_pth=self.image_pth) |
|
if self.freeform_boxes: |
|
result_partial, boxes = result_partial |
|
self.boxes = [Box(x=boxes[i,0].item(), y=boxes[i,1].item(), w=boxes[i,2].item()-boxes[i,0].item(), h=boxes[i,3].item()-boxes[i,1].item()) for i in range(boxes.shape[0])] |
|
candidate_indices = list(range(len(self.boxes))) |
|
result_partial = result_partial.float() |
|
if not softmax: |
|
result_partial = (result_partial-result_partial.mean()) / (result_partial.std() + 1e-9) |
|
result_partial = (temperature * result_partial).sigmoid() |
|
result = torch.zeros((len(self.boxes))).to(result_partial.device) |
|
result[candidate_indices] = result_partial |
|
else: |
|
result = torch.zeros((len(self.boxes))).to(result_partial.device) |
|
result[candidate_indices] = result_partial.softmax(dim=-1) |
|
return result.cpu().numpy() |
|
|
|
def filter_area(self, area_threshold: float) -> np.ndarray: |
|
"""Return a new distribution in which all boxes whose area as a fraction of the image is less than the threshold.""" |
|
image_area = self.image.width*self.image.height |
|
return np.array([1 if self.boxes[i].area/image_area > area_threshold else 0 for i in range(len(self.boxes))]) |
|
|
|
@spatial() |
|
def left_of(b1, b2): |
|
return (b1.right+b1.left) / 2 < (b2.right+b2.left) / 2 |
|
|
|
@spatial() |
|
def right_of(b1, b2): |
|
return (b1.right+b1.left) / 2 > (b2.right+b2.left) / 2 |
|
|
|
@spatial() |
|
def above(b1, b2): |
|
return (b1.bottom+b1.top) < (b2.bottom+b2.top) |
|
|
|
@spatial() |
|
def below(b1, b2): |
|
return (b1.bottom+b1.top) > (b2.bottom+b2.top) |
|
|
|
@spatial() |
|
def bigger_than(b1, b2): |
|
return b1.area > b2.area |
|
|
|
@spatial() |
|
def smaller_than(b1, b2): |
|
return b1.area < b2.area |
|
|
|
@spatial(enforce_antisymmetry=False) |
|
def within(box1, box2): |
|
"""Return percent of box1 inside box2.""" |
|
intersection = box1.intersect(box2) |
|
return intersection.area / box1.area |
|
|
|
@spatial(arity=3, enforce_antisymmetry=True) |
|
def between(box1, box2, box3): |
|
"""How much of box1 lies in min bounding box over box2 and box3?""" |
|
min_bounding = box2.min_bounding(box3) |
|
intersect = box1.intersect(min_bounding) |
|
return intersect.area / box1.area |
|
|