jays009 commited on
Commit
fa3ae41
·
verified ·
1 Parent(s): 2eeccb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -56
app.py CHANGED
@@ -3,33 +3,13 @@ import json
3
  import torch
4
  from torch import nn
5
  from torchvision import models, transforms
6
- from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
- import requests
9
  import os
10
- from io import BytesIO
11
 
12
  # Define the number of classes
13
  num_classes = 2
14
 
15
- # Download model from Hugging Face
16
- def download_model():
17
- model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
18
- return model_path
19
-
20
- # Load the model from Hugging Face
21
- def load_model(model_path):
22
- model = models.resnet50(pretrained=False)
23
- model.fc = nn.Linear(model.fc.in_features, num_classes)
24
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
25
- model.eval()
26
- return model
27
-
28
- # Download the model and load it
29
- model_path = download_model()
30
- model = load_model(model_path)
31
-
32
- # Define the transformation for the input image
33
  transform = transforms.Compose([
34
  transforms.Resize(256),
35
  transforms.CenterCrop(224),
@@ -37,53 +17,53 @@ transform = transforms.Compose([
37
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
38
  ])
39
 
40
- # Function to predict from image content
41
- def predict_from_image(image):
42
- print(f"Processing image: {image}")
43
- if not isinstance(image, Image.Image):
44
- raise ValueError("Invalid image format received. Please provide a valid image.")
 
 
45
 
46
- # Apply transformations
47
- image_tensor = transform(image).unsqueeze(0)
 
 
 
 
 
48
 
49
- # Predict
 
50
  with torch.no_grad():
51
  outputs = model(image_tensor)
52
  predicted_class = torch.argmax(outputs, dim=1).item()
 
53
 
54
- # Interpret the result
55
- if predicted_class == 0:
56
- return {"result": "The photo is of fall army worm with problem ID 126."}
57
- elif predicted_class == 1:
58
- return {"result": "The photo is of a healthy maize image."}
59
- else:
60
- return {"error": "Unexpected class prediction."}
61
 
62
- # Function to predict from path or URL
63
- def predict_from_path_or_url(path_or_url):
64
  try:
65
- if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
66
- response = requests.get(path_or_url)
67
- response.raise_for_status() # Ensure the request was successful
68
- image = Image.open(BytesIO(response.content))
69
- elif os.path.isfile(path_or_url):
70
- image = Image.open(path_or_url)
71
- else:
72
- return {"error": "Invalid path or URL. Please provide a valid URL or local file path."}
73
-
74
- return predict_from_image(image)
75
  except Exception as e:
76
- return {"error": f"Failed to process the path or URL: {str(e)}"}
77
 
78
- # Gradio interface
79
  iface = gr.Interface(
80
- fn=predict_from_image, # Adjust to handle images only
81
- inputs=[gr.Image(type="pil", label="Upload an Image")],
82
- outputs=gr.JSON(label="Prediction Result"),
83
  live=True,
84
  title="Maize Anomaly Detection",
85
- description="Upload an image to detect anomalies in maize crops.",
86
  )
87
 
88
- # Launch the interface
89
- iface.launch(share=True, show_error=True)
 
3
  import torch
4
  from torch import nn
5
  from torchvision import models, transforms
 
6
  from PIL import Image
 
7
  import os
 
8
 
9
  # Define the number of classes
10
  num_classes = 2
11
 
12
+ # Define transformation for image processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  transform = transforms.Compose([
14
  transforms.Resize(256),
15
  transforms.CenterCrop(224),
 
17
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
18
  ])
19
 
20
+ # Function to load and preprocess image
21
+ def load_image_from_path(image_path):
22
+ if not os.path.exists(image_path):
23
+ raise FileNotFoundError(f"Image file not found at {image_path}")
24
+ image = Image.open(image_path)
25
+ image = transform(image).unsqueeze(0) # Convert to tensor and add batch dimension
26
+ return image
27
 
28
+ # Load the model (Example: ResNet50)
29
+ def load_model():
30
+ model = models.resnet50(pretrained=True)
31
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
32
+ model.load_state_dict(torch.load("model.pth"))
33
+ model.eval()
34
+ return model
35
 
36
+ # Predict from image tensor
37
+ def predict(image_tensor):
38
  with torch.no_grad():
39
  outputs = model(image_tensor)
40
  predicted_class = torch.argmax(outputs, dim=1).item()
41
+ return predicted_class
42
 
43
+ # Initialize model
44
+ model = load_model()
 
 
 
 
 
45
 
46
+ # Define the Gradio interface function
47
+ def predict_from_file(file_path):
48
  try:
49
+ # Load image from path
50
+ image_tensor = load_image_from_path(file_path)
51
+ # Get prediction
52
+ predicted_class = predict(image_tensor)
53
+ result = {"result": "Fall armyworm" if predicted_class == 0 else "Healthy maize"}
54
+ return result
 
 
 
 
55
  except Exception as e:
56
+ return {"error": str(e)}
57
 
58
+ # Gradio Interface
59
  iface = gr.Interface(
60
+ fn=predict_from_file,
61
+ inputs=gr.Textbox(label="Image Path (Local)"),
62
+ outputs=gr.JSON(),
63
  live=True,
64
  title="Maize Anomaly Detection",
65
+ description="Send a local file path via POST request to trigger prediction.",
66
  )
67
 
68
+ # Launch the Gradio app
69
+ iface.launch(share=True, server_name="0.0.0.0", server_port=7860)