yolac commited on
Commit
853350a
·
verified ·
1 Parent(s): f44608a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -20
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
 
6
  import logging
7
 
8
  # Set up logging for debugging
@@ -37,14 +38,11 @@ class BacterialMorphologyClassifier(nn.Module):
37
  # Load the model and weights
38
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
39
  logging.debug("Starting model loading...")
40
- try:
41
- model = BacterialMorphologyClassifier()
42
- state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
43
- model.load_state_dict(state_dict, strict=False)
44
- model.eval()
45
- logging.debug("Model loaded successfully.")
46
- except Exception as e:
47
- logging.error(f"Error loading the model: {str(e)}")
48
 
49
  # Define image preprocessing transformations
50
  transform = transforms.Compose([
@@ -56,10 +54,8 @@ transform = transforms.Compose([
56
  # Define the prediction function
57
  def predict(image):
58
  try:
59
- logging.debug("Starting prediction...")
60
  # Preprocess the image
61
  image_tensor = transform(image).unsqueeze(0)
62
- logging.debug("Image preprocessing completed.")
63
 
64
  # Make prediction
65
  output = model(image_tensor)
@@ -78,14 +74,8 @@ def predict(image):
78
  return "Error", 0.0
79
 
80
  # Create a Gradio interface
81
- gr.Interface(
82
- fn=predict,
83
- inputs=gr.Image(type="pil", label="Upload an image"),
84
- outputs=["text", "number"],
85
- examples=[
86
- ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
87
- ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
88
- ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]
89
- ]
90
- ).launch(debug=True)
91
 
 
 
 
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
6
+ import io
7
  import logging
8
 
9
  # Set up logging for debugging
 
38
  # Load the model and weights
39
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
40
  logging.debug("Starting model loading...")
41
+ model = BacterialMorphologyClassifier()
42
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
43
+ model.load_state_dict(state_dict, strict=False)
44
+ model.eval()
45
+ logging.debug("Model loaded successfully.")
 
 
 
46
 
47
  # Define image preprocessing transformations
48
  transform = transforms.Compose([
 
54
  # Define the prediction function
55
  def predict(image):
56
  try:
 
57
  # Preprocess the image
58
  image_tensor = transform(image).unsqueeze(0)
 
59
 
60
  # Make prediction
61
  output = model(image_tensor)
 
74
  return "Error", 0.0
75
 
76
  # Create a Gradio interface
77
+ inputs = gr.Image(type="pil", label="Upload an image")
78
+ outputs = gr.Label(num_top_classes=3, label="Predicted Class")
 
 
 
 
 
 
 
 
79
 
80
+ # Launch the Gradio app
81
+ gr.Interface(fn=predict, inputs=inputs, outputs=outputs, live=True).launch(server_name="0.0.0.0", server_port=7861, debug=True)