File size: 964 Bytes
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import numpy as np
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation

class Sky_Seg_Tool():
    def __init__(self,cfg):
        self.processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_large")
        self.model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_large")
    
    def __call__(self, img):
        '''
        input rgb should be numpy in range of 0-1 or 0-255
        '''
        # Semantic Segmentation
        if np.amax(img) < 2: img = img*255
        inputs = self.processor(images=img, task_inputs=["semantic"], return_tensors="pt")
        outputs = self.model(**inputs)
        # pass through image_processor for postprocessing
        predicted_semantic_map = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[img.size[::-1]])[0]
        sky_msk = predicted_semantic_map.numpy() == 2
        return sky_msk