File size: 1,855 Bytes
c4bc1d1
 
 
 
 
 
 
 
 
 
a239782
c4bc1d1
 
 
 
49d2e89
c4bc1d1
 
 
 
49d2e89
 
c4bc1d1
49d2e89
 
 
 
 
 
 
 
 
39e4135
49d2e89
 
 
 
 
c4bc1d1
 
a239782
c4bc1d1
49d2e89
c4bc1d1
 
 
 
 
 
 
49d2e89
c4bc1d1
 
 
 
bbcc324
d5d90d2
49d2e89
 
 
d5d90d2
36711e2
 
d5d90d2
 
 
 
49d2e89
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python

from __future__ import annotations

import pathlib
import sys

import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
import torchvision.transforms as T
from huggingface_hub import hf_hub_download

sys.path.insert(0, "CelebAMask-HQ/face_parsing")

from unet import unet
from utils import generate_label

TITLE = "CelebAMask-HQ Face Parsing"
DESCRIPTION = "This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ."

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = T.Compose(
    [
        T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

path = hf_hub_download("public-data/CelebAMask-HQ-Face-Parsing", "models/model.pth")
state_dict = torch.load(path, map_location="cpu")
model = unet()
model.load_state_dict(state_dict)
model.eval()
model.to(device)


@spaces.GPU
@torch.inference_mode()
def predict(image: PIL.Image.Image) -> np.ndarray:
    data = transform(image)
    data = data.unsqueeze(0).to(device)
    out = model(data)
    out = generate_label(out, 512)
    out = out[0].cpu().numpy().transpose(1, 2, 0)
    out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8)

    res = np.asarray(image.resize((512, 512))).astype(float) * 0.5 + out.astype(float) * 0.5
    res = np.clip(np.round(res), 0, 255).astype(np.uint8)
    return out, res


examples = sorted(pathlib.Path("images").glob("*.jpg"))

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(label="Input", type="pil"),
    outputs=[
        gr.Image(label="Predicted Labels"),
        gr.Image(label="Masked"),
    ],
    examples=examples,
    title=TITLE,
    description=DESCRIPTION,
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()