yigagilbert commited on
Commit
0e247ee
·
verified ·
1 Parent(s): 2187529

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +28 -0
inference.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
3
+ from PIL import Image
4
+ import json
5
+
6
+ # Load metadata from the JSON file
7
+ with open('metadata.json') as f:
8
+ metadata = json.load(f)
9
+
10
+ def predict(image_path: str):
11
+ # Load the fine-tuned model and feature extractor from Hugging Face Hub
12
+ model = ViTForImageClassification.from_pretrained("yigagilbert/image-quality-model")
13
+ feature_extractor = ViTFeatureExtractor.from_pretrained("yigagilbert/image-quality-model")
14
+
15
+ # Open and preprocess the image
16
+ image = Image.open(image_path)
17
+ inputs = feature_extractor(images=image, return_tensors="pt")
18
+
19
+ # Perform inference
20
+ with torch.no_grad():
21
+ outputs = model(**inputs)
22
+ predicted_value = outputs.logits.squeeze().item()
23
+
24
+ # Scale the predicted value to match the dataset's max value
25
+ max_value = metadata.get('max_value', 1.0) # Default to 1.0 if not found
26
+ predicted_value_scaled = predicted_value * max_value
27
+
28
+ return predicted_value_scaled