from fastapi import FastAPI, UploadFile, File, HTTPException from pydantic import BaseModel from PIL import Image import io import torch from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded # Import your existing utilities and models from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img # Initialize FastAPI app app = FastAPI(title="OmniParser API") app.state.limiter = Limiter(key_func=get_remote_address) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # Load models at startup (reusing your existing code) yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt') caption_model_processor = get_caption_model_processor( model_name="florence2", model_name_or_path="weights/icon_caption_florence" ) # Define request model class ProcessRequest(BaseModel): box_threshold: float = 0.05 iou_threshold: float = 0.1 screen_width: int = 1920 screen_height: int = 1080 @app.post("/process") @app.state.limiter.limit("5/minute") # Limit to 5 requests per minute per IP async def process_image( file: UploadFile = File(...), params: ProcessRequest = None ): # Read image from request image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) # Save image temporarily (reusing your existing logic) temp_path = 'imgs/temp_image.png' image.save(temp_path) # Process image using your existing functions ocr_bbox_rslt, _ = check_ocr_box( temp_path, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9} ) text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( temp_path, yolo_model, BOX_TRESHOLD=params.box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config={ 'text_scale': 0.8, 'text_thickness': 2, 'text_padding': 2, 'thickness': 2, }, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=params.iou_threshold ) # Format output (similar to your existing code) output_text = [] for i, (element_id, coords) in enumerate(label_coordinates.items()): x, y, w, h = coords center_x_norm = x + (w/2) center_y_norm = y + (h/2) screen_x = int(center_x_norm * params.screen_width) screen_y = int(center_y_norm * params.screen_height) screen_w = int(w * params.screen_width) screen_h = int(h * params.screen_height) element_desc = parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}" output_text.append({ "description": element_desc, "normalized_coordinates": { "x": center_x_norm, "y": center_y_norm }, "screen_coordinates": { "x": screen_x, "y": screen_y }, "dimensions": { "width": screen_w, "height": screen_h } }) return { "processed_image": dino_labled_img, # Base64 encoded image "elements": output_text }