ombhojane commited on
Commit
bfc813d
·
verified ·
1 Parent(s): 2f7285e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -21
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import numpy as np
4
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
5
 
6
  def predict_disease(image_file):
7
- """Predicts the disease of a plant from an image.
8
 
9
  Args:
10
  image_file: The uploaded image file.
@@ -12,24 +13,27 @@ def predict_disease(image_file):
12
  Returns:
13
  A string representing the predicted disease.
14
  """
15
- # Load the model and feature extractor
16
- model_name = "ombhojane/healthyPlantsModel"
17
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
18
- model = AutoModelForImageClassification.from_pretrained(model_name)
19
-
20
- # Process the image
21
- image = Image.open(image_file).convert("RGB")
22
- inputs = feature_extractor(images=image, return_tensors="pt")
23
-
24
- # Make prediction
25
- outputs = model(**inputs)
26
- logits = outputs.logits
27
- predicted_class_idx = logits.argmax(-1).item()
28
-
29
- # Get the label
30
- predicted_label = model.config.id2label[predicted_class_idx]
31
-
32
- return predicted_label
 
 
 
33
 
34
  def main():
35
  """Creates the Streamlit app."""
@@ -50,7 +54,10 @@ def main():
50
  disease = predict_disease(image_file)
51
 
52
  # Display the prediction
53
- st.success(f"Predicted disease: {disease}")
 
 
 
54
 
55
  if __name__ == "__main__":
56
  main()
 
1
  import streamlit as st
2
  from PIL import Image
3
  import numpy as np
4
+ import tensorflow as tf
5
+ from transformers import AutoFeatureExtractor, TFAutoModelForImageClassification
6
 
7
  def predict_disease(image_file):
8
+ """Predicts the disease of a plant from an image using TensorFlow.
9
 
10
  Args:
11
  image_file: The uploaded image file.
 
13
  Returns:
14
  A string representing the predicted disease.
15
  """
16
+ try:
17
+ # Load the model and feature extractor
18
+ model_name = "ombhojane/healthyPlantsModel"
19
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
20
+ model = TFAutoModelForImageClassification.from_pretrained(model_name)
21
+
22
+ # Process the image
23
+ image = Image.open(image_file).convert("RGB")
24
+ inputs = feature_extractor(images=image, return_tensors="tf")
25
+
26
+ # Make prediction
27
+ outputs = model(**inputs)
28
+ logits = outputs.logits
29
+ predicted_class_idx = tf.argmax(logits, axis=-1).numpy()[0]
30
+
31
+ # Get the label
32
+ predicted_label = model.config.id2label[predicted_class_idx]
33
+
34
+ return predicted_label
35
+ except Exception as e:
36
+ return f"Error: {str(e)}"
37
 
38
  def main():
39
  """Creates the Streamlit app."""
 
54
  disease = predict_disease(image_file)
55
 
56
  # Display the prediction
57
+ if disease.startswith("Error"):
58
+ st.error(disease)
59
+ else:
60
+ st.success(f"Predicted disease: {disease}")
61
 
62
  if __name__ == "__main__":
63
  main()