Sreekanth Tangirala commited on
Commit
f173f8e
·
1 Parent(s): 0ab6cdd
Files changed (1) hide show
  1. app.py +47 -18
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
2
  import torch
3
  from PIL import Image
 
4
 
5
  def load_model_from_hub(repo_id: str):
6
  """
@@ -39,26 +40,54 @@ def predict(image_path: str, model, processor):
39
 
40
  return predictions
41
 
42
- # Example usage in your Flask/FastAPI app
43
- from flask import Flask, request
44
- app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Load model at startup
47
  model, processor = load_model_from_hub("srtangirala/resnet50-exp")
48
 
49
- @app.route('/predict', methods=['POST'])
50
- def predict_endpoint():
51
- if 'file' not in request.files:
52
- return {'error': 'No file provided'}, 400
53
-
54
- file = request.files['file']
55
- image_path = "temp_image.jpg" # You might want to generate a unique filename
56
- file.save(image_path)
57
-
58
- predictions = predict(image_path, model, processor)
59
-
60
- # Convert predictions to list and return
61
- return {'predictions': predictions.tolist()[0]}
62
 
63
- if __name__ == '__main__':
64
- app.run(debug=True)
 
1
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
2
  import torch
3
  from PIL import Image
4
+ import gradio as gr
5
 
6
  def load_model_from_hub(repo_id: str):
7
  """
 
40
 
41
  return predictions
42
 
43
+ def predict_image(image):
44
+ """
45
+ Gradio interface function for prediction
46
+
47
+ Args:
48
+ image: Image uploaded through Gradio interface
49
+ Returns:
50
+ str: Prediction result with confidence score
51
+ """
52
+ # Convert from numpy array to PIL Image
53
+ if not isinstance(image, Image.Image):
54
+ image = Image.fromarray(image)
55
+
56
+ # Process image and get prediction
57
+ inputs = processor(images=image, return_tensors="pt")
58
+ with torch.no_grad():
59
+ outputs = model(**inputs)
60
+ predictions = outputs.logits.softmax(-1)
61
+
62
+ # Get the top prediction
63
+ pred_scores = predictions[0].tolist()
64
+ top_pred_idx = max(range(len(pred_scores)), key=pred_scores.__getitem__)
65
+ confidence = pred_scores[top_pred_idx]
66
+
67
+ # Get class label
68
+ if hasattr(model.config, 'id2label'):
69
+ label = model.config.id2label[top_pred_idx]
70
+ else:
71
+ label = f"Class {top_pred_idx}"
72
+
73
+ return f"{label} (Confidence: {confidence:.2%})"
74
 
75
  # Load model at startup
76
  model, processor = load_model_from_hub("srtangirala/resnet50-exp")
77
 
78
+ # Create Gradio interface
79
+ iface = gr.Interface(
80
+ fn=predict_image,
81
+ inputs=gr.Image(),
82
+ outputs=gr.Text(),
83
+ title="Image Classification",
84
+ description="Upload an image to classify it!",
85
+ examples=[
86
+ # You can add example images here
87
+ # ["path/to/example1.jpg"],
88
+ # ["path/to/example2.jpg"]
89
+ ]
90
+ )
91
 
92
+ if __name__ == "__main__":
93
+ iface.launch()