File size: 2,615 Bytes
420fa3e
 
 
78d66d3
420fa3e
 
78d66d3
420fa3e
 
 
 
78d66d3
420fa3e
 
 
78d66d3
420fa3e
 
 
 
 
78d66d3
420fa3e
 
78d66d3
420fa3e
 
 
 
 
78d66d3
420fa3e
 
 
 
 
78d66d3
420fa3e
 
78d66d3
420fa3e
 
 
 
 
 
78d66d3
420fa3e
 
 
 
 
78d66d3
420fa3e
 
 
78d66d3
420fa3e
 
 
 
 
78d66d3
420fa3e
78d66d3
420fa3e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# no cpu required
#TODO: update to gpu usage
from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np
import spaces

checkpoint = "google/owlv2-base-patch16-ensemble"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base")
sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")

@spaces.GPU
def query(image, texts, threshold):
  texts = texts.split(",")

  predictions = detector(
    image,
    candidate_labels=texts,
    threshold=threshold
  )

  result_labels = []
  for pred in predictions:

    box = pred["box"]
    score = pred["score"]
    label = pred["label"]
    box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
        round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]

    inputs = sam_processor(
            image,
            input_boxes=[[[box]]],
            return_tensors="pt"
        )

    with torch.no_grad():
        outputs = sam_model(**inputs)

    mask = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]

    from PIL import Image, ImageDraw
    # Convert mask to image format and overlay on the original image
    mask_image = Image.fromarray((mask[0] * 255).astype(np.uint8))
    mask_image = mask_image.convert("L")  # Convert to grayscale for transparency
    mask_image = mask_image.resize(image.size)

    # Create an alpha mask for transparency
    alpha_mask = Image.new("L", mask_image.size, 128)  # Adjust transparency level here
    image.paste(mask_image, (0, 0), alpha_mask)  # Overlay the mask on the image

    # Save the annotated image
    image.save("annotated_image.png")
    print("saved image")
    result_labels.append((mask, label))
  return image, result_labels

import gradio as gr

description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
    query,
    inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
    outputs="annotatedimage",
    title="OWL 🤝 SAM",
    description=description
)
demo.launch(debug=True)