File size: 5,488 Bytes
2201868
 
 
 
 
 
 
 
 
dd0dfd1
2201868
 
 
 
dd0dfd1
 
 
2201868
 
b77b937
 
2201868
 
 
b77b937
 
 
 
 
2201868
 
 
b77b937
2201868
 
 
 
 
 
 
 
 
b77b937
 
5b86dff
a53a305
5b86dff
 
fc29cbf
5b86dff
 
 
 
 
5cadf06
5b86dff
a53a305
5b86dff
a53a305
5b86dff
a53a305
 
5b86dff
a53a305
 
52fd9c2
b77b937
 
 
07f42c0
b77b937
07f42c0
b77b937
 
2255b93
a53a305
 
5b86dff
20e6ace
 
a53a305
 
 
20e6ace
a53a305
20e6ace
 
 
 
07f42c0
20e6ace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53a305
20e6ace
a53a305
9dfc63c
dd0dfd1
 
 
07f42c0
dd0dfd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53a305
dd0dfd1
 
 
 
 
 
 
 
 
 
 
20e6ace
 
dd0dfd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO
from fastapi import FastAPI

# Define the number of classes
num_classes = 2

# In-memory storage for results
results_cache = {}

# Download model from Hugging Face
def download_model():
    model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
    return model_path

# Load the model from Hugging Face
def load_model(model_path):
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()
    return model

# Download the model and load it
model_path = download_model()
model = load_model(model_path)

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Function to predict from image content
def predict_from_image(image):
    try:
        print(f"Processing image: {image}")
        if not isinstance(image, Image.Image):
            raise ValueError("Invalid image format received. Please provide a valid image.")

        # Apply transformations
        image_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            outputs = model(image_tensor)
            predicted_class = torch.argmax(outputs, dim=1).item()

        if predicted_class == 0:
            return {"result": "The photo is of fall army worm with problem ID 126."}
        elif predicted_class == 1:
            return {"result": "The photo is of a healthy maize image."}
        else:
            return {"error": "Unexpected class prediction."}

    except Exception as e:
        print(f"Error during image processing: {e}")
        return {"error": str(e)}

# Function to predict from URL
def predict_from_url(url):
    try:
        print(f"Fetching image from URL: {url}")
        response = requests.get(url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return predict_from_image(image)
    except Exception as e:
        print(f"Error during URL processing: {e}")
        return {"error": f"Failed to process the URL: {str(e)}"}

# Main prediction function without event ID for direct uploads
def predict_direct_upload(image):
    try:
        if image:
            result = predict_from_image(image)
            return result
        else:
            return {"error": "No image provided. Please upload an image."}
    except Exception as e:
        print(f"Error in direct upload prediction function: {e}")
        return {"error": str(e)}

# Main prediction function with caching for paths via Postman
def predict_with_event_id(data):
    try:
        path = data[0].get('path', None)
        if path:
            if path.startswith("http://") or path.startswith("https://"):
                result = predict_from_url(path)
            elif os.path.isfile(path):
                image = Image.open(path)
                result = predict_from_image(image)
            else:
                result = {"error": "Invalid path format. Please provide a valid URL or local file path."}
            
            event_id = id(result)
            results_cache[event_id] = result
            
            print(f"Event ID: {event_id}, Result: {result}")
            return {"event_id": event_id, "result": result}
        else:
            return {"error": "No path provided. Please provide a valid path."}
    except Exception as e:
        print(f"Error in prediction function with event ID: {e}")
        return {"error": str(e)}

# Function to retrieve result by event_id
def get_result(event_id):
    try:
        print(f"Retrieving result for event ID: {event_id}")
        event_id = int(event_id)
        result = results_cache.get(event_id)
        if result:
            return result
        else:
            return {"error": "No result found for the provided event ID."}
    except Exception as e:
        return {"error": f"Invalid event ID: {str(e)}"}

# Create a FastAPI app for handling the GET request
app = FastAPI()

@app.get("/result/{event_id}")
def get_result_api(event_id: int):
    return get_result(event_id)

# Gradio interface setup
iface = gr.Blocks()

with iface:
    gr.Markdown("# Maize Anomaly Detection")
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload an Image")
        url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL")
    output = gr.JSON(label="Prediction Result")
    submit_button = gr.Button("Submit")
    
    submit_button.click(
        fn=predict_direct_upload,
        inputs=[image_input],
        outputs=output
    )
    
    # Event ID retrieval section
    with gr.Row():
        event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID")
        event_output = gr.JSON(label="Retrieved Result")
        retrieve_button = gr.Button("Get Result")
        
        retrieve_button.click(
            fn=get_result,
            inputs=[event_id_input],
            outputs=event_output
        )

# Launch the Gradio interface
iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860)