File size: 2,081 Bytes
ce4c34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6d39f4
ce4c34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11680ba
ce4c34e
 
28ec3f5
6e32795
 
ce4c34e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from unet import UNet
from torchvision import transforms
from PIL import Image
from cvzone.FaceDetectionModule import FaceDetector
import cv2

detector_face=FaceDetector()

model = UNet(3,1)
model.load_state_dict(torch.load("specs_det.pth"))
model.eval()

def face_detect(full_image):
    open_cv_image = np.array(full_image)
    # Convert RGB to BGR
    open_cv_image = open_cv_image[:, :, ::-1].copy()


    face,bboxs=detector_face.findFaces(open_cv_image)

    bbox = bboxs[0]['bbox'] 
    x, y, w, h = bbox        

    cropped_image = open_cv_image[y-5:y+h+5, x-5:x+w+5]

    img = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
    cropped_image = Image.fromarray(img)


    # print(bboxs)
    # print(face)

    return cropped_image




def predict(image):

    transform_input = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor(),
    ])

    transform_output = transforms.Compose([
        transforms.Resize((256,256)),
    ])

    image = face_detect(image)


    with torch.no_grad():



        if transform_input:
            image = transform_input(image)
        image = image.unsqueeze(0)

        image = image.to(next(model.parameters()).device)

        output = model(image)

        output = torch.sigmoid(output)

        output = output.squeeze().cpu().numpy()

        output = (output > 0.5).astype(np.uint8)

        output = Image.fromarray(output * 255)

        if transform_output:
            output = transform_output(output)
        

    # plt.imshow(output)
    # plt.savefig("My figure")

    return output


# Create the Gradio app
app = gr.Interface(
    fn=predict,
    inputs=gr.Image(label="Input Image",type="pil"),
    outputs=gr.Image(label="Image with Segmentation",type="pil"),
    title = "Specs Segmenter",
    description="Segment image on the basis of glasses of a person",
    examples=[
        '000368.jpg',
        '000411.jpg',
        '099899.jpg'
    ]
)
                   
# Run the app
app.launch()