Spaces:
Running
Running
File size: 3,149 Bytes
c570d14 aa3f835 167fab8 c570d14 167fab8 c570d14 4f1095a c570d14 4f1095a aa3f835 c570d14 aa3f835 4f1095a aa3f835 4f1095a aa3f835 167fab8 4f1095a aa3f835 4f1095a aa3f835 4f1095a aa3f835 4f1095a 167fab8 c570d14 4f1095a c570d14 167fab8 c570d14 77bce91 4f1095a 77bce91 4f1095a 77bce91 4f1095a 7ad735d 167fab8 7ad735d 4f1095a 7ad735d 4f1095a 7ad735d 4f1095a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
#!/usr/bin/env python
import functools
import json
import os
import pathlib
import tarfile
from collections.abc import Callable
import gradio as gr
import huggingface_hub
import PIL.Image
import torch
import torchvision.transforms as T # noqa: N812
DESCRIPTION = "# [RF5/danbooru-pretrained](https://github.com/RF5/danbooru-pretrained)"
MODEL_REPO = "public-data/danbooru-pretrained"
def load_sample_image_paths() -> list[pathlib.Path]:
image_dir = pathlib.Path("images")
if not image_dir.exists():
dataset_repo = "hysts/sample-images-TADNE"
path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
with tarfile.open(path) as f:
f.extractall() # noqa: S202
return sorted(image_dir.glob("*"))
def load_model(device: torch.device) -> torch.nn.Module:
path = huggingface_hub.hf_hub_download(MODEL_REPO, "resnet50-13306192.pth")
state_dict = torch.load(path)
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50", pretrained=False)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def load_labels() -> list[str]:
path = huggingface_hub.hf_hub_download(MODEL_REPO, "class_names_6000.json")
with pathlib.Path(path).open() as f:
return json.load(f)
@torch.inference_mode()
def predict(
image: PIL.Image.Image,
score_threshold: float,
transform: Callable,
device: torch.device,
model: torch.nn.Module,
labels: list[str],
) -> dict[str, float]:
data = transform(image)
data = data.to(device).unsqueeze(0)
preds = model(data)[0]
preds = torch.sigmoid(preds)
preds = preds.cpu().numpy().astype(float)
res = {}
for prob, label in zip(preds.tolist(), labels, strict=True):
if prob < score_threshold:
continue
res[label] = prob
return res
image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.4] for path in image_paths]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)
labels = load_labels()
transform = T.Compose(
[
T.Resize(360),
T.ToTensor(),
T.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
]
)
fn = functools.partial(predict, transform=transform, device=device, model=model, labels=labels)
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(label="Input", type="pil")
threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.4)
run_button = gr.Button()
with gr.Column():
result = gr.Label(label="Output")
inputs = [image, threshold]
gr.Examples(
examples=examples,
inputs=inputs,
outputs=result,
fn=fn,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
)
run_button.click(
fn=fn,
inputs=inputs,
outputs=result,
api_name="predict",
)
if __name__ == "__main__":
demo.queue(max_size=15).launch()
|