sachin commited on
Commit
0d35999
·
1 Parent(s): 3e58bef
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. intruct.py +184 -0
Dockerfile CHANGED
@@ -35,4 +35,4 @@ USER appuser
35
  EXPOSE 7860
36
 
37
  # Run the server
38
- CMD ["python", "/app/merged_code.py"]
 
35
  EXPOSE 7860
36
 
37
  # Run the server
38
+ CMD ["python", "/app/intruct.py"]
intruct.py CHANGED
@@ -30,6 +30,72 @@ import cv2
30
  app = FastAPI()
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  model_id_runway = "runwayml/stable-diffusion-inpainting"
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
 
@@ -456,6 +522,124 @@ async def fit_image_to_mask_endpoint(
456
  raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
457
 
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  if __name__ == "__main__":
460
  import uvicorn
461
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
30
  app = FastAPI()
31
 
32
 
33
+ # Load Grounding DINO model and processor at startup
34
+ dino_model_id = "IDEA-Research/grounding-dino-base"
35
+ dino_processor = AutoProcessor.from_pretrained(dino_model_id)
36
+ dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(dino_model_id).to(device)
37
+
38
+ # Load SAM 2 model at startup
39
+ #sam_checkpoint = "sam2.1_hiera_tiny.pt" # Replace with your checkpoint path
40
+ sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
41
+ sam_predictor.model.to(device)
42
+
43
+ # Default text query
44
+ DEFAULT_TEXT_QUERY = "a tank."
45
+
46
+ def process_image_with_dino(image: Image.Image, text_query: str = DEFAULT_TEXT_QUERY):
47
+ """Detect objects using Grounding DINO."""
48
+ inputs = dino_processor(images=image, text=text_query, return_tensors="pt").to(device)
49
+ with torch.no_grad():
50
+ outputs = dino_model(**inputs)
51
+
52
+ # Post-process results
53
+ results = dino_processor.post_process_grounded_object_detection(
54
+ outputs,
55
+ inputs.input_ids,
56
+ threshold=0.4,
57
+ text_threshold=0.3,
58
+ target_sizes=[image.size[::-1]] # [width, height]
59
+ )
60
+ return results[0] # Single image result
61
+
62
+ def segment_with_sam(image: Image.Image, boxes: list):
63
+ """Segment detected objects using SAM 2 and return a mask."""
64
+ image_np = np.array(image)
65
+ sam_predictor.set_image(image_np)
66
+
67
+ if not boxes:
68
+ return np.zeros(image_np.shape[:2], dtype=bool) # Empty mask if no boxes
69
+
70
+ # Convert boxes to [x_min, y_min, x_max, y_max] tensor and move to device
71
+ boxes_tensor = torch.tensor(
72
+ [[box["x_min"], box["y_min"], box["x_max"], box["y_max"]] for box in boxes],
73
+ dtype=torch.float32
74
+ ).to(device)
75
+
76
+ # Predict with SAM 2 using boxes directly
77
+ masks, _, _ = sam_predictor.predict(
78
+ point_coords=None,
79
+ point_labels=None,
80
+ box=boxes_tensor, # Use 'box' argument instead of 'boxes'
81
+ multimask_output=False
82
+ )
83
+ return masks[0] # Return the first mask directly (already a NumPy array)
84
+
85
+ def create_background_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
86
+ """Create an RGB mask for background removal (object preserved)."""
87
+ mask_inv = np.logical_not(mask).astype(np.uint8) * 255 # Invert mask (background is white)
88
+ mask_rgb = cv2.cvtColor(mask_inv, cv2.COLOR_GRAY2RGB) # Convert to RGB
89
+ return mask_rgb
90
+
91
+ def create_object_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
92
+ """Create an RGB mask for object removal (background preserved)."""
93
+ mask_rgb = cv2.cvtColor(mask.astype(np.uint8) * 255, cv2.COLOR_GRAY2RGB) # Object is white, background black
94
+ return mask_rgb
95
+
96
+
97
+
98
+
99
  model_id_runway = "runwayml/stable-diffusion-inpainting"
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
 
 
522
  raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
523
 
524
 
525
+ @app.post("/detect-json/")
526
+ async def detect_json(
527
+ file: UploadFile = File(..., description="Image file to process"),
528
+ text_query: str = DEFAULT_TEXT_QUERY
529
+ ):
530
+ """Endpoint to detect objects and return bounding box information as JSON."""
531
+ try:
532
+ # Read and convert the uploaded image
533
+ image_data = await file.read()
534
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
535
+
536
+ # Process with Grounding DINO
537
+ results = process_image_with_dino(image, text_query)
538
+
539
+ # Format results as JSON-compatible data
540
+ detections = []
541
+ for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
542
+ x_min, y_min, x_max, y_max = box.tolist()
543
+ detections.append({
544
+ "label": label,
545
+ "score": float(score), # Convert tensor to float
546
+ "box": {
547
+ "x_min": x_min,
548
+ "y_min": y_min,
549
+ "x_max": x_max,
550
+ "y_max": y_max
551
+ }
552
+ })
553
+
554
+ return JSONResponse(content={"detections": detections})
555
+ except Exception as e:
556
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
557
+
558
+ @app.post("/segment-image/")
559
+ async def segment_image(
560
+ file: UploadFile = File(..., description="Image file to process"),
561
+ text_query: str = DEFAULT_TEXT_QUERY
562
+ ):
563
+ """Endpoint to segment objects and return the image with background removed."""
564
+ try:
565
+ # Read and convert the uploaded image
566
+ image_data = await file.read()
567
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
568
+
569
+ # Detect objects with Grounding DINO
570
+ results = process_image_with_dino(image, text_query)
571
+
572
+ # Extract boxes for segmentation, move to CPU
573
+ boxes = [
574
+ {"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
575
+ for box in results["boxes"].cpu() # Move tensor to CPU here
576
+ ]
577
+
578
+ # Segment with SAM 2
579
+ mask = segment_with_sam(image, boxes)
580
+
581
+ # Create background mask and apply it
582
+ image_np = np.array(image)
583
+ background_mask = create_background_mask(image_np, mask)
584
+ segmented_image = cv2.bitwise_and(image_np, background_mask)
585
+
586
+ # Convert to PIL Image and save to bytes
587
+ output_image = Image.fromarray(segmented_image)
588
+ img_byte_arr = io.BytesIO()
589
+ output_image.save(img_byte_arr, format="PNG")
590
+ img_byte_arr.seek(0)
591
+
592
+ return StreamingResponse(
593
+ img_byte_arr,
594
+ media_type="image/png",
595
+ headers={"Content-Disposition": "attachment; filename=segmented_image.png"}
596
+ )
597
+ except Exception as e:
598
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
599
+
600
+ @app.post("/mask-object/")
601
+ async def mask_object(
602
+ file: UploadFile = File(..., description="Image file to process"),
603
+ text_query: str = DEFAULT_TEXT_QUERY
604
+ ):
605
+ """Endpoint to mask the detected object and return the image with the object removed."""
606
+ try:
607
+ # Read and convert the uploaded image
608
+ image_data = await file.read()
609
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
610
+
611
+ # Detect objects with Grounding DINO
612
+ results = process_image_with_dino(image, text_query)
613
+
614
+ # Extract boxes for segmentation, move to CPU
615
+ boxes = [
616
+ {"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
617
+ for box in results["boxes"].cpu() # Move tensor to CPU here
618
+ ]
619
+
620
+ # Segment with SAM 2
621
+ mask = segment_with_sam(image, boxes)
622
+
623
+ # Create object mask and apply it
624
+ image_np = np.array(image)
625
+ object_mask = create_object_mask(image_np, mask)
626
+ masked_image = cv2.bitwise_and(image_np, object_mask)
627
+
628
+ # Convert to PIL Image and save to bytes
629
+ output_image = Image.fromarray(masked_image)
630
+ img_byte_arr = io.BytesIO()
631
+ output_image.save(img_byte_arr, format="PNG")
632
+ img_byte_arr.seek(0)
633
+
634
+ return StreamingResponse(
635
+ img_byte_arr,
636
+ media_type="image/png",
637
+ headers={"Content-Disposition": "attachment; filename=masked_object_image.png"}
638
+ )
639
+ except Exception as e:
640
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
641
+
642
+
643
  if __name__ == "__main__":
644
  import uvicorn
645
  uvicorn.run(app, host="0.0.0.0", port=7860)