Spaces:
Build error
Build error
import torch | |
import numpy as np | |
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation | |
class Sky_Seg_Tool(): | |
def __init__(self,cfg): | |
self.processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_large") | |
self.model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_large") | |
def __call__(self, img): | |
''' | |
input rgb should be numpy in range of 0-1 or 0-255 | |
''' | |
# Semantic Segmentation | |
if np.amax(img) < 2: img = img*255 | |
inputs = self.processor(images=img, task_inputs=["semantic"], return_tensors="pt") | |
outputs = self.model(**inputs) | |
# pass through image_processor for postprocessing | |
predicted_semantic_map = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[img.size[::-1]])[0] | |
sky_msk = predicted_semantic_map.numpy() == 2 | |
return sky_msk | |