Spaces:
Paused
Paused
sachin
commited on
Commit
·
3e58bef
1
Parent(s):
998f798
auto-segmt
Browse files- merged_code.py +179 -0
merged_code.py
CHANGED
@@ -455,6 +455,185 @@ async def fit_image_to_mask_endpoint(
|
|
455 |
except Exception as e:
|
456 |
raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
|
459 |
if __name__ == "__main__":
|
460 |
import uvicorn
|
|
|
455 |
except Exception as e:
|
456 |
raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
|
457 |
|
458 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
459 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
460 |
+
import torch
|
461 |
+
from PIL import Image, ImageDraw, ImageFont
|
462 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
463 |
+
import io
|
464 |
+
|
465 |
+
|
466 |
+
|
467 |
+
# Set up model and device
|
468 |
+
model_id_segment = "IDEA-Research/grounding-dino-base"
|
469 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
470 |
+
|
471 |
+
# Load processor and model at startup
|
472 |
+
processor_segment = AutoProcessor.from_pretrained(model_id_segment)
|
473 |
+
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id_segment).to(device)
|
474 |
+
|
475 |
+
# Default text query (can be overridden via endpoint parameters)
|
476 |
+
DEFAULT_TEXT_QUERY = "a tank." # Adjust based on your use case
|
477 |
+
|
478 |
+
|
479 |
+
|
480 |
+
def process_image(image: Image.Image, text_query: str = DEFAULT_TEXT_QUERY):
|
481 |
+
"""Process the image with Grounding DINO and return detection results."""
|
482 |
+
# Prepare inputs for the model
|
483 |
+
inputs = processor_segment(images=image, text=text_query, return_tensors="pt").to(device)
|
484 |
+
|
485 |
+
# Perform inference
|
486 |
+
with torch.no_grad():
|
487 |
+
outputs = model(**inputs)
|
488 |
+
|
489 |
+
# Post-process results
|
490 |
+
results = processor_segment.post_process_grounded_object_detection(
|
491 |
+
outputs,
|
492 |
+
inputs.input_ids,
|
493 |
+
threshold=0.4,
|
494 |
+
text_threshold=0.3,
|
495 |
+
target_sizes=[image.size[::-1]] # [width, height]
|
496 |
+
)
|
497 |
+
return results
|
498 |
+
|
499 |
+
def draw_detections(image: Image.Image, results: list) -> Image.Image:
|
500 |
+
"""Draw bounding boxes and labels on the image."""
|
501 |
+
output_image = image.copy()
|
502 |
+
draw = ImageDraw.Draw(output_image)
|
503 |
+
|
504 |
+
# Try to load a font, fall back to default
|
505 |
+
try:
|
506 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
507 |
+
except:
|
508 |
+
font = ImageFont.load_default()
|
509 |
+
|
510 |
+
# Colors for different objects
|
511 |
+
colors = {"a tank": "red"} # Add more as needed, e.g., {"a cat": "red", "a remote control": "blue"}
|
512 |
+
|
513 |
+
# Draw bounding boxes and labels
|
514 |
+
for detection in results:
|
515 |
+
boxes = detection["boxes"]
|
516 |
+
labels = detection["labels"]
|
517 |
+
scores = detection["scores"]
|
518 |
+
|
519 |
+
for box, label, score in zip(boxes, labels, scores):
|
520 |
+
x_min, y_min, x_max, y_max = box.tolist()
|
521 |
+
|
522 |
+
# Draw rectangle
|
523 |
+
draw.rectangle(
|
524 |
+
[(x_min, y_min), (x_max, y_max)],
|
525 |
+
outline=colors.get(label, "green"),
|
526 |
+
width=2
|
527 |
+
)
|
528 |
+
|
529 |
+
# Draw label with score
|
530 |
+
label_text = f"{label} {score:.2f}"
|
531 |
+
bbox = draw.textbbox((x_min, y_min - 20), label_text, font=font)
|
532 |
+
text_width = bbox[2] - bbox[0]
|
533 |
+
text_height = bbox[3] - bbox[1]
|
534 |
+
|
535 |
+
# Draw background rectangle for text
|
536 |
+
draw.rectangle(
|
537 |
+
[(x_min, y_min - text_height - 5), (x_min + text_width, y_min)],
|
538 |
+
fill=colors.get(label, "green")
|
539 |
+
)
|
540 |
+
|
541 |
+
# Draw text
|
542 |
+
draw.text(
|
543 |
+
(x_min, y_min - text_height - 5),
|
544 |
+
label_text,
|
545 |
+
fill="white",
|
546 |
+
font=font
|
547 |
+
)
|
548 |
+
|
549 |
+
return output_image
|
550 |
+
|
551 |
+
@app.post("/detect-image/")
|
552 |
+
async def detect_image(
|
553 |
+
file: UploadFile = File(..., description="Image file to process"),
|
554 |
+
text_query: str = DEFAULT_TEXT_QUERY
|
555 |
+
):
|
556 |
+
"""
|
557 |
+
Endpoint to detect objects in an image and return the annotated image.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
file: Uploaded image file.
|
561 |
+
text_query: Text query for objects to detect (e.g., "a tank.").
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
StreamingResponse with the annotated image.
|
565 |
+
"""
|
566 |
+
try:
|
567 |
+
# Read and convert the uploaded image
|
568 |
+
image_data = await file.read()
|
569 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
570 |
+
|
571 |
+
# Process the image
|
572 |
+
results = process_image(image, text_query)
|
573 |
+
|
574 |
+
# Draw detections on the image
|
575 |
+
output_image = draw_detections(image, results)
|
576 |
+
|
577 |
+
# Convert to bytes for response
|
578 |
+
img_byte_arr = io.BytesIO()
|
579 |
+
output_image.save(img_byte_arr, format="PNG")
|
580 |
+
img_byte_arr.seek(0)
|
581 |
+
|
582 |
+
return StreamingResponse(
|
583 |
+
img_byte_arr,
|
584 |
+
media_type="image/png",
|
585 |
+
headers={"Content-Disposition": "attachment; filename=detected_objects.png"}
|
586 |
+
)
|
587 |
+
except Exception as e:
|
588 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
589 |
+
|
590 |
+
@app.post("/detect-json/")
|
591 |
+
async def detect_json(
|
592 |
+
file: UploadFile = File(..., description="Image file to process"),
|
593 |
+
text_query: str = DEFAULT_TEXT_QUERY
|
594 |
+
):
|
595 |
+
"""
|
596 |
+
Endpoint to detect objects in an image and return bounding box information as JSON.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
file: Uploaded image file.
|
600 |
+
text_query: Text query for objects to detect (e.g., "a tank.").
|
601 |
+
|
602 |
+
Returns:
|
603 |
+
JSONResponse with bounding box coordinates, labels, and scores.
|
604 |
+
"""
|
605 |
+
try:
|
606 |
+
# Read and convert the uploaded image
|
607 |
+
image_data = await file.read()
|
608 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
609 |
+
|
610 |
+
# Process the image
|
611 |
+
results = process_image(image, text_query)
|
612 |
+
|
613 |
+
# Format results as JSON-compatible data
|
614 |
+
detections = []
|
615 |
+
for detection in results:
|
616 |
+
boxes = detection["boxes"]
|
617 |
+
labels = detection["labels"]
|
618 |
+
scores = detection["scores"]
|
619 |
+
|
620 |
+
for box, label, score in zip(boxes, labels, scores):
|
621 |
+
x_min, y_min, x_max, y_max = box.tolist()
|
622 |
+
detections.append({
|
623 |
+
"label": label,
|
624 |
+
"score": float(score), # Convert tensor to float
|
625 |
+
"box": {
|
626 |
+
"x_min": x_min,
|
627 |
+
"y_min": y_min,
|
628 |
+
"x_max": x_max,
|
629 |
+
"y_max": y_max
|
630 |
+
}
|
631 |
+
})
|
632 |
+
|
633 |
+
return JSONResponse(content={"detections": detections})
|
634 |
+
except Exception as e:
|
635 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
636 |
+
|
637 |
|
638 |
if __name__ == "__main__":
|
639 |
import uvicorn
|