jays009 commited on
Commit
07f42c0
·
verified ·
1 Parent(s): dd0dfd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -8,7 +8,6 @@ from PIL import Image
8
  import requests
9
  from io import BytesIO
10
  from fastapi import FastAPI
11
- from gradio.routes import App
12
 
13
  # Define the number of classes
14
  num_classes = 2
@@ -18,11 +17,13 @@ results_cache = {}
18
 
19
  # Download model from Hugging Face
20
  def download_model():
 
21
  model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
22
  return model_path
23
 
24
  # Load the model from Hugging Face
25
  def load_model(model_path):
 
26
  model = models.resnet50(pretrained=False)
27
  model.fc = nn.Linear(model.fc.in_features, num_classes)
28
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
@@ -44,22 +45,16 @@ transform = transforms.Compose([
44
  # Function to predict from image content
45
  def predict_from_image(image):
46
  try:
47
- # Log the image processing
48
  print(f"Processing image: {image}")
49
-
50
- # Ensure the image is a PIL Image
51
  if not isinstance(image, Image.Image):
52
  raise ValueError("Invalid image format received. Please provide a valid image.")
53
 
54
  # Apply transformations
55
  image_tensor = transform(image).unsqueeze(0)
56
-
57
- # Predict
58
  with torch.no_grad():
59
  outputs = model(image_tensor)
60
  predicted_class = torch.argmax(outputs, dim=1).item()
61
 
62
- # Interpret the result
63
  if predicted_class == 0:
64
  return {"result": "The photo is of fall army worm with problem ID 126."}
65
  elif predicted_class == 1:
@@ -74,11 +69,10 @@ def predict_from_image(image):
74
  # Function to predict from URL
75
  def predict_from_url(url):
76
  try:
77
- # Fetch the image from the URL
78
  response = requests.get(url)
79
- response.raise_for_status() # Ensure the request was successful
80
  image = Image.open(BytesIO(response.content))
81
- print(f"Fetched image from URL: {url}")
82
  return predict_from_image(image)
83
  except Exception as e:
84
  print(f"Error during URL processing: {e}")
@@ -87,21 +81,19 @@ def predict_from_url(url):
87
  # Main prediction function with caching
88
  def predict(image, url):
89
  try:
 
90
  if image:
91
  result = predict_from_image(image)
92
  elif url:
93
  result = predict_from_url(url)
94
  else:
95
  result = {"error": "No input provided. Please upload an image or provide a URL."}
96
-
97
- # Generate and store the event ID
98
  event_id = id(result) # Use Python's id() function to generate a unique identifier
99
  results_cache[event_id] = result
100
-
101
- # Log the result
102
  print(f"Event ID: {event_id}, Result: {result}")
103
  return {"event_id": event_id, "result": result}
104
-
105
  except Exception as e:
106
  print(f"Error in prediction function: {e}")
107
  return {"error": str(e)}
@@ -109,7 +101,7 @@ def predict(image, url):
109
  # Function to retrieve result by event_id
110
  def get_result(event_id):
111
  try:
112
- # Convert event_id from string to int
113
  event_id = int(event_id)
114
  result = results_cache.get(event_id)
115
  if result:
 
8
  import requests
9
  from io import BytesIO
10
  from fastapi import FastAPI
 
11
 
12
  # Define the number of classes
13
  num_classes = 2
 
17
 
18
  # Download model from Hugging Face
19
  def download_model():
20
+ print("Downloading model...")
21
  model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
22
  return model_path
23
 
24
  # Load the model from Hugging Face
25
  def load_model(model_path):
26
+ print("Loading model...")
27
  model = models.resnet50(pretrained=False)
28
  model.fc = nn.Linear(model.fc.in_features, num_classes)
29
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
 
45
  # Function to predict from image content
46
  def predict_from_image(image):
47
  try:
 
48
  print(f"Processing image: {image}")
 
 
49
  if not isinstance(image, Image.Image):
50
  raise ValueError("Invalid image format received. Please provide a valid image.")
51
 
52
  # Apply transformations
53
  image_tensor = transform(image).unsqueeze(0)
 
 
54
  with torch.no_grad():
55
  outputs = model(image_tensor)
56
  predicted_class = torch.argmax(outputs, dim=1).item()
57
 
 
58
  if predicted_class == 0:
59
  return {"result": "The photo is of fall army worm with problem ID 126."}
60
  elif predicted_class == 1:
 
69
  # Function to predict from URL
70
  def predict_from_url(url):
71
  try:
72
+ print(f"Fetching image from URL: {url}")
73
  response = requests.get(url)
74
+ response.raise_for_status()
75
  image = Image.open(BytesIO(response.content))
 
76
  return predict_from_image(image)
77
  except Exception as e:
78
  print(f"Error during URL processing: {e}")
 
81
  # Main prediction function with caching
82
  def predict(image, url):
83
  try:
84
+ print("Starting prediction...")
85
  if image:
86
  result = predict_from_image(image)
87
  elif url:
88
  result = predict_from_url(url)
89
  else:
90
  result = {"error": "No input provided. Please upload an image or provide a URL."}
91
+
 
92
  event_id = id(result) # Use Python's id() function to generate a unique identifier
93
  results_cache[event_id] = result
94
+
 
95
  print(f"Event ID: {event_id}, Result: {result}")
96
  return {"event_id": event_id, "result": result}
 
97
  except Exception as e:
98
  print(f"Error in prediction function: {e}")
99
  return {"error": str(e)}
 
101
  # Function to retrieve result by event_id
102
  def get_result(event_id):
103
  try:
104
+ print(f"Retrieving result for event ID: {event_id}")
105
  event_id = int(event_id)
106
  result = results_cache.get(event_id)
107
  if result: