Spaces:
Build error
Build error
File size: 3,469 Bytes
2ad48f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from PIL import Image
import io
import torch
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# Import your existing utilities and models
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
# Initialize FastAPI app
app = FastAPI(title="OmniParser API")
app.state.limiter = Limiter(key_func=get_remote_address)
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Load models at startup (reusing your existing code)
yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
caption_model_processor = get_caption_model_processor(
model_name="florence2",
model_name_or_path="weights/icon_caption_florence"
)
# Define request model
class ProcessRequest(BaseModel):
box_threshold: float = 0.05
iou_threshold: float = 0.1
screen_width: int = 1920
screen_height: int = 1080
@app.post("/process")
@app.state.limiter.limit("5/minute") # Limit to 5 requests per minute per IP
async def process_image(
file: UploadFile = File(...),
params: ProcessRequest = None
):
# Read image from request
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
# Save image temporarily (reusing your existing logic)
temp_path = 'imgs/temp_image.png'
image.save(temp_path)
# Process image using your existing functions
ocr_bbox_rslt, _ = check_ocr_box(
temp_path,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold':0.9}
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
temp_path,
yolo_model,
BOX_TRESHOLD=params.box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config={
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 2,
'thickness': 2,
},
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=params.iou_threshold
)
# Format output (similar to your existing code)
output_text = []
for i, (element_id, coords) in enumerate(label_coordinates.items()):
x, y, w, h = coords
center_x_norm = x + (w/2)
center_y_norm = y + (h/2)
screen_x = int(center_x_norm * params.screen_width)
screen_y = int(center_y_norm * params.screen_height)
screen_w = int(w * params.screen_width)
screen_h = int(h * params.screen_height)
element_desc = parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}"
output_text.append({
"description": element_desc,
"normalized_coordinates": {
"x": center_x_norm,
"y": center_y_norm
},
"screen_coordinates": {
"x": screen_x,
"y": screen_y
},
"dimensions": {
"width": screen_w,
"height": screen_h
}
})
return {
"processed_image": dino_labled_img, # Base64 encoded image
"elements": output_text
} |