TheArchitect416 commited on
Commit
7b8e27e
·
verified ·
1 Parent(s): 25d01c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -20
app.py CHANGED
@@ -1,42 +1,60 @@
1
  import torch
2
  import gradio as gr
3
  from torchvision import transforms
4
- from PIL import Image
5
  from huggingface_hub import hf_hub_download
 
 
6
 
7
- # Load model
8
- class MySegmentationModel(torch.nn.Module):
9
- def __init__(self):
10
- super(MySegmentationModel, self).__init__()
11
- # Define your model architecture here
12
- self.dummy_layer = torch.nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1)
13
-
14
- def forward(self, x):
15
- return self.dummy_layer(x)
16
 
 
17
  model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth")
18
- model = MySegmentationModel()
 
 
 
 
 
 
 
 
 
19
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
20
  model.eval()
21
 
22
- # Define preprocessing
23
- transform = transforms.Compose([
24
  transforms.Resize((256, 256)),
25
- transforms.ToTensor()
 
 
26
  ])
27
 
28
- # Define inference function
29
  def predict(image):
30
- image = transform(image).unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
31
  with torch.no_grad():
32
- output = model(image)
33
- return output.squeeze(0).numpy() # Convert to numpy for visualization
 
 
 
 
34
 
35
- # Create Gradio interface
36
  iface = gr.Interface(
37
  fn=predict,
38
  inputs=gr.Image(type="pil"),
39
- outputs="image"
 
 
40
  )
41
 
42
  iface.launch()
 
1
  import torch
2
  import gradio as gr
3
  from torchvision import transforms
 
4
  from huggingface_hub import hf_hub_download
5
+ import segmentation_models_pytorch as smp
6
+ import numpy as np
7
 
8
+ # Set the number of output classes (from your label_colors.txt, you have 4 classes)
9
+ NUM_CLASSES = 4
 
 
 
 
 
 
 
10
 
11
+ # Download the model state dictionary from your Hugging Face repository
12
  model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth")
13
+
14
+ # Create the model using segmentation_models_pytorch.
15
+ # This should match the architecture you used during training.
16
+ model = smp.Unet(
17
+ encoder_name="resnet34", # for example, resnet34 was used in training
18
+ encoder_weights="imagenet", # or you might have used pretrained weights from ImageNet
19
+ in_channels=3, # RGB images
20
+ classes=NUM_CLASSES # number of segmentation classes
21
+ )
22
+ # Load the state dict (mapping the keys appropriately)
23
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
24
  model.eval()
25
 
26
+ # Define preprocessing transforms (should match what was used during training)
27
+ preprocess = transforms.Compose([
28
  transforms.Resize((256, 256)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means
31
+ std=(0.229, 0.224, 0.225))
32
  ])
33
 
34
+ # Define the inference function
35
  def predict(image):
36
+ """
37
+ Accepts a PIL image, preprocesses it, runs the model,
38
+ and returns the predicted mask.
39
+ """
40
+ # Preprocess the image
41
+ input_tensor = preprocess(image).unsqueeze(0) # add batch dimension; shape: [1, 3, 256, 256]
42
+
43
  with torch.no_grad():
44
+ output = model(input_tensor)
45
+
46
+ # The output is typically raw logits.
47
+ # Take argmax along the channel dimension to get the predicted class per pixel.
48
+ pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
49
+ return pred_mask
50
 
51
+ # Create a Gradio interface
52
  iface = gr.Interface(
53
  fn=predict,
54
  inputs=gr.Image(type="pil"),
55
+ outputs=gr.Image(type="numpy"),
56
+ title="Oil Spill Segmentation",
57
+ description="Segment oil spills in aerial images."
58
  )
59
 
60
  iface.launch()