import torch from u2net import U2NET from torchvision import transforms import numpy as np from PIL import Image import torch.nn.functional as F import data_transforms # Load the model def load_model(): model = U2NET(3, 1) model.load_state_dict(torch.load("u2net.pth", map_location="cpu")) model.eval() return model # Preprocessing function (same as you defined locally) def preprocess(image): transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)]) label_3 = np.zeros(image.shape) label = np.zeros(label_3.shape[0:2]) sample = transform({"imidx": np.array([0]), "image": image, "label": label}) return sample # Inference function def infer(model, image): input_size = [1024, 1024] im_shp = image.shape[0:2] im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8) image = torch.divide(im_tensor, 255.0) result = model(image) result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0) result = (result - result.min()) / (result.max() - result.min()) return result.numpy()