jays009 commited on
Commit
342396f
·
verified ·
1 Parent(s): 0f694f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -36
app.py CHANGED
@@ -3,8 +3,11 @@ import json
3
  import torch
4
  from torch import nn
5
  from torchvision import models, transforms
 
6
  from PIL import Image
 
7
  import os
 
8
 
9
  # Define the number of classes
10
  num_classes = 2
@@ -26,7 +29,7 @@ def load_model(model_path):
26
  model_path = download_model()
27
  model = load_model(model_path)
28
 
29
- # Define transformation for image processing
30
  transform = transforms.Compose([
31
  transforms.Resize(256),
32
  transforms.CenterCrop(224),
@@ -34,53 +37,61 @@ transform = transforms.Compose([
34
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
35
  ])
36
 
37
- # Function to load and preprocess image
38
- def load_image_from_path(image_path):
39
- if not os.path.exists(image_path):
40
- raise FileNotFoundError(f"Image file not found at {image_path}")
41
- image = Image.open(image_path)
42
- image = transform(image).unsqueeze(0) # Convert to tensor and add batch dimension
43
- return image
44
 
45
- # Load the model (Example: ResNet50)
46
- def load_model():
47
- model = models.resnet50(pretrained=True)
48
- model.fc = nn.Linear(model.fc.in_features, num_classes)
49
- model.load_state_dict(torch.load("model.pth"))
50
- model.eval()
51
- return model
52
 
53
- # Predict from image tensor
54
- def predict(image_tensor):
55
  with torch.no_grad():
56
  outputs = model(image_tensor)
57
  predicted_class = torch.argmax(outputs, dim=1).item()
58
- return predicted_class
59
 
60
- # Initialize model
61
- model = load_model()
 
 
 
 
 
62
 
63
- # Define the Gradio interface function
64
- def predict_from_file(file_path):
65
  try:
66
- # Load image from path
67
- image_tensor = load_image_from_path(file_path)
68
- # Get prediction
69
- predicted_class = predict(image_tensor)
70
- result = {"result": "Fall armyworm" if predicted_class == 0 else "Healthy maize"}
71
- return result
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
- return {"error": str(e)}
74
 
75
- # Gradio Interface
76
  iface = gr.Interface(
77
- fn=predict_from_file,
78
- inputs=gr.Textbox(label="Image Path (Local)"),
79
- outputs=gr.JSON(),
 
 
 
80
  live=True,
81
  title="Maize Anomaly Detection",
82
- description="Send a local file path via POST request to trigger prediction.",
83
  )
84
 
85
- # Launch the Gradio app
86
- iface.launch(share=True, server_name="0.0.0.0", server_port=7860)
 
3
  import torch
4
  from torch import nn
5
  from torchvision import models, transforms
6
+ from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
+ import requests
9
  import os
10
+ from io import BytesIO
11
 
12
  # Define the number of classes
13
  num_classes = 2
 
29
  model_path = download_model()
30
  model = load_model(model_path)
31
 
32
+ # Define the transformation for the input image
33
  transform = transforms.Compose([
34
  transforms.Resize(256),
35
  transforms.CenterCrop(224),
 
37
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
38
  ])
39
 
40
+ # Function to predict from image content
41
+ def predict_from_image(image):
42
+ # Ensure the image is a PIL Image
43
+ if not isinstance(image, Image.Image):
44
+ raise ValueError("Invalid image format received. Please provide a valid image.")
 
 
45
 
46
+ # Apply transformations
47
+ image_tensor = transform(image).unsqueeze(0)
 
 
 
 
 
48
 
49
+ # Predict
 
50
  with torch.no_grad():
51
  outputs = model(image_tensor)
52
  predicted_class = torch.argmax(outputs, dim=1).item()
 
53
 
54
+ # Interpret the result
55
+ if predicted_class == 0:
56
+ return {"result": "The photo is of fall army worm with problem ID 126."}
57
+ elif predicted_class == 1:
58
+ return {"result": "The photo is of a healthy maize image."}
59
+ else:
60
+ return {"error": "Unexpected class prediction."}
61
 
62
+ # Function to handle image from URL or file path
63
+ def predict_from_url_or_path(url=None, path=None):
64
  try:
65
+ # If URL is provided, fetch and process image
66
+ if url:
67
+ response = requests.get(url)
68
+ response.raise_for_status() # Ensure the request was successful
69
+ image = Image.open(BytesIO(response.content))
70
+ return predict_from_image(image)
71
+
72
+ # If path is provided, open the image from the local path
73
+ elif path:
74
+ if not os.path.exists(path):
75
+ return {"error": f"File not found at {path}"}
76
+ image = Image.open(path)
77
+ return predict_from_image(image)
78
+ else:
79
+ return {"error": "No valid input provided."}
80
  except Exception as e:
81
+ return {"error": f"Failed to process the input: {str(e)}"}
82
 
83
+ # Gradio interface
84
  iface = gr.Interface(
85
+ fn=lambda url, path: predict_from_url_or_path(url=url, path=path),
86
+ inputs=[
87
+ gr.Textbox(label="Enter Image URL", placeholder="Provide a valid image URL (optional)", optional=True),
88
+ gr.Textbox(label="Or Enter Local Image Path", placeholder="Provide the local image path (optional)", optional=True),
89
+ ],
90
+ outputs=gr.JSON(label="Prediction Result"),
91
  live=True,
92
  title="Maize Anomaly Detection",
93
+ description="Provide either an image URL or a local file path to detect anomalies in maize crops.",
94
  )
95
 
96
+ # Launch the interface
97
+ iface.launch(share=True, show_error=True)