from huggingface_hub import from_pretrained_fastai import gradio as gr from fastai.vision.all import * class TargetMaskConvertTransform(ItemTransform): def __init__(self): pass def encodes(self, x): img,mask = x #Convert to array mask = np.array(mask) # Uvas mask[mask==255]=1 # Hojas mask[mask==150]=2 # Poste mask[mask==76]=3 mask[mask==74]=3 # Madera mask[mask==29]=4 mask[mask==25]=4 # Back to PILMask mask = PILMask.create(mask) return img, mask repo_id = "ancebuc/grapes-segmentation" learner = from_pretrained_fastai(repo_id) labels = learner.dls.vocab device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.jit.load("unet.pth") model = model.cpu() model.eval() import torchvision.transforms as transforms def transform_image(image): my_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image_aux = image return my_transforms(image_aux).unsqueeze(0).to(device) # Definimos una función que se encarga de llevar a cabo las predicciones def predict(img): img = PILImage.create(img) image = transforms.Resize((480,640))(img) tensor = transform_image(image=image) with torch.no_grad(): outputs = model(tensor) outputs = torch.argmax(outputs,1) mask = np.array(outputs.cpu()) mask = np.reshape(mask,(480,640)) # Añadimos una dimesionalidad para colocar color mask = np.expand_dims(mask, axis=2) # Y añadimos los tres canales mask = np.repeat(mask, 3, axis=2) # Creamos las máscaras uvas = np.all(mask == [1, 1, 1], axis=2) hojas = np.all(mask == [2, 2, 2], axis=2) poste = np.all(mask == [3, 3, 3], axis=2) madera = np.all(mask == [4, 4, 4], axis=2) # Uvas mask[uvas] = [255, 255, 255] # Hojas mask[hojas] = [0, 255, 0] # Poste mask[poste] = [0, 0, 255] # Madera mask[madera] = [255, 0, 0] return Image.fromarray(mask.astype('uint8')) # Creamos la interfaz y la lanzamos. gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.inputs.Image(shape=(128, 128)),examples=['color_158.jpg','color_157.jpg']).launch(share=False)