resnet-train / app.py
Sreekanth Tangirala
remove pth from tracking
9c8dfb8
raw
history blame
1.86 kB
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
def load_model_from_hub(repo_id: str):
"""
Load model from Hugging Face Hub
Args:
repo_id: The repository ID (e.g., 'username/model-name')
Returns:
model: The loaded model
processor: The feature extractor/processor
"""
# Load model and processor from Hub
model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoFeatureExtractor.from_pretrained(repo_id)
return model, processor
def predict(image_path: str, model, processor):
"""
Make prediction using the loaded model
Args:
image_path: Path to input image
model: Loaded model
processor: Feature extractor/processor
Returns:
prediction: Model prediction
"""
# Load and preprocess image
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.softmax(-1)
return predictions
# Example usage in your Flask/FastAPI app
from flask import Flask, request
app = Flask(__name__)
# Load model at startup
model, processor = load_model_from_hub("srtangirala/resnet50-exp")
@app.route('/predict', methods=['POST'])
def predict_endpoint():
if 'file' not in request.files:
return {'error': 'No file provided'}, 400
file = request.files['file']
image_path = "temp_image.jpg" # You might want to generate a unique filename
file.save(image_path)
predictions = predict(image_path, model, processor)
# Convert predictions to list and return
return {'predictions': predictions.tolist()[0]}
if __name__ == '__main__':
app.run(debug=True)