Spaces:
Sleeping
Sleeping
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
import torch | |
from PIL import Image | |
def load_model_from_hub(repo_id: str): | |
""" | |
Load model from Hugging Face Hub | |
Args: | |
repo_id: The repository ID (e.g., 'username/model-name') | |
Returns: | |
model: The loaded model | |
processor: The feature extractor/processor | |
""" | |
# Load model and processor from Hub | |
model = AutoModelForImageClassification.from_pretrained(repo_id) | |
processor = AutoFeatureExtractor.from_pretrained(repo_id) | |
return model, processor | |
def predict(image_path: str, model, processor): | |
""" | |
Make prediction using the loaded model | |
Args: | |
image_path: Path to input image | |
model: Loaded model | |
processor: Feature extractor/processor | |
Returns: | |
prediction: Model prediction | |
""" | |
# Load and preprocess image | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = outputs.logits.softmax(-1) | |
return predictions | |
# Example usage in your Flask/FastAPI app | |
from flask import Flask, request | |
app = Flask(__name__) | |
# Load model at startup | |
model, processor = load_model_from_hub("srtangirala/resnet50-exp") | |
def predict_endpoint(): | |
if 'file' not in request.files: | |
return {'error': 'No file provided'}, 400 | |
file = request.files['file'] | |
image_path = "temp_image.jpg" # You might want to generate a unique filename | |
file.save(image_path) | |
predictions = predict(image_path, model, processor) | |
# Convert predictions to list and return | |
return {'predictions': predictions.tolist()[0]} | |
if __name__ == '__main__': | |
app.run(debug=True) |