jays009 commited on
Commit
95250f9
·
verified ·
1 Parent(s): fc29cbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -48
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  import torch
3
  from torch import nn
4
  from torchvision import models, transforms
5
- from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import requests
8
  import base64
@@ -12,30 +11,19 @@ import os
12
  # Define the number of classes
13
  num_classes = 2 # Update with the actual number of classes in your dataset
14
 
15
- # Download model from Hugging Face
16
- def download_model():
17
- try:
18
- model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
19
- return model_path
20
- except Exception as e:
21
- print(f"Error downloading model: {e}")
22
- return None
23
-
24
- # Load the model from Hugging Face
25
- def load_model(model_path):
26
  try:
27
  model = models.resnet50(pretrained=False)
28
  model.fc = nn.Linear(model.fc.in_features, num_classes)
29
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
30
  model.eval()
31
  return model
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
  return None
35
 
36
- # Download the model and load it
37
- model_path = download_model()
38
- model = load_model(model_path) if model_path else None
39
 
40
  # Define the transformation for the input image
41
  transform = transforms.Compose([
@@ -45,46 +33,45 @@ transform = transforms.Compose([
45
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
46
  ])
47
 
48
- def process_image(image_input):
 
49
  try:
50
- # Process the image input (URL, local file, or base64)
51
- if isinstance(image_input, dict):
52
- # Check if the input contains a URL
53
- if image_input.get("url"):
54
- image_url = image_input["url"]
55
- response = requests.get(image_url)
56
- image = Image.open(BytesIO(response.content))
57
- # Check if the input contains a file path
58
- elif image_input.get("path"):
59
- image_path = image_input["path"]
60
- image = Image.open(image_path)
61
- # Handle base64 if it's included
62
- elif image_input.get("data"):
63
- image_data = base64.b64decode(image_input["data"])
64
  image = Image.open(BytesIO(image_data))
 
 
 
 
 
 
 
65
  else:
66
- return "Invalid input data format. Please provide a URL or path."
67
 
68
- # Apply transformations
69
- image = transform(image).unsqueeze(0)
70
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
71
 
72
- # Make the prediction
73
- with torch.no_grad():
74
- outputs = model(image)
75
- predicted_class = torch.argmax(outputs, dim=1).item()
76
 
77
- # Return prediction result
78
- if predicted_class == 0:
79
- return "The photo you've sent is of fall army worm with problem ID 126."
80
- elif predicted_class == 1:
81
- return "The photo you've sent is of a healthy maize image."
82
- else:
83
- return "Unexpected class prediction."
 
 
 
84
  else:
85
- return "Invalid input. Please provide a dictionary with 'url' or 'path'."
86
  except Exception as e:
87
- print(f"Error processing image: {e}")
88
  return f"Error processing image: {e}"
89
 
90
  # Create the Gradio interface
 
2
  import torch
3
  from torch import nn
4
  from torchvision import models, transforms
 
5
  from PIL import Image
6
  import requests
7
  import base64
 
11
  # Define the number of classes
12
  num_classes = 2 # Update with the actual number of classes in your dataset
13
 
14
+ # Load the model (assuming you've already downloaded it)
15
+ def load_model():
 
 
 
 
 
 
 
 
 
16
  try:
17
  model = models.resnet50(pretrained=False)
18
  model.fc = nn.Linear(model.fc.in_features, num_classes)
19
+ model.load_state_dict(torch.load("path_to_your_model.pth", map_location=torch.device("cpu")))
20
  model.eval()
21
  return model
22
  except Exception as e:
23
  print(f"Error loading model: {e}")
24
  return None
25
 
26
+ model = load_model()
 
 
27
 
28
  # Define the transformation for the input image
29
  transform = transforms.Compose([
 
33
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
34
  ])
35
 
36
+ # Prediction function
37
+ def process_image(data):
38
  try:
39
+ # Check if the input contains a base64-encoded string
40
+ if isinstance(data, dict):
41
+ if "data" in data:
42
+ # Base64 decoding
43
+ image_data = base64.b64decode(data["data"])
 
 
 
 
 
 
 
 
 
44
  image = Image.open(BytesIO(image_data))
45
+ elif "url" in data:
46
+ # URL-based image loading
47
+ response = requests.get(data["url"])
48
+ image = Image.open(BytesIO(response.content))
49
+ elif "path" in data:
50
+ # Local path image loading
51
+ image = Image.open(data["path"])
52
  else:
53
+ return "Invalid input data structure."
54
 
55
+ # Validate image
56
+ if not isinstance(image, Image.Image):
57
+ return "Invalid image format received."
58
 
59
+ # Apply transformations
60
+ image = transform(image).unsqueeze(0)
 
 
61
 
62
+ # Prediction
63
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
64
+ with torch.no_grad():
65
+ outputs = model(image)
66
+ predicted_class = torch.argmax(outputs, dim=1).item()
67
+
68
+ if predicted_class == 0:
69
+ return "The photo you've sent is of fall army worm with problem ID 126."
70
+ elif predicted_class == 1:
71
+ return "The photo you've sent is of a healthy maize image."
72
  else:
73
+ return "Unexpected class prediction."
74
  except Exception as e:
 
75
  return f"Error processing image: {e}"
76
 
77
  # Create the Gradio interface