yolac commited on
Commit
c8491fd
·
verified ·
1 Parent(s): d9d91ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -28
app.py CHANGED
@@ -1,10 +1,10 @@
1
- from flask import Flask, request, jsonify
2
  import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
6
- import io
7
- from torch.hub import load_state_dict_from_url
 
8
 
9
  # Define the model architecture
10
  class BacterialMorphologyClassifier(nn.Module):
@@ -32,48 +32,61 @@ class BacterialMorphologyClassifier(nn.Module):
32
  x = self.fc(x)
33
  return x
34
 
35
- # Load the model and weights
 
36
  model = BacterialMorphologyClassifier()
37
- MODEL_URL = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
38
- state_dict = load_state_dict_from_url(MODEL_URL, map_location=torch.device('cpu'))
39
- model.load_state_dict(state_dict, strict=False)
40
- model.eval()
41
 
42
- # Set up Flask app
43
- app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Define image preprocessing transformations
46
  transform = transforms.Compose([
47
  transforms.Resize((224, 224)),
48
  transforms.ToTensor(),
49
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
50
  ])
51
 
52
- @app.route('/predict', methods=['POST'])
53
- def predict():
54
  try:
55
- # Get image from request
56
- image_file = request.files['image']
57
- image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
58
-
59
- # Preprocess the image
60
  image_tensor = transform(image).unsqueeze(0)
61
 
62
- # Make prediction
63
  output = model(image_tensor)
64
  prediction = output.argmax().item()
65
 
66
  # Class mapping
67
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
68
 
69
- # Return prediction result
70
- response = {
71
- 'predicted_class': class_labels[prediction],
72
- 'confidence': output.max().item()
73
- }
74
- return jsonify(response)
75
  except Exception as e:
76
- return jsonify({'error': str(e)})
 
 
 
 
 
 
 
 
 
77
 
78
- if __name__ == '__main__':
79
- app.run(host='0.0.0.0', port=5000, debug=False)
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import transforms
4
  from PIL import Image
5
+ import requests
6
+ import gradio as gr
7
+ import os
8
 
9
  # Define the model architecture
10
  class BacterialMorphologyClassifier(nn.Module):
 
32
  x = self.fc(x)
33
  return x
34
 
35
+ # Load the model
36
+ MODEL_PATH = "model.pth"
37
  model = BacterialMorphologyClassifier()
 
 
 
 
38
 
39
+ try:
40
+ # Download the model if it doesn't exist
41
+ if not os.path.exists(MODEL_PATH):
42
+ print("Downloading the model...")
43
+ url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
44
+ response = requests.get(url)
45
+ with open(MODEL_PATH, "wb") as f:
46
+ f.write(response.content)
47
+ # Load the model weights
48
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')), strict=False)
49
+ model.eval()
50
+ print("Model loaded successfully.")
51
+ except Exception as e:
52
+ print(f"Error loading the model: {e}")
53
 
54
+ # Define image preprocessing
55
  transform = transforms.Compose([
56
  transforms.Resize((224, 224)),
57
  transforms.ToTensor(),
58
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
59
  ])
60
 
61
+ # Prediction function
62
+ def predict(image):
63
  try:
64
+ # Convert the image to a tensor
 
 
 
 
65
  image_tensor = transform(image).unsqueeze(0)
66
 
67
+ # Perform prediction
68
  output = model(image_tensor)
69
  prediction = output.argmax().item()
70
 
71
  # Class mapping
72
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
73
 
74
+ # Return the predicted class and confidence
75
+ predicted_class = class_labels[prediction]
76
+ confidence = output.max().item()
77
+ return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
 
 
78
  except Exception as e:
79
+ return f"Error: {str(e)}"
80
+
81
+ # Set up Gradio interface
82
+ interface = gr.Interface(
83
+ fn=predict,
84
+ inputs=gr.inputs.Image(type="pil"),
85
+ outputs="text",
86
+ title="Bacterial Morphology Classification",
87
+ description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.",
88
+ )
89
 
90
+ # Launch the app
91
+ if __name__ == "__main__":
92
+ interface.launch()