import torch import numpy as np from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation class Sky_Seg_Tool(): def __init__(self,cfg): self.processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_large") self.model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_large") def __call__(self, img): ''' input rgb should be numpy in range of 0-1 or 0-255 ''' # Semantic Segmentation if np.amax(img) < 2: img = img*255 inputs = self.processor(images=img, task_inputs=["semantic"], return_tensors="pt") outputs = self.model(**inputs) # pass through image_processor for postprocessing predicted_semantic_map = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[img.size[::-1]])[0] sky_msk = predicted_semantic_map.numpy() == 2 return sky_msk