jays009 commited on
Commit
aae3560
·
verified ·
1 Parent(s): be90e66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -28
app.py CHANGED
@@ -36,43 +36,28 @@ transform = transforms.Compose([
36
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
37
  ])
38
 
39
- # Function to convert image from URL to PIL image
40
- def url_to_image(image_url):
41
- response = requests.get(image_url)
42
- img = Image.open(BytesIO(response.content))
43
- return img
44
 
45
- # Function to convert base64 string to PIL image
46
- def base64_to_pil(base64_string):
47
- img_data = base64.b64decode(base64_string)
48
- return Image.open(BytesIO(img_data))
 
 
49
 
50
- # Define the prediction function
51
- def predict(image_input):
52
- # If input is a string (URL or base64 encoded), handle accordingly
53
- if isinstance(image_input, str):
54
- if image_input.startswith("http"): # If URL
55
- image = url_to_image(image_input)
56
- elif image_input.startswith("data:image"): # If base64 string
57
- image = base64_to_pil(image_input)
58
- else: # Local image path
59
- image = Image.open(image_input)
60
- else:
61
- image = image_input # If the input is already a PIL image
62
 
63
- # Apply the necessary transformations to the image
64
- image = transform(image).unsqueeze(0) # Add batch dimension
65
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
66
-
67
  with torch.no_grad():
68
- outputs = model(image) # Perform forward pass
69
- predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class
70
 
71
  # Create a response based on the predicted class
72
  if predicted_class == 0:
73
  return "The photo you've sent is of fall army worm with problem ID 126."
74
  elif predicted_class == 1:
75
- return "The photo you've sent is of a healthy wheat image."
76
  else:
77
  return "Unexpected class prediction."
78
 
 
36
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
37
  ])
38
 
 
 
 
 
 
39
 
40
+ def predict(image):
41
+ # Check if the input contains a base64-encoded string
42
+ if isinstance(image, dict) and image.get("data"):
43
+ # Decode the base64 string into a PIL image
44
+ image_data = base64.b64decode(image["data"])
45
+ image = Image.open(BytesIO(image_data))
46
 
47
+ # Apply your existing transformations
48
+ image = transform(image).unsqueeze(0) # Transform and add batch dimension
49
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
 
 
 
 
 
 
 
 
50
 
51
+ # Perform inference
 
 
 
52
  with torch.no_grad():
53
+ outputs = model(image)
54
+ predicted_class = torch.argmax(outputs, dim=1).item()
55
 
56
  # Create a response based on the predicted class
57
  if predicted_class == 0:
58
  return "The photo you've sent is of fall army worm with problem ID 126."
59
  elif predicted_class == 1:
60
+ return "The photo you've sent is of a healthy maize image."
61
  else:
62
  return "Unexpected class prediction."
63