Spaces:
Build error
Build error
File size: 4,499 Bytes
36a599e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import modal
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from PIL import Image
import io
import base64
from typing import Optional
import traceback
# Create app and web app
app = modal.App("ui-coordinates-finder")
web_app = FastAPI()
# Add your model initialization to the app
@app.function(gpu="T4")
def init_models():
from utils import get_yolo_model, get_caption_model_processor
yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
caption_model_processor = get_caption_model_processor(
model_name="florence2",
model_name_or_path="weights/icon_caption_florence"
)
return yolo_model, caption_model_processor
# Configure CORS
web_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.function(gpu="T4", timeout=300)
@web_app.post("/process")
async def process_image_endpoint(
request: Request,
file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
screen_width: int = 1920,
screen_height: int = 1080
):
try:
# Add logging for debugging
print(f"Processing file: {file.filename}")
# Read and process the image
contents = await file.read()
print("File read successfully")
# Save image temporarily
image_save_path = '/tmp/saved_image_demo.png'
image = Image.open(io.BytesIO(contents))
image.save(image_save_path)
# Initialize models
yolo_model, caption_model_processor = init_models()
# Process with OCR and detection
from utils import check_ocr_box, get_som_labeled_img
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 2,
'thickness': 2,
}
ocr_bbox_rslt, _ = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold': 0.9}
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold
)
# Format the output similar to Gradio demo
output_text = []
for i, (element_id, coords) in enumerate(label_coordinates.items()):
x, y, w, h = coords
# Calculate center points (normalized)
center_x_norm = x + (w/2)
center_y_norm = y + (h/2)
# Calculate screen coordinates
screen_x = int(center_x_norm * screen_width)
screen_y = int(center_y_norm * screen_height)
screen_w = int(w * screen_width)
screen_h = int(h * screen_height)
if i < len(parsed_content_list):
element_desc = parsed_content_list[i]
output_text.append({
"description": element_desc,
"normalized_coords": (center_x_norm, center_y_norm),
"screen_coords": (screen_x, screen_y),
"dimensions": (screen_w, screen_h)
})
return JSONResponse(
status_code=200,
content={
"message": "Success",
"filename": file.filename,
"processed_image": dino_labled_img, # Base64 encoded image
"elements": output_text
}
)
except Exception as e:
error_details = traceback.format_exc()
print(f"Error processing request: {error_details}")
return JSONResponse(
status_code=500,
content={
"error": str(e),
"details": error_details
}
)
@app.function()
@modal.asgi_app()
def fastapi_app():
return web_app
if __name__ == "__main__":
app.serve() |