jays009 commited on
Commit
b77b937
·
verified ·
1 Parent(s): 0c6fbcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -50
app.py CHANGED
@@ -7,7 +7,6 @@ from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
9
  import os
10
- import base64
11
  from io import BytesIO
12
 
13
  # Define the number of classes
@@ -15,28 +14,20 @@ num_classes = 2
15
 
16
  # Download model from Hugging Face
17
  def download_model():
18
- try:
19
- model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
20
- return model_path
21
- except Exception as e:
22
- print(f"Error downloading model: {e}")
23
- return None
24
 
25
  # Load the model from Hugging Face
26
  def load_model(model_path):
27
- try:
28
- model = models.resnet50(pretrained=False)
29
- model.fc = nn.Linear(model.fc.in_features, num_classes)
30
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
31
- model.eval()
32
- return model
33
- except Exception as e:
34
- print(f"Error loading model: {e}")
35
- return None
36
 
37
  # Download the model and load it
38
  model_path = download_model()
39
- model = load_model(model_path) if model_path else None
40
 
41
  # Define the transformation for the input image
42
  transform = transforms.Compose([
@@ -46,48 +37,50 @@ transform = transforms.Compose([
46
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
  ])
48
 
49
- def predict(image):
50
- try:
51
- print(f"Received image input: {image}")
52
-
53
- # Check if the input is a PIL Image type (Gradio automatically provides a PIL image)
54
- if not isinstance(image, Image.Image):
55
- return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
56
 
57
- # Apply transformations to the image
58
- image = transform(image).unsqueeze(0)
59
- print(f"Transformed image tensor: {image.shape}")
60
 
61
- # Move the image to the correct device
62
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
 
63
 
64
- # Make predictions
65
- with torch.no_grad():
66
- outputs = model(image)
67
- predicted_class = torch.argmax(outputs, dim=1).item()
68
- print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")
69
-
70
- # Return the result based on the predicted class
71
- if predicted_class == 0:
72
- return json.dumps({"result": "The photo you've sent is of fall army worm with problem ID 126."})
73
- elif predicted_class == 1:
74
- return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
75
- else:
76
- return json.dumps({"error": "Unexpected class prediction."})
77
 
 
 
 
 
 
 
 
78
  except Exception as e:
79
- print(f"Error processing image: {e}")
80
- return json.dumps({"error": f"Error processing image: {e}"})
81
 
82
- # Create the Gradio interface
83
  iface = gr.Interface(
84
- fn=predict,
85
- inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"),
86
- outputs=gr.Textbox(label="Prediction Result"),
 
 
 
87
  live=True,
88
  title="Maize Anomaly Detection",
89
- description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
90
  )
91
 
92
- # Launch the Gradio interface
93
  iface.launch(share=True, show_error=True)
 
7
  from PIL import Image
8
  import requests
9
  import os
 
10
  from io import BytesIO
11
 
12
  # Define the number of classes
 
14
 
15
  # Download model from Hugging Face
16
  def download_model():
17
+ model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
18
+ return model_path
 
 
 
 
19
 
20
  # Load the model from Hugging Face
21
  def load_model(model_path):
22
+ model = models.resnet50(pretrained=False)
23
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
24
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
25
+ model.eval()
26
+ return model
 
 
 
 
27
 
28
  # Download the model and load it
29
  model_path = download_model()
30
+ model = load_model(model_path)
31
 
32
  # Define the transformation for the input image
33
  transform = transforms.Compose([
 
37
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
38
  ])
39
 
40
+ # Function to predict from image content
41
+ def predict_from_image(image):
42
+ # Ensure the image is a PIL Image
43
+ if not isinstance(image, Image.Image):
44
+ raise ValueError("Invalid image format received. Please provide a valid image.")
 
 
45
 
46
+ # Apply transformations
47
+ image_tensor = transform(image).unsqueeze(0)
 
48
 
49
+ # Predict
50
+ with torch.no_grad():
51
+ outputs = model(image_tensor)
52
+ predicted_class = torch.argmax(outputs, dim=1).item()
53
 
54
+ # Interpret the result
55
+ if predicted_class == 0:
56
+ return {"result": "The photo is of fall army worm with problem ID 126."}
57
+ elif predicted_class == 1:
58
+ return {"result": "The photo is of a healthy maize image."}
59
+ else:
60
+ return {"error": "Unexpected class prediction."}
 
 
 
 
 
 
61
 
62
+ # Function to predict from URL
63
+ def predict_from_url(url):
64
+ try:
65
+ response = requests.get(url)
66
+ response.raise_for_status() # Ensure the request was successful
67
+ image = Image.open(BytesIO(response.content))
68
+ return predict_from_image(image)
69
  except Exception as e:
70
+ return {"error": f"Failed to process the URL: {str(e)}"}
 
71
 
72
+ # Gradio interface
73
  iface = gr.Interface(
74
+ fn=lambda image, url: predict_from_image(image) if image else predict_from_url(url),
75
+ inputs=[
76
+ gr.Image(type="pil", label="Upload an Image"),
77
+ gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL"),
78
+ ],
79
+ outputs=gr.JSON(label="Prediction Result"),
80
  live=True,
81
  title="Maize Anomaly Detection",
82
+ description="Upload an image or provide a URL to detect anomalies in maize crops.",
83
  )
84
 
85
+ # Launch the interface
86
  iface.launch(share=True, show_error=True)