Spaces:
Runtime error
Runtime error
Update
Browse files
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
|
|
4 |
colorFrom: green
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
4 |
colorFrom: green
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.34.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
CHANGED
@@ -10,11 +10,11 @@ import subprocess
|
|
10 |
import tarfile
|
11 |
|
12 |
if os.getenv('SYSTEM') == 'spaces':
|
13 |
-
subprocess.
|
14 |
shlex.split(
|
15 |
'pip install git+https://github.com/facebookresearch/[email protected]'
|
16 |
))
|
17 |
-
subprocess.
|
18 |
shlex.split(
|
19 |
'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'))
|
20 |
|
@@ -27,13 +27,9 @@ from detectron2.data.detection_utils import read_image
|
|
27 |
from detectron2.engine.defaults import DefaultPredictor
|
28 |
from detectron2.utils.visualizer import Visualizer
|
29 |
|
30 |
-
|
31 |
-
DESCRIPTION = 'This is an unofficial demo for https://github.com/zymk9/Yet-Another-Anime-Segmenter.'
|
32 |
|
33 |
-
|
34 |
-
MODEL_REPO = 'hysts/Yet-Another-Anime-Segmenter'
|
35 |
-
MODEL_FILENAME = 'SOLOv2.pth'
|
36 |
-
CONFIG_FILENAME = 'SOLOv2.yaml'
|
37 |
|
38 |
|
39 |
def load_sample_image_paths() -> list[pathlib.Path]:
|
@@ -42,20 +38,15 @@ def load_sample_image_paths() -> list[pathlib.Path]:
|
|
42 |
dataset_repo = 'hysts/sample-images-TADNE'
|
43 |
path = huggingface_hub.hf_hub_download(dataset_repo,
|
44 |
'images.tar.gz',
|
45 |
-
repo_type='dataset'
|
46 |
-
use_auth_token=HF_TOKEN)
|
47 |
with tarfile.open(path) as f:
|
48 |
f.extractall()
|
49 |
return sorted(image_dir.glob('*'))
|
50 |
|
51 |
|
52 |
def load_model(device: torch.device) -> DefaultPredictor:
|
53 |
-
config_path = huggingface_hub.hf_hub_download(MODEL_REPO,
|
54 |
-
|
55 |
-
use_auth_token=HF_TOKEN)
|
56 |
-
model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
|
57 |
-
MODEL_FILENAME,
|
58 |
-
use_auth_token=HF_TOKEN)
|
59 |
cfg = get_cfg()
|
60 |
cfg.merge_from_file(config_path)
|
61 |
cfg.MODEL.WEIGHTS = model_path
|
@@ -90,28 +81,34 @@ examples = [[path.as_posix(), 0.1, 0.5] for path in image_paths]
|
|
90 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
91 |
model = load_model(device)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
gr.
|
96 |
-
|
97 |
-
|
98 |
-
gr.
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
gr.
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import tarfile
|
11 |
|
12 |
if os.getenv('SYSTEM') == 'spaces':
|
13 |
+
subprocess.run(
|
14 |
shlex.split(
|
15 |
'pip install git+https://github.com/facebookresearch/[email protected]'
|
16 |
))
|
17 |
+
subprocess.run(
|
18 |
shlex.split(
|
19 |
'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'))
|
20 |
|
|
|
27 |
from detectron2.engine.defaults import DefaultPredictor
|
28 |
from detectron2.utils.visualizer import Visualizer
|
29 |
|
30 |
+
DESCRIPTION = '# [Yet-Another-Anime-Segmenter](https://github.com/zymk9/Yet-Another-Anime-Segmenter)'
|
|
|
31 |
|
32 |
+
MODEL_REPO = 'public-data/Yet-Another-Anime-Segmenter'
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
def load_sample_image_paths() -> list[pathlib.Path]:
|
|
|
38 |
dataset_repo = 'hysts/sample-images-TADNE'
|
39 |
path = huggingface_hub.hf_hub_download(dataset_repo,
|
40 |
'images.tar.gz',
|
41 |
+
repo_type='dataset')
|
|
|
42 |
with tarfile.open(path) as f:
|
43 |
f.extractall()
|
44 |
return sorted(image_dir.glob('*'))
|
45 |
|
46 |
|
47 |
def load_model(device: torch.device) -> DefaultPredictor:
|
48 |
+
config_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'SOLOv2.yaml')
|
49 |
+
model_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'SOLOv2.pth')
|
|
|
|
|
|
|
|
|
50 |
cfg = get_cfg()
|
51 |
cfg.merge_from_file(config_path)
|
52 |
cfg.MODEL.WEIGHTS = model_path
|
|
|
81 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
82 |
model = load_model(device)
|
83 |
|
84 |
+
fn = functools.partial(predict, model=model)
|
85 |
+
|
86 |
+
with gr.Blocks(css='style.css') as demo:
|
87 |
+
gr.Markdown(DESCRIPTION)
|
88 |
+
with gr.Row():
|
89 |
+
with gr.Column():
|
90 |
+
image = gr.Image(label='Input', type='filepath')
|
91 |
+
class_score_threshold = gr.Slider(label='Score Threshold',
|
92 |
+
minimum=0,
|
93 |
+
maximum=1,
|
94 |
+
step=0.05,
|
95 |
+
value=0.1)
|
96 |
+
mask_score_threshold = gr.Slider(label='Mask Score Threshold',
|
97 |
+
minimum=0,
|
98 |
+
maximum=1,
|
99 |
+
step=0.05,
|
100 |
+
value=0.5)
|
101 |
+
run_button = gr.Button('Run')
|
102 |
+
with gr.Column():
|
103 |
+
result_instances = gr.Image(label='Instances')
|
104 |
+
result_masked = gr.Image(label='Masked')
|
105 |
+
|
106 |
+
inputs = [image, class_score_threshold, mask_score_threshold]
|
107 |
+
outputs = [result_instances, result_masked]
|
108 |
+
gr.Examples(examples=examples,
|
109 |
+
inputs=inputs,
|
110 |
+
outputs=outputs,
|
111 |
+
fn=fn,
|
112 |
+
cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
|
113 |
+
run_button.click(fn=fn, inputs=inputs, outputs=outputs, api_name='predict')
|
114 |
+
demo.queue(max_size=15).launch()
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
opencv-python-headless==4.
|
2 |
-
torch==1.
|
3 |
-
torchvision==0.
|
|
|
1 |
+
opencv-python-headless==4.7.0.72
|
2 |
+
torch==1.13.1
|
3 |
+
torchvision==0.14.1
|
style.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|