zorba111's picture
Upload folder using huggingface_hub
2ad48f3 verified
raw
history blame
3.47 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from PIL import Image
import io
import torch
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# Import your existing utilities and models
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
# Initialize FastAPI app
app = FastAPI(title="OmniParser API")
app.state.limiter = Limiter(key_func=get_remote_address)
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Load models at startup (reusing your existing code)
yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
caption_model_processor = get_caption_model_processor(
model_name="florence2",
model_name_or_path="weights/icon_caption_florence"
)
# Define request model
class ProcessRequest(BaseModel):
box_threshold: float = 0.05
iou_threshold: float = 0.1
screen_width: int = 1920
screen_height: int = 1080
@app.post("/process")
@app.state.limiter.limit("5/minute") # Limit to 5 requests per minute per IP
async def process_image(
file: UploadFile = File(...),
params: ProcessRequest = None
):
# Read image from request
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
# Save image temporarily (reusing your existing logic)
temp_path = 'imgs/temp_image.png'
image.save(temp_path)
# Process image using your existing functions
ocr_bbox_rslt, _ = check_ocr_box(
temp_path,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold':0.9}
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
temp_path,
yolo_model,
BOX_TRESHOLD=params.box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config={
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 2,
'thickness': 2,
},
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=params.iou_threshold
)
# Format output (similar to your existing code)
output_text = []
for i, (element_id, coords) in enumerate(label_coordinates.items()):
x, y, w, h = coords
center_x_norm = x + (w/2)
center_y_norm = y + (h/2)
screen_x = int(center_x_norm * params.screen_width)
screen_y = int(center_y_norm * params.screen_height)
screen_w = int(w * params.screen_width)
screen_h = int(h * params.screen_height)
element_desc = parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}"
output_text.append({
"description": element_desc,
"normalized_coordinates": {
"x": center_x_norm,
"y": center_y_norm
},
"screen_coordinates": {
"x": screen_x,
"y": screen_y
},
"dimensions": {
"width": screen_w,
"height": screen_h
}
})
return {
"processed_image": dino_labled_img, # Base64 encoded image
"elements": output_text
}