jays009 commited on
Commit
5cadf06
·
verified ·
1 Parent(s): ee2271a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -24
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  from torch import nn
4
  from torchvision import models, transforms
 
5
  from PIL import Image
6
  import requests
7
  import base64
@@ -11,19 +12,30 @@ import os
11
  # Define the number of classes
12
  num_classes = 2 # Update with the actual number of classes in your dataset
13
 
14
- # Load the model (assuming you've already downloaded it)
15
- def load_model():
 
 
 
 
 
 
 
 
 
16
  try:
17
  model = models.resnet50(pretrained=False)
18
  model.fc = nn.Linear(model.fc.in_features, num_classes)
19
- model.load_state_dict(torch.load("path_to_your_model.pth", map_location=torch.device("cpu")))
20
  model.eval()
21
  return model
22
  except Exception as e:
23
  print(f"Error loading model: {e}")
24
  return None
25
 
26
- model = load_model()
 
 
27
 
28
  # Define the transformation for the input image
29
  transform = transforms.Compose([
@@ -33,41 +45,54 @@ transform = transforms.Compose([
33
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
34
  ])
35
 
36
- # Prediction function
37
- def process_image(image, image_url=None):
38
  try:
39
- # Ensure that the image is not None
40
- if image is None and not image_url:
41
- return "No image or URL provided."
42
 
43
- # Handle URL-based image loading
44
- if image_url:
45
  try:
46
- response = requests.get(image_url)
47
- response.raise_for_status() # Raise an error if the request fails
 
 
 
 
 
 
 
 
 
48
  image = Image.open(BytesIO(response.content))
 
49
  except Exception as e:
 
50
  return f"Error fetching image from URL: {e}"
51
 
52
- # Handle local file path image loading (Gradio File input)
53
  elif isinstance(image, str) and os.path.isfile(image):
54
  try:
55
  image = Image.open(image)
 
56
  except Exception as e:
 
57
  return f"Error loading image from local path: {e}"
58
 
59
- # Validate that the image is loaded correctly
60
  if not isinstance(image, Image.Image):
 
61
  return "Invalid image format received."
62
 
63
  # Apply transformations
64
  image = transform(image).unsqueeze(0)
 
65
 
66
- # Prediction
67
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
68
  with torch.no_grad():
69
  outputs = model(image)
70
  predicted_class = torch.argmax(outputs, dim=1).item()
 
71
 
72
  if predicted_class == 0:
73
  return "The photo you've sent is of fall army worm with problem ID 126."
@@ -76,20 +101,18 @@ def process_image(image, image_url=None):
76
  else:
77
  return "Unexpected class prediction."
78
  except Exception as e:
 
79
  return f"Error processing image: {e}"
80
 
81
  # Create the Gradio interface
82
  iface = gr.Interface(
83
- fn=process_image,
84
- inputs=[
85
- gr.File(label="Upload an image (Local File Path)"), # Input: Local file
86
- gr.Textbox(label="Enter Image URL", placeholder="Enter image URL here", lines=1) # Input: Image URL
87
- ],
88
- outputs=gr.Textbox(label="Prediction Result"), # Output: Prediction result
89
  live=True,
90
  title="Maize Anomaly Detection",
91
- description="Upload an image of maize to detect anomalies like disease or pest infestation. You can upload local images or provide an image URL."
92
  )
93
 
94
  # Launch the Gradio interface
95
- iface.launch(share=True, show_error=True)
 
2
  import torch
3
  from torch import nn
4
  from torchvision import models, transforms
5
+ from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import requests
8
  import base64
 
12
  # Define the number of classes
13
  num_classes = 2 # Update with the actual number of classes in your dataset
14
 
15
+ # Download model from Hugging Face
16
+ def download_model():
17
+ try:
18
+ model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
19
+ return model_path
20
+ except Exception as e:
21
+ print(f"Error downloading model: {e}")
22
+ return None
23
+
24
+ # Load the model from Hugging Face
25
+ def load_model(model_path):
26
  try:
27
  model = models.resnet50(pretrained=False)
28
  model.fc = nn.Linear(model.fc.in_features, num_classes)
29
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
30
  model.eval()
31
  return model
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
  return None
35
 
36
+ # Download the model and load it
37
+ model_path = download_model()
38
+ model = load_model(model_path) if model_path else None
39
 
40
  # Define the transformation for the input image
41
  transform = transforms.Compose([
 
45
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
46
  ])
47
 
48
+ def predict(image):
 
49
  try:
50
+ print(f"Received image input: {image}")
 
 
51
 
52
+ # Check if the input contains a base64-encoded string
53
+ if isinstance(image, dict) and image.get("data"):
54
  try:
55
+ image_data = base64.b64decode(image["data"])
56
+ image = Image.open(BytesIO(image_data))
57
+ print(f"Decoded base64 image: {image}")
58
+ except Exception as e:
59
+ print(f"Error decoding base64 image: {e}")
60
+ return f"Error decoding base64 image: {e}"
61
+
62
+ # Check if the input is a URL
63
+ elif isinstance(image, str) and image.startswith("http"):
64
+ try:
65
+ response = requests.get(image)
66
  image = Image.open(BytesIO(response.content))
67
+ print(f"Fetched image from URL: {image}")
68
  except Exception as e:
69
+ print(f"Error fetching image from URL: {e}")
70
  return f"Error fetching image from URL: {e}"
71
 
72
+ # Check if the input is a local file path
73
  elif isinstance(image, str) and os.path.isfile(image):
74
  try:
75
  image = Image.open(image)
76
+ print(f"Loaded image from local path: {image}")
77
  except Exception as e:
78
+ print(f"Error loading image from local path: {e}")
79
  return f"Error loading image from local path: {e}"
80
 
81
+ # Validate that the image is correctly loaded
82
  if not isinstance(image, Image.Image):
83
+ print("Invalid image format received.")
84
  return "Invalid image format received."
85
 
86
  # Apply transformations
87
  image = transform(image).unsqueeze(0)
88
+ print(f"Transformed image tensor: {image.shape}")
89
 
 
90
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
91
+
92
  with torch.no_grad():
93
  outputs = model(image)
94
  predicted_class = torch.argmax(outputs, dim=1).item()
95
+ print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")
96
 
97
  if predicted_class == 0:
98
  return "The photo you've sent is of fall army worm with problem ID 126."
 
101
  else:
102
  return "Unexpected class prediction."
103
  except Exception as e:
104
+ print(f"Error processing image: {e}")
105
  return f"Error processing image: {e}"
106
 
107
  # Create the Gradio interface
108
  iface = gr.Interface(
109
+ fn=predict,
110
+ inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"), # Input: Image, URL, or Local Path
111
+ outputs=gr.Textbox(label="Prediction Result"), # Output: Predicted class
 
 
 
112
  live=True,
113
  title="Maize Anomaly Detection",
114
+ description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
115
  )
116
 
117
  # Launch the Gradio interface
118
+ iface.launch(share=True, show_error=True)