image-qaulity-model / inference.py
yigagilbert's picture
Create inference.py
0e247ee verified
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import json
# Load metadata from the JSON file
with open('metadata.json') as f:
metadata = json.load(f)
def predict(image_path: str):
# Load the fine-tuned model and feature extractor from Hugging Face Hub
model = ViTForImageClassification.from_pretrained("yigagilbert/image-quality-model")
feature_extractor = ViTFeatureExtractor.from_pretrained("yigagilbert/image-quality-model")
# Open and preprocess the image
image = Image.open(image_path)
inputs = feature_extractor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
predicted_value = outputs.logits.squeeze().item()
# Scale the predicted value to match the dataset's max value
max_value = metadata.get('max_value', 1.0) # Default to 1.0 if not found
predicted_value_scaled = predicted_value * max_value
return predicted_value_scaled