thinh-researcher commited on
Commit
c2bfa14
·
1 Parent(s): 6d933f3
Files changed (5) hide show
  1. .gitignore +2 -1
  2. app.py +2 -17
  3. inference.py +36 -0
  4. prepare_samples.py +5 -3
  5. requirements.txt +5 -0
.gitignore CHANGED
@@ -2,4 +2,5 @@ env
2
  __pycache__
3
  data
4
  samples/*
5
- !samples/.gitkeep
 
 
2
  __pycache__
3
  data
4
  samples/*
5
+ !samples/.gitkeep
6
+ model_cache
app.py CHANGED
@@ -1,22 +1,7 @@
1
- from typing import Tuple
2
  import gradio as gr
3
- import os
4
- from PIL import Image
5
- from datasets import load_dataset
6
- from prepare_samples import prepare_samples
7
-
8
- DIR_PATH = os.path.dirname(__file__)
9
-
10
-
11
- def inference(rgb: Image.Image, depth: Image.Image) -> Tuple[Image.Image]:
12
- return rgb
13
 
14
-
15
- dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data")
16
-
17
- # with gr.Blocks() as demo:
18
- # with gr.Row(elem_id="center"):
19
- # gr.Markdown("# BBS-Net Demo")
20
 
21
  TITLE = "BBS-Net Demo"
22
  DESCRIPTION = "Gradio demo for BBS-Net: RGB-D salient object detection with a bifurcated backbone strategy network."
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from inference import inference
4
+ from prepare_samples import prepare_samples
 
 
 
 
5
 
6
  TITLE = "BBS-Net Demo"
7
  DESCRIPTION = "Gradio demo for BBS-Net: RGB-D salient object detection with a bifurcated backbone strategy network."
inference.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModel
2
+ from typing import Dict
3
+
4
+ import numpy as np
5
+ from matplotlib import cm
6
+ from PIL import Image
7
+ from torch import Tensor
8
+
9
+ model = AutoModel.from_pretrained(
10
+ "RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="model_cache"
11
+ )
12
+ image_processor = AutoImageProcessor.from_pretrained(
13
+ "RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="image_processor_cache"
14
+ )
15
+
16
+
17
+ def inference(rgb: Image.Image, depth: Image.Image) -> Image.Image:
18
+ preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess(
19
+ {
20
+ "rgb": rgb,
21
+ "depth": depth,
22
+ }
23
+ )
24
+
25
+ output: Dict[str, Tensor] = model(
26
+ preprocessed_sample["rgb"], preprocessed_sample["depth"]
27
+ )
28
+ postprocessed_sample: np.ndarray = image_processor.postprocess(
29
+ output["logits"], [rgb.size[1], rgb.size[0]]
30
+ )
31
+ prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255))
32
+ return prediction
33
+
34
+
35
+ if __name__ == "__main__":
36
+ pass
prepare_samples.py CHANGED
@@ -1,8 +1,10 @@
1
- from typing import List, Tuple
2
- from datasets import load_dataset
3
- from PIL import Image
4
  import os
5
  import shutil
 
 
 
 
 
6
 
7
  dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data")
8
  SAMPLES_DIR = "samples"
 
 
 
 
1
  import os
2
  import shutil
3
+ from typing import List, Tuple
4
+
5
+ from PIL import Image
6
+ from datasets import load_dataset
7
+
8
 
9
  dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data")
10
  SAMPLES_DIR = "samples"
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
  gradio
 
 
 
 
2
  datasets
 
 
1
  gradio
2
+ torch
3
+ opencv-python
4
+ transformers[torch]
5
+ torchvision
6
  datasets
7
+ matplotlib