VistaDream / ops /sky.py
hpwang's picture
[Init]
fd5e0f7
raw
history blame
964 Bytes
import torch
import numpy as np
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
class Sky_Seg_Tool():
def __init__(self,cfg):
self.processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_large")
self.model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_large")
def __call__(self, img):
'''
input rgb should be numpy in range of 0-1 or 0-255
'''
# Semantic Segmentation
if np.amax(img) < 2: img = img*255
inputs = self.processor(images=img, task_inputs=["semantic"], return_tensors="pt")
outputs = self.model(**inputs)
# pass through image_processor for postprocessing
predicted_semantic_map = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[img.size[::-1]])[0]
sky_msk = predicted_semantic_map.numpy() == 2
return sky_msk