import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from unet import UNet from torchvision import transforms from PIL import Image from cvzone.FaceDetectionModule import FaceDetector import cv2 detector_face=FaceDetector() model = UNet(3,1) model.load_state_dict(torch.load("specs_det.pth")) model.eval() def face_detect(full_image): open_cv_image = np.array(full_image) # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() face,bboxs=detector_face.findFaces(open_cv_image) bbox = bboxs[0]['bbox'] x, y, w, h = bbox cropped_image = open_cv_image[y-5:y+h+5, x-5:x+w+5] img = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB) cropped_image = Image.fromarray(img) # print(bboxs) # print(face) return cropped_image def predict(image): transform_input = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor(), ]) transform_output = transforms.Compose([ transforms.Resize((256,256)), ]) image = face_detect(image) with torch.no_grad(): if transform_input: image = transform_input(image) image = image.unsqueeze(0) image = image.to(next(model.parameters()).device) output = model(image) output = torch.sigmoid(output) output = output.squeeze().cpu().numpy() output = (output > 0.5).astype(np.uint8) output = Image.fromarray(output * 255) if transform_output: output = transform_output(output) # plt.imshow(output) # plt.savefig("My figure") return output # Create the Gradio app app = gr.Interface( fn=predict, inputs=gr.Image(label="Input Image",type="pil"), outputs=gr.Image(label="Image with Segmentation",type="pil"), title = "Specs Segmenter", description="Segment image on the basis of glasses of a person", examples=[ '000368.jpg', '000411.jpg', '099899.jpg' ] ) # Run the app app.launch()