hysts HF staff commited on
Commit
08eb34b
·
1 Parent(s): 49ff668
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +38 -41
  3. requirements.txt +3 -3
  4. style.css +3 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -10,11 +10,11 @@ import subprocess
10
  import tarfile
11
 
12
  if os.getenv('SYSTEM') == 'spaces':
13
- subprocess.call(
14
  shlex.split(
15
  'pip install git+https://github.com/facebookresearch/[email protected]'
16
  ))
17
- subprocess.call(
18
  shlex.split(
19
  'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'))
20
 
@@ -27,13 +27,9 @@ from detectron2.data.detection_utils import read_image
27
  from detectron2.engine.defaults import DefaultPredictor
28
  from detectron2.utils.visualizer import Visualizer
29
 
30
- TITLE = 'Yet-Another-Anime-Segmenter'
31
- DESCRIPTION = 'This is an unofficial demo for https://github.com/zymk9/Yet-Another-Anime-Segmenter.'
32
 
33
- HF_TOKEN = os.getenv('HF_TOKEN')
34
- MODEL_REPO = 'hysts/Yet-Another-Anime-Segmenter'
35
- MODEL_FILENAME = 'SOLOv2.pth'
36
- CONFIG_FILENAME = 'SOLOv2.yaml'
37
 
38
 
39
  def load_sample_image_paths() -> list[pathlib.Path]:
@@ -42,20 +38,15 @@ def load_sample_image_paths() -> list[pathlib.Path]:
42
  dataset_repo = 'hysts/sample-images-TADNE'
43
  path = huggingface_hub.hf_hub_download(dataset_repo,
44
  'images.tar.gz',
45
- repo_type='dataset',
46
- use_auth_token=HF_TOKEN)
47
  with tarfile.open(path) as f:
48
  f.extractall()
49
  return sorted(image_dir.glob('*'))
50
 
51
 
52
  def load_model(device: torch.device) -> DefaultPredictor:
53
- config_path = huggingface_hub.hf_hub_download(MODEL_REPO,
54
- CONFIG_FILENAME,
55
- use_auth_token=HF_TOKEN)
56
- model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
57
- MODEL_FILENAME,
58
- use_auth_token=HF_TOKEN)
59
  cfg = get_cfg()
60
  cfg.merge_from_file(config_path)
61
  cfg.MODEL.WEIGHTS = model_path
@@ -90,28 +81,34 @@ examples = [[path.as_posix(), 0.1, 0.5] for path in image_paths]
90
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
91
  model = load_model(device)
92
 
93
- func = functools.partial(predict, model=model)
94
-
95
- gr.Interface(
96
- fn=func,
97
- inputs=[
98
- gr.Image(label='Input', type='filepath'),
99
- gr.Slider(label='Class Score Threshold',
100
- minimum=0,
101
- maximum=1,
102
- step=0.05,
103
- value=0.1),
104
- gr.Slider(label='Mask Score Threshold',
105
- minimum=0,
106
- maximum=1,
107
- step=0.05,
108
- default=0.5),
109
- ],
110
- outputs=[
111
- gr.Image(label='Instances'),
112
- gr.Image(label='Masked'),
113
- ],
114
- examples=examples,
115
- title=TITLE,
116
- description=DESCRIPTION,
117
- ).queue().launch(show_api=False)
 
 
 
 
 
 
 
10
  import tarfile
11
 
12
  if os.getenv('SYSTEM') == 'spaces':
13
+ subprocess.run(
14
  shlex.split(
15
  'pip install git+https://github.com/facebookresearch/[email protected]'
16
  ))
17
+ subprocess.run(
18
  shlex.split(
19
  'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'))
20
 
 
27
  from detectron2.engine.defaults import DefaultPredictor
28
  from detectron2.utils.visualizer import Visualizer
29
 
30
+ DESCRIPTION = '# [Yet-Another-Anime-Segmenter](https://github.com/zymk9/Yet-Another-Anime-Segmenter)'
 
31
 
32
+ MODEL_REPO = 'public-data/Yet-Another-Anime-Segmenter'
 
 
 
33
 
34
 
35
  def load_sample_image_paths() -> list[pathlib.Path]:
 
38
  dataset_repo = 'hysts/sample-images-TADNE'
39
  path = huggingface_hub.hf_hub_download(dataset_repo,
40
  'images.tar.gz',
41
+ repo_type='dataset')
 
42
  with tarfile.open(path) as f:
43
  f.extractall()
44
  return sorted(image_dir.glob('*'))
45
 
46
 
47
  def load_model(device: torch.device) -> DefaultPredictor:
48
+ config_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'SOLOv2.yaml')
49
+ model_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'SOLOv2.pth')
 
 
 
 
50
  cfg = get_cfg()
51
  cfg.merge_from_file(config_path)
52
  cfg.MODEL.WEIGHTS = model_path
 
81
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
82
  model = load_model(device)
83
 
84
+ fn = functools.partial(predict, model=model)
85
+
86
+ with gr.Blocks(css='style.css') as demo:
87
+ gr.Markdown(DESCRIPTION)
88
+ with gr.Row():
89
+ with gr.Column():
90
+ image = gr.Image(label='Input', type='filepath')
91
+ class_score_threshold = gr.Slider(label='Score Threshold',
92
+ minimum=0,
93
+ maximum=1,
94
+ step=0.05,
95
+ value=0.1)
96
+ mask_score_threshold = gr.Slider(label='Mask Score Threshold',
97
+ minimum=0,
98
+ maximum=1,
99
+ step=0.05,
100
+ value=0.5)
101
+ run_button = gr.Button('Run')
102
+ with gr.Column():
103
+ result_instances = gr.Image(label='Instances')
104
+ result_masked = gr.Image(label='Masked')
105
+
106
+ inputs = [image, class_score_threshold, mask_score_threshold]
107
+ outputs = [result_instances, result_masked]
108
+ gr.Examples(examples=examples,
109
+ inputs=inputs,
110
+ outputs=outputs,
111
+ fn=fn,
112
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
113
+ run_button.click(fn=fn, inputs=inputs, outputs=outputs, api_name='predict')
114
+ demo.queue(max_size=15).launch()
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- opencv-python-headless==4.5.5.62
2
- torch==1.10.1
3
- torchvision==0.11.2
 
1
+ opencv-python-headless==4.7.0.72
2
+ torch==1.13.1
3
+ torchvision==0.14.1
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }