bbsnet / inference.py
thinh-researcher's picture
Fixed
6e2081d
raw
history blame
1.09 kB
from transformers import AutoImageProcessor, AutoModel
from typing import Dict
import numpy as np
from matplotlib import cm
from PIL import Image
from torch import Tensor
model = AutoModel.from_pretrained(
"RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="model_cache"
)
image_processor = AutoImageProcessor.from_pretrained(
"RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="image_processor_cache"
)
def inference(rgb: Image.Image, depth: Image.Image) -> Image.Image:
rgb = rgb.convert(mode="RGB")
depth = depth.convert(mode="L")
preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess(
{
"rgb": rgb,
"depth": depth,
}
)
output: Dict[str, Tensor] = model(
preprocessed_sample["rgb"], preprocessed_sample["depth"]
)
postprocessed_sample: np.ndarray = image_processor.postprocess(
output["logits"], [rgb.size[1], rgb.size[0]]
)
prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255))
return prediction
if __name__ == "__main__":
pass