jays009 commited on
Commit
bf44ad8
·
verified ·
1 Parent(s): 99f6f46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -4,6 +4,9 @@ from torch import nn
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
 
 
 
7
 
8
  # Define the number of classes
9
  num_classes = 2 # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)
@@ -33,8 +36,30 @@ transform = transforms.Compose([
33
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
34
  ])
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Define the prediction function
37
- def predict(image):
 
 
 
 
 
 
 
 
 
 
 
38
  # Apply the necessary transformations to the image
39
  image = transform(image).unsqueeze(0) # Add batch dimension
40
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
@@ -57,9 +82,9 @@ iface = gr.Interface(
57
  inputs=gr.Image(type="pil"), # Image input
58
  outputs=gr.Textbox(), # Output: Predicted class
59
  live=True, # Updates as the user uploads an image
60
- title="maize Anomaly Detection",
61
- description="Upload an image of maize to detect anomalies like disease or pest infestation."
62
  )
63
 
64
  # Launch the Gradio interface
65
- iface.launch(share=True)
 
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
+ import requests
8
+ import base64
9
+ from io import BytesIO
10
 
11
  # Define the number of classes
12
  num_classes = 2 # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)
 
36
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
37
  ])
38
 
39
+ # Function to convert image from URL to PIL image
40
+ def url_to_image(image_url):
41
+ response = requests.get(image_url)
42
+ img = Image.open(BytesIO(response.content))
43
+ return img
44
+
45
+ # Function to convert base64 string to PIL image
46
+ def base64_to_pil(base64_string):
47
+ img_data = base64.b64decode(base64_string)
48
+ return Image.open(BytesIO(img_data))
49
+
50
  # Define the prediction function
51
+ def predict(image_input):
52
+ # If input is a string (URL or base64 encoded), handle accordingly
53
+ if isinstance(image_input, str):
54
+ if image_input.startswith("http"): # If URL
55
+ image = url_to_image(image_input)
56
+ elif image_input.startswith("data:image"): # If base64 string
57
+ image = base64_to_pil(image_input)
58
+ else: # Local image path
59
+ image = Image.open(image_input)
60
+ else:
61
+ image = image_input # If the input is already a PIL image
62
+
63
  # Apply the necessary transformations to the image
64
  image = transform(image).unsqueeze(0) # Add batch dimension
65
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
 
82
  inputs=gr.Image(type="pil"), # Image input
83
  outputs=gr.Textbox(), # Output: Predicted class
84
  live=True, # Updates as the user uploads an image
85
+ title="Maize Anomaly Detection",
86
+ description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
87
  )
88
 
89
  # Launch the Gradio interface
90
+ iface.launch(share=True) # Share=True to make the app accessible online