ShAnSantosh's picture
updated the model
160cb15
raw
history blame
1.78 kB
import albumentations
import cv2
import torch
import timm
import gradio as gr
import numpy as np
import os
import random
device = torch.device('cpu')
labels = {
0: 'bacterial_leaf_blight',
1: 'bacterial_leaf_streak',
2: 'bacterial_panicle_blight',
3: 'blast',
4: 'brown_spot',
5: 'dead_heart',
6: 'downy_mildew',
7: 'hispa',
8: 'normal',
9: 'tungro'
}
def inference_fn(model, image=None):
model.eval()
image = image.to(device)
with torch.no_grad():
output = model(image.unsqueeze(0))
out = output.sigmoid().detach().cpu().numpy().flatten()
return out
def predict(image=None) -> dict:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
augmentations = albumentations.Compose(
[
albumentations.Resize(256, 256),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
]
)
augmented = augmentations(image=image)
image = augmented["image"]
image = np.transpose(image, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32)
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device)))
model.to(device)
predicted = inference_fn(model, image)
return {labels[i]: float(predicted[i]) for i in range(10)}
gr.Interface(fn=predict,
inputs=gr.inputs.Image(),
outputs=gr.outputs.Label(num_top_classes=10),
examples=["200005.jpg", "200006.jpg"], interpretation='default', capture_session=True).launch(share=True)