jays009 commited on
Commit
2255b93
·
verified ·
1 Parent(s): 6981478

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -38
app.py CHANGED
@@ -9,67 +9,84 @@ import base64
9
  from io import BytesIO
10
 
11
  # Define the number of classes
12
- num_classes = 2 # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)
13
 
14
  # Download model from Hugging Face
15
  def download_model():
16
- model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
17
- return model_path
 
 
 
 
18
 
19
  # Load the model from Hugging Face
20
  def load_model(model_path):
21
- model = models.resnet50(pretrained=False) # Set pretrained=False because you're loading custom weights
22
- model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust for the number of classes in your dataset
23
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model on CPU for compatibility
24
- model.eval() # Set to evaluation mode
25
- return model
 
 
 
 
26
 
27
  # Download the model and load it
28
- model_path = download_model() # Downloads the model from Hugging Face Hub
29
- model = load_model(model_path)
30
 
31
  # Define the transformation for the input image
32
  transform = transforms.Compose([
33
- transforms.Resize(256), # Resize the image to 256x256
34
- transforms.CenterCrop(224), # Crop the image to 224x224
35
- transforms.ToTensor(), # Convert the image to a Tensor
36
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
37
  ])
38
 
39
-
40
  def predict(image):
41
  # Check if the input contains a base64-encoded string
42
  if isinstance(image, dict) and image.get("data"):
43
- # Decode the base64 string into a PIL image
44
- image_data = base64.b64decode(image["data"])
45
- image = Image.open(BytesIO(image_data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Apply your existing transformations
48
- image = transform(image).unsqueeze(0) # Transform and add batch dimension
49
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
50
 
51
- # Perform inference
52
- with torch.no_grad():
53
- outputs = model(image)
54
- predicted_class = torch.argmax(outputs, dim=1).item()
55
-
56
- # Create a response based on the predicted class
57
- if predicted_class == 0:
58
- return "The photo you've sent is of fall army worm with problem ID 126."
59
- elif predicted_class == 1:
60
- return "The photo you've sent is of a healthy maize image."
61
- else:
62
- return "Unexpected class prediction."
63
 
64
  # Create the Gradio interface
65
  iface = gr.Interface(
66
- fn=predict, # Function for prediction
67
- inputs=gr.Image(type="pil"), # Image input
68
- outputs=gr.Textbox(), # Output: Predicted class
69
- live=True, # Updates as the user uploads an image
70
  title="Maize Anomaly Detection",
71
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
72
  )
73
 
74
  # Launch the Gradio interface
75
- iface.launch(share=True)
 
9
  from io import BytesIO
10
 
11
  # Define the number of classes
12
+ num_classes = 2 # Update with the actual number of classes in your dataset
13
 
14
  # Download model from Hugging Face
15
  def download_model():
16
+ try:
17
+ model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
18
+ return model_path
19
+ except Exception as e:
20
+ print(f"Error downloading model: {e}")
21
+ return None
22
 
23
  # Load the model from Hugging Face
24
  def load_model(model_path):
25
+ try:
26
+ model = models.resnet50(pretrained=False)
27
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
28
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
29
+ model.eval()
30
+ return model
31
+ except Exception as e:
32
+ print(f"Error loading model: {e}")
33
+ return None
34
 
35
  # Download the model and load it
36
+ model_path = download_model()
37
+ model = load_model(model_path) if model_path else None
38
 
39
  # Define the transformation for the input image
40
  transform = transforms.Compose([
41
+ transforms.Resize(256),
42
+ transforms.CenterCrop(224),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
  ])
46
 
 
47
  def predict(image):
48
  # Check if the input contains a base64-encoded string
49
  if isinstance(image, dict) and image.get("data"):
50
+ try:
51
+ image_data = base64.b64decode(image["data"])
52
+ image = Image.open(BytesIO(image_data))
53
+ except Exception as e:
54
+ return f"Error decoding base64 image: {e}"
55
+
56
+ elif isinstance(image, str):
57
+ try:
58
+ response = requests.get(image)
59
+ image = Image.open(BytesIO(response.content))
60
+ except Exception as e:
61
+ return f"Error fetching image from URL: {e}"
62
+
63
+ # Apply transformations
64
+ try:
65
+ image = transform(image).unsqueeze(0)
66
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
67
 
68
+ with torch.no_grad():
69
+ outputs = model(image)
70
+ predicted_class = torch.argmax(outputs, dim=1).item()
71
 
72
+ if predicted_class == 0:
73
+ return "The photo you've sent is of fall army worm with problem ID 126."
74
+ elif predicted_class == 1:
75
+ return "The photo you've sent is of a healthy maize image."
76
+ else:
77
+ return "Unexpected class prediction."
78
+ except Exception as e:
79
+ return f"Error processing image: {e}"
 
 
 
 
80
 
81
  # Create the Gradio interface
82
  iface = gr.Interface(
83
+ fn=predict,
84
+ inputs=gr.Image(type="pil"),
85
+ outputs=gr.Textbox(),
86
+ live=True,
87
  title="Maize Anomaly Detection",
88
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
89
  )
90
 
91
  # Launch the Gradio interface
92
+ iface.launch(share=True)