bbsnet / inference.py
thinh-researcher's picture
Update
c2bfa14
raw
history blame
1.02 kB
from transformers import AutoImageProcessor, AutoModel
from typing import Dict
import numpy as np
from matplotlib import cm
from PIL import Image
from torch import Tensor
model = AutoModel.from_pretrained(
"RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="model_cache"
)
image_processor = AutoImageProcessor.from_pretrained(
"RGBD-SOD/bbsnet", trust_remote_code=True, cache_dir="image_processor_cache"
)
def inference(rgb: Image.Image, depth: Image.Image) -> Image.Image:
preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess(
{
"rgb": rgb,
"depth": depth,
}
)
output: Dict[str, Tensor] = model(
preprocessed_sample["rgb"], preprocessed_sample["depth"]
)
postprocessed_sample: np.ndarray = image_processor.postprocess(
output["logits"], [rgb.size[1], rgb.size[0]]
)
prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255))
return prediction
if __name__ == "__main__":
pass