sachin commited on
Commit
3e58bef
·
1 Parent(s): 998f798

auto-segmt

Browse files
Files changed (1) hide show
  1. merged_code.py +179 -0
merged_code.py CHANGED
@@ -455,6 +455,185 @@ async def fit_image_to_mask_endpoint(
455
  except Exception as e:
456
  raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
  if __name__ == "__main__":
460
  import uvicorn
 
455
  except Exception as e:
456
  raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
457
 
458
+ from fastapi import FastAPI, File, UploadFile, HTTPException
459
+ from fastapi.responses import StreamingResponse, JSONResponse
460
+ import torch
461
+ from PIL import Image, ImageDraw, ImageFont
462
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
463
+ import io
464
+
465
+
466
+
467
+ # Set up model and device
468
+ model_id_segment = "IDEA-Research/grounding-dino-base"
469
+ device = "cuda" if torch.cuda.is_available() else "cpu"
470
+
471
+ # Load processor and model at startup
472
+ processor_segment = AutoProcessor.from_pretrained(model_id_segment)
473
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id_segment).to(device)
474
+
475
+ # Default text query (can be overridden via endpoint parameters)
476
+ DEFAULT_TEXT_QUERY = "a tank." # Adjust based on your use case
477
+
478
+
479
+
480
+ def process_image(image: Image.Image, text_query: str = DEFAULT_TEXT_QUERY):
481
+ """Process the image with Grounding DINO and return detection results."""
482
+ # Prepare inputs for the model
483
+ inputs = processor_segment(images=image, text=text_query, return_tensors="pt").to(device)
484
+
485
+ # Perform inference
486
+ with torch.no_grad():
487
+ outputs = model(**inputs)
488
+
489
+ # Post-process results
490
+ results = processor_segment.post_process_grounded_object_detection(
491
+ outputs,
492
+ inputs.input_ids,
493
+ threshold=0.4,
494
+ text_threshold=0.3,
495
+ target_sizes=[image.size[::-1]] # [width, height]
496
+ )
497
+ return results
498
+
499
+ def draw_detections(image: Image.Image, results: list) -> Image.Image:
500
+ """Draw bounding boxes and labels on the image."""
501
+ output_image = image.copy()
502
+ draw = ImageDraw.Draw(output_image)
503
+
504
+ # Try to load a font, fall back to default
505
+ try:
506
+ font = ImageFont.truetype("arial.ttf", 20)
507
+ except:
508
+ font = ImageFont.load_default()
509
+
510
+ # Colors for different objects
511
+ colors = {"a tank": "red"} # Add more as needed, e.g., {"a cat": "red", "a remote control": "blue"}
512
+
513
+ # Draw bounding boxes and labels
514
+ for detection in results:
515
+ boxes = detection["boxes"]
516
+ labels = detection["labels"]
517
+ scores = detection["scores"]
518
+
519
+ for box, label, score in zip(boxes, labels, scores):
520
+ x_min, y_min, x_max, y_max = box.tolist()
521
+
522
+ # Draw rectangle
523
+ draw.rectangle(
524
+ [(x_min, y_min), (x_max, y_max)],
525
+ outline=colors.get(label, "green"),
526
+ width=2
527
+ )
528
+
529
+ # Draw label with score
530
+ label_text = f"{label} {score:.2f}"
531
+ bbox = draw.textbbox((x_min, y_min - 20), label_text, font=font)
532
+ text_width = bbox[2] - bbox[0]
533
+ text_height = bbox[3] - bbox[1]
534
+
535
+ # Draw background rectangle for text
536
+ draw.rectangle(
537
+ [(x_min, y_min - text_height - 5), (x_min + text_width, y_min)],
538
+ fill=colors.get(label, "green")
539
+ )
540
+
541
+ # Draw text
542
+ draw.text(
543
+ (x_min, y_min - text_height - 5),
544
+ label_text,
545
+ fill="white",
546
+ font=font
547
+ )
548
+
549
+ return output_image
550
+
551
+ @app.post("/detect-image/")
552
+ async def detect_image(
553
+ file: UploadFile = File(..., description="Image file to process"),
554
+ text_query: str = DEFAULT_TEXT_QUERY
555
+ ):
556
+ """
557
+ Endpoint to detect objects in an image and return the annotated image.
558
+
559
+ Args:
560
+ file: Uploaded image file.
561
+ text_query: Text query for objects to detect (e.g., "a tank.").
562
+
563
+ Returns:
564
+ StreamingResponse with the annotated image.
565
+ """
566
+ try:
567
+ # Read and convert the uploaded image
568
+ image_data = await file.read()
569
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
570
+
571
+ # Process the image
572
+ results = process_image(image, text_query)
573
+
574
+ # Draw detections on the image
575
+ output_image = draw_detections(image, results)
576
+
577
+ # Convert to bytes for response
578
+ img_byte_arr = io.BytesIO()
579
+ output_image.save(img_byte_arr, format="PNG")
580
+ img_byte_arr.seek(0)
581
+
582
+ return StreamingResponse(
583
+ img_byte_arr,
584
+ media_type="image/png",
585
+ headers={"Content-Disposition": "attachment; filename=detected_objects.png"}
586
+ )
587
+ except Exception as e:
588
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
589
+
590
+ @app.post("/detect-json/")
591
+ async def detect_json(
592
+ file: UploadFile = File(..., description="Image file to process"),
593
+ text_query: str = DEFAULT_TEXT_QUERY
594
+ ):
595
+ """
596
+ Endpoint to detect objects in an image and return bounding box information as JSON.
597
+
598
+ Args:
599
+ file: Uploaded image file.
600
+ text_query: Text query for objects to detect (e.g., "a tank.").
601
+
602
+ Returns:
603
+ JSONResponse with bounding box coordinates, labels, and scores.
604
+ """
605
+ try:
606
+ # Read and convert the uploaded image
607
+ image_data = await file.read()
608
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
609
+
610
+ # Process the image
611
+ results = process_image(image, text_query)
612
+
613
+ # Format results as JSON-compatible data
614
+ detections = []
615
+ for detection in results:
616
+ boxes = detection["boxes"]
617
+ labels = detection["labels"]
618
+ scores = detection["scores"]
619
+
620
+ for box, label, score in zip(boxes, labels, scores):
621
+ x_min, y_min, x_max, y_max = box.tolist()
622
+ detections.append({
623
+ "label": label,
624
+ "score": float(score), # Convert tensor to float
625
+ "box": {
626
+ "x_min": x_min,
627
+ "y_min": y_min,
628
+ "x_max": x_max,
629
+ "y_max": y_max
630
+ }
631
+ })
632
+
633
+ return JSONResponse(content={"detections": detections})
634
+ except Exception as e:
635
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
636
+
637
 
638
  if __name__ == "__main__":
639
  import uvicorn