Spaces:
Sleeping
Sleeping
File size: 2,081 Bytes
ce4c34e e6d39f4 ce4c34e 11680ba ce4c34e 28ec3f5 6e32795 ce4c34e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from unet import UNet
from torchvision import transforms
from PIL import Image
from cvzone.FaceDetectionModule import FaceDetector
import cv2
detector_face=FaceDetector()
model = UNet(3,1)
model.load_state_dict(torch.load("specs_det.pth"))
model.eval()
def face_detect(full_image):
open_cv_image = np.array(full_image)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
face,bboxs=detector_face.findFaces(open_cv_image)
bbox = bboxs[0]['bbox']
x, y, w, h = bbox
cropped_image = open_cv_image[y-5:y+h+5, x-5:x+w+5]
img = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
cropped_image = Image.fromarray(img)
# print(bboxs)
# print(face)
return cropped_image
def predict(image):
transform_input = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
])
transform_output = transforms.Compose([
transforms.Resize((256,256)),
])
image = face_detect(image)
with torch.no_grad():
if transform_input:
image = transform_input(image)
image = image.unsqueeze(0)
image = image.to(next(model.parameters()).device)
output = model(image)
output = torch.sigmoid(output)
output = output.squeeze().cpu().numpy()
output = (output > 0.5).astype(np.uint8)
output = Image.fromarray(output * 255)
if transform_output:
output = transform_output(output)
# plt.imshow(output)
# plt.savefig("My figure")
return output
# Create the Gradio app
app = gr.Interface(
fn=predict,
inputs=gr.Image(label="Input Image",type="pil"),
outputs=gr.Image(label="Image with Segmentation",type="pil"),
title = "Specs Segmenter",
description="Segment image on the basis of glasses of a person",
examples=[
'000368.jpg',
'000411.jpg',
'099899.jpg'
]
)
# Run the app
app.launch() |