LuoFengBit commited on
Commit
6f6df07
·
1 Parent(s): 26bbddb

Create image_classification.py

Browse files
Files changed (1) hide show
  1. image_classification.py +35 -0
image_classification.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Dec 17 20:35:41 2023
4
+
5
+ @author: luofeng
6
+ """
7
+
8
+ from transformers import ViTImageProcessor, ViTForImageClassification
9
+ from PIL import Image
10
+ #import requests
11
+
12
+ #url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
+ #image = Image.open(requests.get(url, stream=True).raw)
14
+
15
+ processor = ViTImageProcessor.from_pretrained('E:\\workspaces_python\\model_local\\vit-base-patch16-224')
16
+ model = ViTForImageClassification.from_pretrained('E:\\workspaces_python\\model_local\\vit-base-patch16-224')
17
+
18
+ #inputs = processor(images=image, return_tensors="pt")
19
+ #outputs = model(**inputs)
20
+ #logits = outputs.logits
21
+ # model predicts one of the 1000 ImageNet classes
22
+ #predicted_class_idx = logits.argmax(-1).item()
23
+ #print("Predicted class:", model.config.id2label[predicted_class_idx])
24
+
25
+
26
+ def imageClassification(image_path):
27
+ image = Image.open(image_path)
28
+ inputs = processor(images=image, return_tensors="pt")
29
+ outputs = model(**inputs)
30
+ logits = outputs.logits
31
+ # model predicts one of the 1000 ImageNet classes
32
+ predicted_class_idx = logits.argmax(-1).item()
33
+ classification_result = model.config.id2label[predicted_class_idx]
34
+ print("Predicted class:", classification_result)
35
+ return classification_result