File size: 3,469 Bytes
2ad48f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from PIL import Image
import io
import torch
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

# Import your existing utilities and models
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img

# Initialize FastAPI app
app = FastAPI(title="OmniParser API")
app.state.limiter = Limiter(key_func=get_remote_address)
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Load models at startup (reusing your existing code)
yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
caption_model_processor = get_caption_model_processor(
    model_name="florence2", 
    model_name_or_path="weights/icon_caption_florence"
)

# Define request model
class ProcessRequest(BaseModel):
    box_threshold: float = 0.05
    iou_threshold: float = 0.1
    screen_width: int = 1920
    screen_height: int = 1080

@app.post("/process")
@app.state.limiter.limit("5/minute")  # Limit to 5 requests per minute per IP
async def process_image(
    file: UploadFile = File(...),
    params: ProcessRequest = None
):
    # Read image from request
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes))
    
    # Save image temporarily (reusing your existing logic)
    temp_path = 'imgs/temp_image.png'
    image.save(temp_path)
    
    # Process image using your existing functions
    ocr_bbox_rslt, _ = check_ocr_box(
        temp_path, 
        display_img=False, 
        output_bb_format='xyxy', 
        goal_filtering=None, 
        easyocr_args={'paragraph': False, 'text_threshold':0.9}
    )
    
    text, ocr_bbox = ocr_bbox_rslt
    
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
        temp_path,
        yolo_model,
        BOX_TRESHOLD=params.box_threshold,
        output_coord_in_ratio=True,
        ocr_bbox=ocr_bbox,
        draw_bbox_config={
            'text_scale': 0.8,
            'text_thickness': 2,
            'text_padding': 2,
            'thickness': 2,
        },
        caption_model_processor=caption_model_processor,
        ocr_text=text,
        iou_threshold=params.iou_threshold
    )
    
    # Format output (similar to your existing code)
    output_text = []
    for i, (element_id, coords) in enumerate(label_coordinates.items()):
        x, y, w, h = coords
        center_x_norm = x + (w/2)
        center_y_norm = y + (h/2)
        screen_x = int(center_x_norm * params.screen_width)
        screen_y = int(center_y_norm * params.screen_height)
        screen_w = int(w * params.screen_width)
        screen_h = int(h * params.screen_height)
        
        element_desc = parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}"
        output_text.append({
            "description": element_desc,
            "normalized_coordinates": {
                "x": center_x_norm,
                "y": center_y_norm
            },
            "screen_coordinates": {
                "x": screen_x,
                "y": screen_y
            },
            "dimensions": {
                "width": screen_w,
                "height": screen_h
            }
        })
    
    return {
        "processed_image": dino_labled_img,  # Base64 encoded image
        "elements": output_text
    }