ombhojane commited on
Commit
2f7285e
·
verified ·
1 Parent(s): 741f5c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -32
app.py CHANGED
@@ -1,37 +1,56 @@
1
  import streamlit as st
2
-
3
- from transformers import pipeline
4
-
5
- pipe = pipeline("image-classification", model="ombhojane/healthyPlantsModel", backend="pytorch")
6
-
7
- def predict_disease(image):
8
- """Predicts the disease of a plant from an image.
9
-
10
- Args:
11
- image: A NumPy array representing the image.
12
-
13
- Returns:
14
- A string representing the predicted disease.
15
- """
16
-
17
- prediction = pipe(image)
18
- label = prediction[0]["label"]
19
- return label
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def main():
22
- """Creates the Streamlit app."""
23
-
24
- st.title("Plant Disease Detection App")
25
-
26
- # Upload an image
27
- image = st.file_uploader("Upload an image of a plant")
28
-
29
- # Predict the disease
30
- if image is not None:
31
- disease = predict_disease(image)
32
-
33
- # Display the prediction
34
- st.markdown(f"Predicted disease: {disease}")
 
 
 
 
 
 
35
 
36
  if __name__ == "__main__":
37
- main()
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
5
+
6
+ def predict_disease(image_file):
7
+ """Predicts the disease of a plant from an image.
8
+
9
+ Args:
10
+ image_file: The uploaded image file.
11
+
12
+ Returns:
13
+ A string representing the predicted disease.
14
+ """
15
+ # Load the model and feature extractor
16
+ model_name = "ombhojane/healthyPlantsModel"
17
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
18
+ model = AutoModelForImageClassification.from_pretrained(model_name)
19
+
20
+ # Process the image
21
+ image = Image.open(image_file).convert("RGB")
22
+ inputs = feature_extractor(images=image, return_tensors="pt")
23
+
24
+ # Make prediction
25
+ outputs = model(**inputs)
26
+ logits = outputs.logits
27
+ predicted_class_idx = logits.argmax(-1).item()
28
+
29
+ # Get the label
30
+ predicted_label = model.config.id2label[predicted_class_idx]
31
+
32
+ return predicted_label
33
 
34
  def main():
35
+ """Creates the Streamlit app."""
36
+ st.title("Plant Disease Detection App")
37
+
38
+ # Upload an image
39
+ image_file = st.file_uploader("Upload an image of a plant", type=["jpg", "jpeg", "png"])
40
+
41
+ # Predict the disease
42
+ if image_file is not None:
43
+ # Display the image
44
+ image = Image.open(image_file)
45
+ st.image(image, caption="Uploaded Plant Image", use_column_width=True)
46
+
47
+ # Add a prediction button
48
+ if st.button("Detect Disease"):
49
+ with st.spinner("Analyzing image..."):
50
+ disease = predict_disease(image_file)
51
+
52
+ # Display the prediction
53
+ st.success(f"Predicted disease: {disease}")
54
 
55
  if __name__ == "__main__":
56
+ main()