File size: 1,997 Bytes
b7f49b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from mmdet.apis import init_detector, inference_detector
import gradio as gr
import cv2
import sys
import torch
import numpy as np

print('Loading model...')
device = 'gpu' if torch.cuda.is_available() else 'cpu'

table_det = init_detector('model/table-det/config.py',
                          'model/table-det/model.pth', device=device)
def get_corners(points):
    """
    Returns the top-left, top-right, bottom-right, and bottom-left corners
    of a rectangle defined by a list of four points in the form of tuples.
    """
    # Sort points by x-coordinate
    sorted_points = sorted(points, key=lambda p: p[0])
    
    # Split sorted points into left and right halves
    left_points = sorted_points[:2]
    right_points = sorted_points[2:]
    
    # Sort left and right points by y-coordinate
    left_points = sorted(left_points, key=lambda p: p[1])
    right_points = sorted(right_points, key=lambda p: p[1], reverse=True)
    
    # Return corners in order: top-left, top-right, bottom-right, bottom-left
    return (left_points[0], right_points[0], right_points[1], left_points[1])

def funct(mask_array):
  table_images = []
  table_bboxes = []
  contours, hierarchy = cv2.findContours(mask_array, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  for cnt in contours:
    rect = cv2.minAreaRect(cnt)
    box = cv2.boxPoints(rect)
    epsilon = cv2.arcLength(cnt,True)
    approx = cv2.approxPolyDP(cnt, 0.02*epsilon, True)
    points = np.squeeze(approx)
    if len(points) != 4:
      points = box
    tl, br, bl, tr = get_corners(points.tolist())
    table_bboxes.append([tl, tr, br, bl])
  return table_bboxes
        
    
def predict(image_input):
    results = inference_detector(table_det, image_input)
    print(results)
    return {'message': 'success'}

def run():
    demo = gr.Interface(
        fn=predict,
        inputs=gr.components.Image(),
        outputs=gr.JSON(),
    )

    demo.launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    run()