File size: 2,131 Bytes
ce4c34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from unet import UNet
from torchvision import transforms
from PIL import Image
from cvzone.FaceDetectionModule import FaceDetector
import cv2

detector_face=FaceDetector()

model = UNet(3,1)
model.load_state_dict(torch.load("specs_det.pth"))
model.eval()

def face_detect(full_image):
    open_cv_image = np.array(full_image)
    # Convert RGB to BGR
    open_cv_image = open_cv_image[:, :, ::-1].copy()


    face,bboxs=detector_face.findFaces(open_cv_image)

    bbox = bboxs[0]['bbox'] 
    x, y, w, h = bbox        

    cropped_image = open_cv_image[y-10:y+h+10, x-10:x+w+10]

    img = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
    cropped_image = Image.fromarray(img)


    # print(bboxs)
    # print(face)

    return cropped_image




def predict(image):

    transform_input = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor(),
    ])

    transform_output = transforms.Compose([
        transforms.Resize((256,256)),
    ])

    image = face_detect(image)


    with torch.no_grad():



        if transform_input:
            image = transform_input(image)
        image = image.unsqueeze(0)

        image = image.to(next(model.parameters()).device)

        output = model(image)

        output = torch.sigmoid(output)

        output = output.squeeze().cpu().numpy()

        output = (output > 0.5).astype(np.uint8)

        output = Image.fromarray(output * 255)

        if transform_output:
            output = transform_output(output)
        

    # plt.imshow(output)
    # plt.savefig("My figure")

    return output


# Create the Gradio app
app = gr.Interface(
    fn=predict,
    inputs=gr.Image(label="Input Image",type="pil"),
    outputs=gr.Image(label="Image with Segmentation",type="pil"),
    title = "Kamehamehaa",
    description="Segment image on the basis of glasses of a person",
    examples=[
        'face-synthetics-glasses/test/images/000368.jpg',
        'face-synthetics-glasses/test/images/000411.jpg'
    ]
)
                   
# Run the app
app.launch()