zahraanaji commited on
Commit
e045c15
·
verified ·
1 Parent(s): 346bc31

Upload 3_object_skeleton.py

Browse files
Files changed (1) hide show
  1. 3_object_skeleton.py +43 -0
3_object_skeleton.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
3
+ from PIL import Image, ImageDraw
4
+
5
+ # Specify the checkpoint name or identifier for the pre-trained model
6
+ checkpoint = "google/owlvit-base-patch32"
7
+
8
+ # Initialize the pre-trained model and processor
9
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)
10
+ processor = AutoProcessor.from_pretrained(checkpoint)
11
+
12
+ image = Image.open('/content/drive/MyDrive/img3.jpg')
13
+
14
+ # Convert image to PIL Image format if not already
15
+ if isinstance(image, str):
16
+ image = Image.open(image)
17
+
18
+ text_queries = ["hat", "book", "sunglasses"]
19
+
20
+ # Prepare inputs for zero-shot object detection
21
+ inputs = processor(images=image, text=text_queries, return_tensors="pt")
22
+
23
+ # Perform inference with the model
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ target_sizes = torch.tensor([image.size[::-1]])
27
+ results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
28
+
29
+ # Create a drawing object for the image
30
+ draw = ImageDraw.Draw(image)
31
+
32
+ # Extract detection results (scores, labels, and bounding boxes)
33
+ scores = results["scores"].tolist()
34
+ labels = results["labels"].tolist()
35
+ boxes = results["boxes"].tolist()
36
+
37
+ # Iterate over detected objects and draw bounding boxes and labels
38
+ for box, score, label in zip(boxes, scores, labels):
39
+ xmin, ymin, xmax, ymax = box
40
+ draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
41
+ draw.text((xmin, ymin), f"{text_queries[label]}: {round(score, 2)}", fill="black")
42
+
43
+ return image