File size: 3,935 Bytes
2201868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aab569a
2255b93
b4d05af
aab569a
 
 
b4d05af
aab569a
 
b4d05af
66345ab
aab569a
5649d80
b4d05af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc29cbf
b4d05af
95250f9
5cadf06
95250f9
5cadf06
95250f9
 
 
5cadf06
95250f9
b4d05af
95250f9
5649d80
95250f9
5649d80
2255b93
5649d80
b4d05af
2255b93
5cadf06
5649d80
9dfc63c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import base64
from io import BytesIO
import os

# Define the number of classes
num_classes = 2

# Download model from Hugging Face
def download_model():
    try:
        model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
        return model_path
    except Exception as e:
        print(f"Error downloading model: {e}")
        return None

# Load the model from Hugging Face
def load_model(model_path):
    try:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def predict(data):
    try:
        # Check if the data is a list and not empty
        if not isinstance(data, list) or len(data) == 0:
            return json.dumps({"error": "Input data should be a non-empty list."})
        
        # Extract the image path
        image_input = data[0].get('path', None)
        if not image_input:
            return json.dumps({"error": "No image path provided."})

        print(f"Received image input: {image_input}")

        # Handle URLs
        if isinstance(image_input, str) and (image_input.startswith("http://") or image_input.startswith("https://")):
            try:
                response = requests.get(image_input)
                response.raise_for_status()  # Check for HTTP errors
                image = Image.open(BytesIO(response.content))
                print(f"Fetched image from URL: {image}")
            except Exception as e:
                print(f"Error fetching image from URL: {e}")
                return json.dumps({"error": f"Error fetching image from URL: {e}"})
        
        # Check if the image path is a valid local path
        elif isinstance(image_input, str) and os.path.exists(image_input):
            try:
                image = Image.open(image_input)
                print(f"Loaded image from local path: {image}")
            except Exception as e:
                return json.dumps({"error": f"Error loading image from local path: {e}"})
        
        else:
            return json.dumps({"error": "Invalid image path. Ensure it's a valid URL or local path."})

        # Apply the transformations and make prediction
        image = transform(image).unsqueeze(0)
        print(f"Transformed image tensor: {image.shape}")
        image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        with torch.no_grad():
            outputs = model(image)
            predicted_class = torch.argmax(outputs, dim=1).item()
            print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")

        # Return the result based on the predicted class
        if predicted_class == 0:
            return json.dumps({"result": "The photo you've sent is of fall army worm with problem ID 126."})
        elif predicted_class == 1:
            return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
        else:
            return json.dumps({"error": "Unexpected class prediction."})
    
    except Exception as e:
        print(f"Error processing image: {e}")
        return json.dumps({"error": f"Error processing image: {e}"})