Srujan111 commited on
Commit
6919142
·
1 Parent(s): f9a4936

Upload import torch.py

Browse files
Files changed (1) hide show
  1. import torch.py +70 -0
import torch.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+
6
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
7
+ feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+
13
+ max_length = 16
14
+ num_beams = 4
15
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
16
+
17
+ def predict_step(image_paths):
18
+ images = []
19
+ for image_path in image_paths:
20
+ i_image = Image.open(image_path)
21
+ if i_image.mode != "RGB":
22
+ i_image = i_image.convert(mode="RGB")
23
+
24
+ images.append(i_image)
25
+
26
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
27
+ pixel_values = pixel_values.to(device)
28
+
29
+ output_ids = model.generate(pixel_values, **gen_kwargs)
30
+
31
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
32
+ preds = [pred.strip() for pred in preds]
33
+ return preds
34
+
35
+ target_object = 'Desk'
36
+
37
+ def predict_step(image_paths):
38
+ images = []
39
+ for image_path in image_paths:
40
+ i_image = Image.open(image_path)
41
+ if i_image.mode != "RGB":
42
+ i_image = i_image.convert(mode="RGB")
43
+
44
+ images.append(i_image)
45
+
46
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
47
+ pixel_values = pixel_values.to(device)
48
+
49
+ output_ids = model.generate(pixel_values, **gen_kwargs)
50
+
51
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
52
+ preds = [pred.strip() for pred in preds]
53
+
54
+
55
+
56
+ # Check if the target object is mentioned in the captions
57
+ object_found = any(target_object.lower() in caption.lower() for caption in preds)
58
+
59
+
60
+ return object_found
61
+
62
+ # Check if the target object is present in the image
63
+
64
+
65
+ result = predict_step(['D:\Sushant.jpg'])
66
+
67
+ if result:
68
+ print(f"The object "+ target_object + " is present in the image.")
69
+ else:
70
+ print(f"The object "+ target_object + " is not present in the image.")