Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from unet import UNet | |
from torchvision import transforms | |
from PIL import Image | |
from cvzone.FaceDetectionModule import FaceDetector | |
import cv2 | |
detector_face=FaceDetector() | |
model = UNet(3,1) | |
model.load_state_dict(torch.load("specs_det.pth")) | |
model.eval() | |
def face_detect(full_image): | |
open_cv_image = np.array(full_image) | |
# Convert RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
face,bboxs=detector_face.findFaces(open_cv_image) | |
bbox = bboxs[0]['bbox'] | |
x, y, w, h = bbox | |
cropped_image = open_cv_image[y-10:y+h+10, x-10:x+w+10] | |
img = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB) | |
cropped_image = Image.fromarray(img) | |
# print(bboxs) | |
# print(face) | |
return cropped_image | |
def predict(image): | |
transform_input = transforms.Compose([ | |
transforms.Resize((256,256)), | |
transforms.ToTensor(), | |
]) | |
transform_output = transforms.Compose([ | |
transforms.Resize((256,256)), | |
]) | |
image = face_detect(image) | |
with torch.no_grad(): | |
if transform_input: | |
image = transform_input(image) | |
image = image.unsqueeze(0) | |
image = image.to(next(model.parameters()).device) | |
output = model(image) | |
output = torch.sigmoid(output) | |
output = output.squeeze().cpu().numpy() | |
output = (output > 0.5).astype(np.uint8) | |
output = Image.fromarray(output * 255) | |
if transform_output: | |
output = transform_output(output) | |
# plt.imshow(output) | |
# plt.savefig("My figure") | |
return output | |
# Create the Gradio app | |
app = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(label="Input Image",type="pil"), | |
outputs=gr.Image(label="Image with Segmentation",type="pil"), | |
title = "Kamehamehaa", | |
description="Segment image on the basis of glasses of a person", | |
examples=[ | |
'face-synthetics-glasses/test/images/000368.jpg', | |
'face-synthetics-glasses/test/images/000411.jpg' | |
] | |
) | |
# Run the app | |
app.launch() |