jays009 commited on
Commit
38d7439
·
verified ·
1 Parent(s): 763fc6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -27
app.py CHANGED
@@ -5,43 +5,45 @@ from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
 
8
- num_classes = 2 # Number of classes for your dataset
 
9
 
10
- # Download model weights from Hugging Face
11
  def download_model():
12
- model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
13
  return model_path
14
 
15
- # Load the model from the downloaded weights
16
  def load_model(model_path):
17
- model = models.resnet50(pretrained=False) # Set pretrained=False for custom weights
18
- model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust final layer for your number of classes
19
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model weights
20
- model.eval() # Set model to evaluation mode
21
  return model
22
 
23
- # Download and load the model
24
- model_path = download_model()
25
  model = load_model(model_path)
26
 
27
- # Image transformation pipeline
28
  transform = transforms.Compose([
29
  transforms.Resize(256), # Resize the image to 256x256
30
  transforms.CenterCrop(224), # Crop the image to 224x224
31
  transforms.ToTensor(), # Convert the image to a Tensor
32
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize for ImageNet
33
  ])
34
 
35
- # Prediction function
36
  def predict(image):
 
37
  image = transform(image).unsqueeze(0) # Add batch dimension
38
- image = image.to(torch.device("cpu")) # Move the image to CPU (adjust if you want to use GPU)
39
-
40
  with torch.no_grad():
41
  outputs = model(image) # Perform forward pass
42
- predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class ID
43
-
44
- # Return appropriate response based on predicted class
45
  if predicted_class == 0:
46
  return "The photo you've sent is of fall army worm with problem ID 126."
47
  elif predicted_class == 1:
@@ -51,14 +53,13 @@ def predict(image):
51
 
52
  # Create the Gradio interface
53
  iface = gr.Interface(
54
- fn=predict, # Prediction function
55
- inputs=gr.Image(type="pil"), # Image input (PIL format)
56
- outputs=gr.Textbox(), # Text output (Predicted class description)
57
- live=True, # Update predictions as the user uploads an image
58
- title="Maize Anomaly Detection",
59
- description="Upload an image of maize to detect anomalies like disease or pest infestation."
60
  )
61
 
62
- # Expose Gradio interface as API endpoint
63
- iface.launch
64
-
 
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
 
8
+ # Define the number of classes
9
+ num_classes = 2 # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)
10
 
11
+ # Download model from Hugging Face
12
  def download_model():
13
+ model_path = hf_hub_download(repo_id="your_huggingface_username/your_model_name", filename="pytorch_model.bin")
14
  return model_path
15
 
16
+ # Load the model from Hugging Face
17
  def load_model(model_path):
18
+ model = models.resnet50(pretrained=False) # Set pretrained=False because you're loading custom weights
19
+ model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust for the number of classes in your dataset
20
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model on CPU for compatibility
21
+ model.eval() # Set to evaluation mode
22
  return model
23
 
24
+ # Download the model and load it
25
+ model_path = download_model() # Downloads the model from Hugging Face Hub
26
  model = load_model(model_path)
27
 
28
+ # Define the transformation for the input image
29
  transform = transforms.Compose([
30
  transforms.Resize(256), # Resize the image to 256x256
31
  transforms.CenterCrop(224), # Crop the image to 224x224
32
  transforms.ToTensor(), # Convert the image to a Tensor
33
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
34
  ])
35
 
36
+ # Define the prediction function
37
  def predict(image):
38
+ # Apply the necessary transformations to the image
39
  image = transform(image).unsqueeze(0) # Add batch dimension
40
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
41
+
42
  with torch.no_grad():
43
  outputs = model(image) # Perform forward pass
44
+ predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class
45
+
46
+ # Create a response based on the predicted class
47
  if predicted_class == 0:
48
  return "The photo you've sent is of fall army worm with problem ID 126."
49
  elif predicted_class == 1:
 
53
 
54
  # Create the Gradio interface
55
  iface = gr.Interface(
56
+ fn=predict, # Function for prediction
57
+ inputs=gr.Image(type="pil"), # Image input
58
+ outputs=gr.Textbox(), # Output: Predicted class
59
+ live=True, # Updates as the user uploads an image
60
+ title="Wheat Anomaly Detection",
61
+ description="Upload an image of wheat to detect anomalies like disease or pest infestation."
62
  )
63
 
64
+ # Launch the Gradio interface
65
+ iface.launch()