File size: 4,499 Bytes
36a599e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import modal
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from PIL import Image
import io
import base64
from typing import Optional
import traceback

# Create app and web app
app = modal.App("ui-coordinates-finder")
web_app = FastAPI()

# Add your model initialization to the app
@app.function(gpu="T4")
def init_models():
    from utils import get_yolo_model, get_caption_model_processor
    
    yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
    caption_model_processor = get_caption_model_processor(
        model_name="florence2", 
        model_name_or_path="weights/icon_caption_florence"
    )
    return yolo_model, caption_model_processor

# Configure CORS
web_app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.function(gpu="T4", timeout=300)
@web_app.post("/process")
async def process_image_endpoint(
    request: Request, 
    file: UploadFile = File(...),
    box_threshold: float = 0.05,
    iou_threshold: float = 0.1,
    screen_width: int = 1920,
    screen_height: int = 1080
):
    try:
        # Add logging for debugging
        print(f"Processing file: {file.filename}")
        
        # Read and process the image
        contents = await file.read()
        print("File read successfully")
        
        # Save image temporarily
        image_save_path = '/tmp/saved_image_demo.png'
        image = Image.open(io.BytesIO(contents))
        image.save(image_save_path)
        
        # Initialize models
        yolo_model, caption_model_processor = init_models()
        
        # Process with OCR and detection
        from utils import check_ocr_box, get_som_labeled_img
        
        draw_bbox_config = {
            'text_scale': 0.8,
            'text_thickness': 2,
            'text_padding': 2,
            'thickness': 2,
        }
        
        ocr_bbox_rslt, _ = check_ocr_box(
            image_save_path, 
            display_img=False, 
            output_bb_format='xyxy',
            goal_filtering=None, 
            easyocr_args={'paragraph': False, 'text_threshold': 0.9}
        )
        text, ocr_bbox = ocr_bbox_rslt
        
        dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
            image_save_path,
            yolo_model,
            BOX_TRESHOLD=box_threshold,
            output_coord_in_ratio=True,
            ocr_bbox=ocr_bbox,
            draw_bbox_config=draw_bbox_config,
            caption_model_processor=caption_model_processor,
            ocr_text=text,
            iou_threshold=iou_threshold
        )
        
        # Format the output similar to Gradio demo
        output_text = []
        for i, (element_id, coords) in enumerate(label_coordinates.items()):
            x, y, w, h = coords
            
            # Calculate center points (normalized)
            center_x_norm = x + (w/2)
            center_y_norm = y + (h/2)
            
            # Calculate screen coordinates
            screen_x = int(center_x_norm * screen_width)
            screen_y = int(center_y_norm * screen_height)
            screen_w = int(w * screen_width)
            screen_h = int(h * screen_height)
            
            if i < len(parsed_content_list):
                element_desc = parsed_content_list[i]
                output_text.append({
                    "description": element_desc,
                    "normalized_coords": (center_x_norm, center_y_norm),
                    "screen_coords": (screen_x, screen_y),
                    "dimensions": (screen_w, screen_h)
                })
        
        return JSONResponse(
            status_code=200,
            content={
                "message": "Success",
                "filename": file.filename,
                "processed_image": dino_labled_img,  # Base64 encoded image
                "elements": output_text
            }
        )
        
    except Exception as e:
        error_details = traceback.format_exc()
        print(f"Error processing request: {error_details}")
        return JSONResponse(
            status_code=500,
            content={
                "error": str(e),
                "details": error_details
            }
        )

@app.function()
@modal.asgi_app()
def fastapi_app():
    return web_app

if __name__ == "__main__":
    app.serve()