from typing import Union

from fastapi import FastAPI,File
from PIL import Image
from transformers import pipeline
from io import BytesIO
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch


app = FastAPI(title="Object Detection",
    docs_url="/", 
    description="Object detection in Image")


processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")



@app.post('/image')
def read_image(image_file: bytes = File(...)):
    image = Image.open(BytesIO(image_file))
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
    return results