File size: 1,605 Bytes
9dfb0a4
 
 
 
8264471
 
 
 
9dfb0a4
 
 
8264471
 
9dfb0a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import numpy as np
import torch
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
import os
access_token = os.getenv('HF_TOKEN')
from huggingface_hub import login
login(token = access_token)

# Load the model from Hugging Face
model_name = "gdurkin/cdl_mask2former_hi_res_v3"
processor = Mask2FormerImageProcessor.from_pretrained(model_name,token = access_token)
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name,token = access_token)
device = torch.device('cpu')

# Define the inference function
def predict(img):
    if isinstance(img, np.ndarray):
        img = torch.from_numpy(img).float()
    if torch.is_tensor(img):
        input_tensor = img
    else:
        raise ValueError("Unsupported image format")

    if input_tensor.ndim == 3:
        input_tensor = input_tensor.unsqueeze(0)
    elif input_tensor.ndim != 4:
        raise ValueError("Input tensor must be 3D or 4D")
    
    input_tensor = input_tensor.permute(0, 3, 1, 2)  # Ensure the tensor is in the correct shape (N, C, H, W)
    
    with torch.no_grad():
        outputs = model(input_tensor.to(device))
    
    target_sizes = [(input_tensor.shape[2], input_tensor.shape[3])]
    predicted_segmentation_maps = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
    
    return predicted_segmentation_maps[0].cpu().numpy()

# Create a Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy", image_mode='RGB'),
    outputs="numpy",
    live=True
)

# Launch the interface
iface.launch()