jays009 commited on
Commit
163e73a
·
verified ·
1 Parent(s): 334fb08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -109
app.py CHANGED
@@ -6,24 +6,19 @@ from torchvision import models, transforms
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
 
9
  from io import BytesIO
10
- from fastapi import FastAPI
11
 
12
  # Define the number of classes
13
  num_classes = 2
14
 
15
- # In-memory storage for results
16
- results_cache = {}
17
-
18
  # Download model from Hugging Face
19
  def download_model():
20
- print("Downloading model...")
21
  model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
22
  return model_path
23
 
24
  # Load the model from Hugging Face
25
  def load_model(model_path):
26
- print("Loading model...")
27
  model = models.resnet50(pretrained=False)
28
  model.fc = nn.Linear(model.fc.in_features, num_classes)
29
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
@@ -44,116 +39,48 @@ transform = transforms.Compose([
44
 
45
  # Function to predict from image content
46
  def predict_from_image(image):
47
- try:
48
- print(f"Processing image: {image}")
49
- if not isinstance(image, Image.Image):
50
- raise ValueError("Invalid image format received. Please provide a valid image.")
51
-
52
- # Apply transformations
53
- image_tensor = transform(image).unsqueeze(0)
54
- with torch.no_grad():
55
- outputs = model(image_tensor)
56
- predicted_class = torch.argmax(outputs, dim=1).item()
57
-
58
- if predicted_class == 0:
59
- return {"result": "The photo is of fall army worm with problem ID 126."}
60
- elif predicted_class == 1:
61
- return {"result": "The photo is of a healthy maize image."}
62
- else:
63
- return {"error": "Unexpected class prediction."}
64
-
65
- except Exception as e:
66
- print(f"Error during image processing: {e}")
67
- return {"error": str(e)}
68
 
69
  # Function to predict from URL
70
  def predict_from_url(url):
71
  try:
72
- print(f"Fetching image from URL: {url}")
73
  response = requests.get(url)
74
- response.raise_for_status()
75
  image = Image.open(BytesIO(response.content))
76
  return predict_from_image(image)
77
  except Exception as e:
78
- print(f"Error during URL processing: {e}")
79
  return {"error": f"Failed to process the URL: {str(e)}"}
80
 
81
- # Main prediction function with caching
82
- def predict(image, url):
83
- try:
84
- print("Starting prediction...")
85
-
86
- # Check if image is provided (direct upload)
87
- if image:
88
- print("Processing uploaded image...")
89
- result = predict_from_image(image)
90
-
91
- # Check if URL is provided
92
- elif url:
93
- print("Processing image from URL...")
94
- result = predict_from_url(url)
95
-
96
- else:
97
- result = {"error": "No input provided. Please upload an image or provide a URL."}
98
-
99
- event_id = id(result) # Use Python's id() function to generate a unique identifier
100
- results_cache[event_id] = result
101
-
102
- print(f"Event ID: {event_id}, Result: {result}")
103
- return {"event_id": event_id, "result": result}
104
- except Exception as e:
105
- print(f"Error in prediction function: {e}")
106
- return {"error": str(e)}
107
-
108
-
109
- # Function to retrieve result by event_id
110
- def get_result(event_id):
111
- try:
112
- print(f"Retrieving result for event ID: {event_id}")
113
- event_id = int(event_id)
114
- result = results_cache.get(event_id)
115
- if result:
116
- return result
117
- else:
118
- return {"error": "No result found for the provided event ID."}
119
- except Exception as e:
120
- return {"error": f"Invalid event ID: {str(e)}"}
121
-
122
- # Create a FastAPI app for handling the GET request
123
- app = FastAPI()
124
-
125
- @app.get("/result/{event_id}")
126
- def get_result_api(event_id: int):
127
- return get_result(event_id)
128
-
129
- # Gradio interface setup
130
- iface = gr.Blocks()
131
-
132
- with iface:
133
- gr.Markdown("# Maize Anomaly Detection")
134
- with gr.Row():
135
- image_input = gr.Image(type="pil", label="Upload an Image")
136
- url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL")
137
- output = gr.JSON(label="Prediction Result")
138
- submit_button = gr.Button("Submit")
139
-
140
- submit_button.click(
141
- fn=predict,
142
- inputs=[image_input, url_input],
143
- outputs=output
144
- )
145
-
146
- # Event ID retrieval section
147
- with gr.Row():
148
- event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID")
149
- event_output = gr.JSON(label="Retrieved Result")
150
- retrieve_button = gr.Button("Get Result")
151
-
152
- retrieve_button.click(
153
- fn=get_result,
154
- inputs=[event_id_input],
155
- outputs=event_output
156
- )
157
-
158
- # Launch the Gradio interface
159
- iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860)
 
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
9
+ import os
10
  from io import BytesIO
 
11
 
12
  # Define the number of classes
13
  num_classes = 2
14
 
 
 
 
15
  # Download model from Hugging Face
16
  def download_model():
 
17
  model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
18
  return model_path
19
 
20
  # Load the model from Hugging Face
21
  def load_model(model_path):
 
22
  model = models.resnet50(pretrained=False)
23
  model.fc = nn.Linear(model.fc.in_features, num_classes)
24
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
 
39
 
40
  # Function to predict from image content
41
  def predict_from_image(image):
42
+ # Ensure the image is a PIL Image
43
+ if not isinstance(image, Image.Image):
44
+ raise ValueError("Invalid image format received. Please provide a valid image.")
45
+
46
+ # Apply transformations
47
+ image_tensor = transform(image).unsqueeze(0)
48
+
49
+ # Predict
50
+ with torch.no_grad():
51
+ outputs = model(image_tensor)
52
+ predicted_class = torch.argmax(outputs, dim=1).item()
53
+
54
+ # Interpret the result
55
+ if predicted_class == 0:
56
+ return {"result": "The photo is of fall army worm with problem ID 126."}
57
+ elif predicted_class == 1:
58
+ return {"result": "The photo is of a healthy maize image."}
59
+ else:
60
+ return {"error": "Unexpected class prediction."}
 
 
61
 
62
  # Function to predict from URL
63
  def predict_from_url(url):
64
  try:
 
65
  response = requests.get(url)
66
+ response.raise_for_status() # Ensure the request was successful
67
  image = Image.open(BytesIO(response.content))
68
  return predict_from_image(image)
69
  except Exception as e:
 
70
  return {"error": f"Failed to process the URL: {str(e)}"}
71
 
72
+ # Gradio interface
73
+ iface = gr.Interface(
74
+ fn=lambda image, url: predict_from_image(image) if image else predict_from_url(url),
75
+ inputs=[
76
+ gr.Image(type="pil", label="Upload an Image"),
77
+ gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL"),
78
+ ],
79
+ outputs=gr.JSON(label="Prediction Result"),
80
+ live=True,
81
+ title="Maize Anomaly Detection",
82
+ description="Upload an image or provide a URL to detect anomalies in maize crops.",
83
+ )
84
+
85
+ # Launch the interface
86
+ iface.launch(share=True, show_error=True)