saritha commited on
Commit
b4563e1
·
verified ·
1 Parent(s): bb713bd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import warnings
6
+ import sys
7
+ import os
8
+ import contextlib
9
+ from transformers import ViTForImageClassification
10
+
11
+ # Suppress warnings related to the model weights initialization, FutureWarning and UserWarnings
12
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
13
+ warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
14
+
15
+ # Suppress output for copying files and verbose model initialization messages
16
+ @contextlib.contextmanager
17
+ def suppress_stdout():
18
+ with open(os.devnull, 'w') as devnull:
19
+ old_stdout = sys.stdout
20
+ sys.stdout = devnull
21
+ try:
22
+ yield
23
+ finally:
24
+ sys.stdout = old_stdout
25
+
26
+ # Load the saved model and suppress the warnings
27
+ with suppress_stdout():
28
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
29
+ model.load_state_dict(torch.load('/kaggle/working/vit_sugarcane_disease_detection.pth'))
30
+ model.eval()
31
+
32
+ # Define the same transformation used during training
33
+ transform = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
37
+ ])
38
+
39
+ # Load the class names (disease types)
40
+ class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
41
+
42
+ # Function to predict disease type from an image
43
+ def predict_disease(image_path):
44
+ # Open the image file
45
+ img = Image.open(image_path)
46
+
47
+ # Apply transformations to the image
48
+ img_tensor = transform(img).unsqueeze(0) # Add batch dimension
49
+
50
+ # Make prediction
51
+ with torch.no_grad():
52
+ outputs = model(img_tensor)
53
+ _, predicted_class = torch.max(outputs.logits, 1)
54
+
55
+ # Get the predicted label
56
+ predicted_label = class_names[predicted_class.item()]
57
+
58
+ return predicted_label
59
+
60
+ # Test with a new image
61
+ image_path = '/kaggle/input/sugarcane-test-images/zoomed_healthy (9).jpeg' # Replace with your image path
62
+ predicted_label = predict_disease(image_path)
63
+ print(f'The predicted disease type is: {predicted_label}')