File size: 1,862 Bytes
9c8dfb8
de2aabe
 
 
9c8dfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
de2aabe
9c8dfb8
 
 
de2aabe
9c8dfb8
 
 
 
 
 
 
 
 
 
de2aabe
9c8dfb8
de2aabe
9c8dfb8
 
 
 
 
 
 
 
de2aabe
9c8dfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de2aabe
9c8dfb8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image

def load_model_from_hub(repo_id: str):
    """
    Load model from Hugging Face Hub
    
    Args:
        repo_id: The repository ID (e.g., 'username/model-name')
    Returns:
        model: The loaded model
        processor: The feature extractor/processor
    """
    # Load model and processor from Hub
    model = AutoModelForImageClassification.from_pretrained(repo_id)
    processor = AutoFeatureExtractor.from_pretrained(repo_id)
    return model, processor

def predict(image_path: str, model, processor):
    """
    Make prediction using the loaded model
    
    Args:
        image_path: Path to input image
        model: Loaded model
        processor: Feature extractor/processor
    Returns:
        prediction: Model prediction
    """
    # Load and preprocess image
    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")
    
    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits.softmax(-1)
    
    return predictions

# Example usage in your Flask/FastAPI app
from flask import Flask, request
app = Flask(__name__)

# Load model at startup
model, processor = load_model_from_hub("srtangirala/resnet50-exp")

@app.route('/predict', methods=['POST'])
def predict_endpoint():
    if 'file' not in request.files:
        return {'error': 'No file provided'}, 400
    
    file = request.files['file']
    image_path = "temp_image.jpg"  # You might want to generate a unique filename
    file.save(image_path)
    
    predictions = predict(image_path, model, processor)
    
    # Convert predictions to list and return
    return {'predictions': predictions.tolist()[0]}

if __name__ == '__main__':
    app.run(debug=True)