Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,855 Bytes
c4bc1d1 a239782 c4bc1d1 49d2e89 c4bc1d1 49d2e89 c4bc1d1 49d2e89 39e4135 49d2e89 c4bc1d1 a239782 c4bc1d1 49d2e89 c4bc1d1 49d2e89 c4bc1d1 bbcc324 d5d90d2 49d2e89 d5d90d2 36711e2 d5d90d2 49d2e89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
#!/usr/bin/env python
from __future__ import annotations
import pathlib
import sys
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
import torchvision.transforms as T
from huggingface_hub import hf_hub_download
sys.path.insert(0, "CelebAMask-HQ/face_parsing")
from unet import unet
from utils import generate_label
TITLE = "CelebAMask-HQ Face Parsing"
DESCRIPTION = "This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ."
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = T.Compose(
[
T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
path = hf_hub_download("public-data/CelebAMask-HQ-Face-Parsing", "models/model.pth")
state_dict = torch.load(path, map_location="cpu")
model = unet()
model.load_state_dict(state_dict)
model.eval()
model.to(device)
@spaces.GPU
@torch.inference_mode()
def predict(image: PIL.Image.Image) -> np.ndarray:
data = transform(image)
data = data.unsqueeze(0).to(device)
out = model(data)
out = generate_label(out, 512)
out = out[0].cpu().numpy().transpose(1, 2, 0)
out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8)
res = np.asarray(image.resize((512, 512))).astype(float) * 0.5 + out.astype(float) * 0.5
res = np.clip(np.round(res), 0, 255).astype(np.uint8)
return out, res
examples = sorted(pathlib.Path("images").glob("*.jpg"))
demo = gr.Interface(
fn=predict,
inputs=gr.Image(label="Input", type="pil"),
outputs=[
gr.Image(label="Predicted Labels"),
gr.Image(label="Masked"),
],
examples=examples,
title=TITLE,
description=DESCRIPTION,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|