remotewith's picture
Update app.py
6e32795 verified
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from unet import UNet
from torchvision import transforms
from PIL import Image
from cvzone.FaceDetectionModule import FaceDetector
import cv2
detector_face=FaceDetector()
model = UNet(3,1)
model.load_state_dict(torch.load("specs_det.pth"))
model.eval()
def face_detect(full_image):
open_cv_image = np.array(full_image)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
face,bboxs=detector_face.findFaces(open_cv_image)
bbox = bboxs[0]['bbox']
x, y, w, h = bbox
cropped_image = open_cv_image[y-5:y+h+5, x-5:x+w+5]
img = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
cropped_image = Image.fromarray(img)
# print(bboxs)
# print(face)
return cropped_image
def predict(image):
transform_input = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
])
transform_output = transforms.Compose([
transforms.Resize((256,256)),
])
image = face_detect(image)
with torch.no_grad():
if transform_input:
image = transform_input(image)
image = image.unsqueeze(0)
image = image.to(next(model.parameters()).device)
output = model(image)
output = torch.sigmoid(output)
output = output.squeeze().cpu().numpy()
output = (output > 0.5).astype(np.uint8)
output = Image.fromarray(output * 255)
if transform_output:
output = transform_output(output)
# plt.imshow(output)
# plt.savefig("My figure")
return output
# Create the Gradio app
app = gr.Interface(
fn=predict,
inputs=gr.Image(label="Input Image",type="pil"),
outputs=gr.Image(label="Image with Segmentation",type="pil"),
title = "Specs Segmenter",
description="Segment image on the basis of glasses of a person",
examples=[
'000368.jpg',
'000411.jpg',
'099899.jpg'
]
)
# Run the app
app.launch()