TheArchitect416 commited on
Commit
bdb9d5d
·
verified ·
1 Parent(s): c4967a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ # Load model
8
+ class MySegmentationModel(torch.nn.Module):
9
+ def __init__(self):
10
+ super(MySegmentationModel, self).__init__()
11
+ # Define your model architecture here
12
+ self.dummy_layer = torch.nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1)
13
+
14
+ def forward(self, x):
15
+ return self.dummy_layer(x)
16
+
17
+ model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth")
18
+ model = MySegmentationModel()
19
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
20
+ model.eval()
21
+
22
+ # Define preprocessing
23
+ transform = transforms.Compose([
24
+ transforms.Resize((256, 256)),
25
+ transforms.ToTensor()
26
+ ])
27
+
28
+ # Define inference function
29
+ def predict(image):
30
+ image = transform(image).unsqueeze(0) # Add batch dimension
31
+ with torch.no_grad():
32
+ output = model(image)
33
+ return output.squeeze(0).numpy() # Convert to numpy for visualization
34
+
35
+ # Create Gradio interface
36
+ iface = gr.Interface(
37
+ fn=predict,
38
+ inputs=gr.Image(type="pil"),
39
+ outputs="image"
40
+ )
41
+
42
+ iface.launch()