File size: 5,076 Bytes
2201868
 
 
 
 
 
 
 
 
dd0dfd1
2201868
 
 
 
dd0dfd1
 
 
2201868
 
f7b71ff
b77b937
 
2201868
 
 
f7b71ff
b77b937
 
 
 
 
2201868
 
 
b77b937
2201868
 
 
 
 
 
 
 
 
b77b937
 
5b86dff
a53a305
5b86dff
 
fc29cbf
5b86dff
 
 
 
 
5cadf06
5b86dff
a53a305
5b86dff
a53a305
5b86dff
a53a305
 
5b86dff
a53a305
 
52fd9c2
b77b937
 
 
07f42c0
b77b937
07f42c0
b77b937
 
2255b93
a53a305
 
5b86dff
f7b71ff
 
a53a305
f7b71ff
334fb08
 
a53a305
334fb08
a53a305
334fb08
 
f7b71ff
334fb08
f7b71ff
334fb08
a53a305
f7b71ff
07f42c0
f7b71ff
 
 
 
 
a53a305
f7b71ff
a53a305
9dfc63c
334fb08
dd0dfd1
 
 
07f42c0
dd0dfd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53a305
dd0dfd1
 
 
 
 
 
 
 
 
 
 
f7b71ff
 
dd0dfd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO
from fastapi import FastAPI

# Define the number of classes
num_classes = 2

# In-memory storage for results
results_cache = {}

# Download model from Hugging Face
def download_model():
    print("Downloading model...")
    model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
    return model_path

# Load the model from Hugging Face
def load_model(model_path):
    print("Loading model...")
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()
    return model

# Download the model and load it
model_path = download_model()
model = load_model(model_path)

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Function to predict from image content
def predict_from_image(image):
    try:
        print(f"Processing image: {image}")
        if not isinstance(image, Image.Image):
            raise ValueError("Invalid image format received. Please provide a valid image.")

        # Apply transformations
        image_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            outputs = model(image_tensor)
            predicted_class = torch.argmax(outputs, dim=1).item()

        if predicted_class == 0:
            return {"result": "The photo is of fall army worm with problem ID 126."}
        elif predicted_class == 1:
            return {"result": "The photo is of a healthy maize image."}
        else:
            return {"error": "Unexpected class prediction."}

    except Exception as e:
        print(f"Error during image processing: {e}")
        return {"error": str(e)}

# Function to predict from URL
def predict_from_url(url):
    try:
        print(f"Fetching image from URL: {url}")
        response = requests.get(url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return predict_from_image(image)
    except Exception as e:
        print(f"Error during URL processing: {e}")
        return {"error": f"Failed to process the URL: {str(e)}"}

# Main prediction function with caching
def predict(image, url):
    try:
        print("Starting prediction...")
        
        # Check if image is provided (direct upload)
        if image:
            print("Processing uploaded image...")
            result = predict_from_image(image)
        
        # Check if URL is provided
        elif url:
            print("Processing image from URL...")
            result = predict_from_url(url)
        
        else:
            result = {"error": "No input provided. Please upload an image or provide a URL."}

        event_id = id(result)  # Use Python's id() function to generate a unique identifier
        results_cache[event_id] = result

        print(f"Event ID: {event_id}, Result: {result}")
        return {"event_id": event_id, "result": result}
    except Exception as e:
        print(f"Error in prediction function: {e}")
        return {"error": str(e)}


# Function to retrieve result by event_id
def get_result(event_id):
    try:
        print(f"Retrieving result for event ID: {event_id}")
        event_id = int(event_id)
        result = results_cache.get(event_id)
        if result:
            return result
        else:
            return {"error": "No result found for the provided event ID."}
    except Exception as e:
        return {"error": f"Invalid event ID: {str(e)}"}

# Create a FastAPI app for handling the GET request
app = FastAPI()

@app.get("/result/{event_id}")
def get_result_api(event_id: int):
    return get_result(event_id)

# Gradio interface setup
iface = gr.Blocks()

with iface:
    gr.Markdown("# Maize Anomaly Detection")
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload an Image")
        url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL")
    output = gr.JSON(label="Prediction Result")
    submit_button = gr.Button("Submit")
    
    submit_button.click(
        fn=predict,
        inputs=[image_input, url_input],
        outputs=output
    )
    
    # Event ID retrieval section
    with gr.Row():
        event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID")
        event_output = gr.JSON(label="Retrieved Result")
        retrieve_button = gr.Button("Get Result")
        
        retrieve_button.click(
            fn=get_result,
            inputs=[event_id_input],
            outputs=event_output
        )

# Launch the Gradio interface
iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860)