Spaces:
Sleeping
Sleeping
File size: 1,862 Bytes
9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 de2aabe 9c8dfb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
def load_model_from_hub(repo_id: str):
"""
Load model from Hugging Face Hub
Args:
repo_id: The repository ID (e.g., 'username/model-name')
Returns:
model: The loaded model
processor: The feature extractor/processor
"""
# Load model and processor from Hub
model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoFeatureExtractor.from_pretrained(repo_id)
return model, processor
def predict(image_path: str, model, processor):
"""
Make prediction using the loaded model
Args:
image_path: Path to input image
model: Loaded model
processor: Feature extractor/processor
Returns:
prediction: Model prediction
"""
# Load and preprocess image
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.softmax(-1)
return predictions
# Example usage in your Flask/FastAPI app
from flask import Flask, request
app = Flask(__name__)
# Load model at startup
model, processor = load_model_from_hub("srtangirala/resnet50-exp")
@app.route('/predict', methods=['POST'])
def predict_endpoint():
if 'file' not in request.files:
return {'error': 'No file provided'}, 400
file = request.files['file']
image_path = "temp_image.jpg" # You might want to generate a unique filename
file.save(image_path)
predictions = predict(image_path, model, processor)
# Convert predictions to list and return
return {'predictions': predictions.tolist()[0]}
if __name__ == '__main__':
app.run(debug=True) |