hysts HF Staff commited on
Commit
77bce91
·
1 Parent(s): f75cd10
Files changed (5) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -29
  4. app.py +38 -74
  5. requirements.txt +2 -2
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,35 +4,7 @@ emoji: 🐠
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- # Configuration
13
-
14
- `title`: _string_
15
- Display title for the Space
16
-
17
- `emoji`: _string_
18
- Space emoji (emoji-only character allowed)
19
-
20
- `colorFrom`: _string_
21
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
-
23
- `colorTo`: _string_
24
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
-
26
- `sdk`: _string_
27
- Can be either `gradio`, `streamlit`, or `static`
28
-
29
- `sdk_version` : _string_
30
- Only applicable for `streamlit` SDK.
31
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
-
33
- `app_file`: _string_
34
- Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
35
- Path is relative to the root of the repository.
36
-
37
- `pinned`: _boolean_
38
- Whether the Space stays on top of your list.
 
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import json
8
  import os
@@ -18,30 +17,13 @@ import torchvision.transforms as T
18
 
19
  TITLE = 'RF5/danbooru-pretrained'
20
  DESCRIPTION = 'This is an unofficial demo for https://github.com/RF5/danbooru-pretrained.'
21
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.danbooru_pretrained" alt="visitor badge"/></center>'
22
 
23
- TOKEN = os.environ['TOKEN']
24
  MODEL_REPO = 'hysts/danbooru-pretrained'
25
  MODEL_FILENAME = 'resnet50-13306192.pth'
26
  LABEL_FILENAME = 'class_names_6000.json'
27
 
28
 
29
- def parse_args() -> argparse.Namespace:
30
- parser = argparse.ArgumentParser()
31
- parser.add_argument('--device', type=str, default='cpu')
32
- parser.add_argument('--score-slider-step', type=float, default=0.05)
33
- parser.add_argument('--score-threshold', type=float, default=0.4)
34
- parser.add_argument('--theme', type=str, default='dark-grass')
35
- parser.add_argument('--live', action='store_true')
36
- parser.add_argument('--share', action='store_true')
37
- parser.add_argument('--port', type=int)
38
- parser.add_argument('--disable-queue',
39
- dest='enable_queue',
40
- action='store_false')
41
- parser.add_argument('--allow-flagging', type=str, default='never')
42
- return parser.parse_args()
43
-
44
-
45
  def load_sample_image_paths() -> list[pathlib.Path]:
46
  image_dir = pathlib.Path('images')
47
  if not image_dir.exists():
@@ -49,7 +31,7 @@ def load_sample_image_paths() -> list[pathlib.Path]:
49
  path = huggingface_hub.hf_hub_download(dataset_repo,
50
  'images.tar.gz',
51
  repo_type='dataset',
52
- use_auth_token=TOKEN)
53
  with tarfile.open(path) as f:
54
  f.extractall()
55
  return sorted(image_dir.glob('*'))
@@ -58,7 +40,7 @@ def load_sample_image_paths() -> list[pathlib.Path]:
58
  def load_model(device: torch.device) -> torch.nn.Module:
59
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
60
  MODEL_FILENAME,
61
- use_auth_token=TOKEN)
62
  state_dict = torch.load(path)
63
  model = torch.hub.load('RF5/danbooru-pretrained',
64
  'resnet50',
@@ -72,7 +54,7 @@ def load_model(device: torch.device) -> torch.nn.Module:
72
  def load_labels() -> list[str]:
73
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
74
  LABEL_FILENAME,
75
- use_auth_token=TOKEN)
76
  with open(path) as f:
77
  labels = json.load(f)
78
  return labels
@@ -96,55 +78,37 @@ def predict(image: PIL.Image.Image, score_threshold: float,
96
  return res
97
 
98
 
99
- def main():
100
- args = parse_args()
101
- device = torch.device(args.device)
102
-
103
- image_paths = load_sample_image_paths()
104
- examples = [[path.as_posix(), args.score_threshold]
105
- for path in image_paths]
106
-
107
- model = load_model(device)
108
- labels = load_labels()
109
-
110
- transform = T.Compose([
111
- T.Resize(360),
112
- T.ToTensor(),
113
- T.Normalize(mean=[0.7137, 0.6628, 0.6519],
114
- std=[0.2970, 0.3017, 0.2979]),
115
- ])
116
-
117
- func = functools.partial(predict,
118
- transform=transform,
119
- device=device,
120
- model=model,
121
- labels=labels)
122
- func = functools.update_wrapper(func, predict)
123
-
124
- gr.Interface(
125
- func,
126
- [
127
- gr.inputs.Image(type='pil', label='Input'),
128
- gr.inputs.Slider(0,
129
- 1,
130
- step=args.score_slider_step,
131
- default=args.score_threshold,
132
- label='Score Threshold'),
133
- ],
134
- gr.outputs.Label(label='Output'),
135
- examples=examples,
136
- title=TITLE,
137
- description=DESCRIPTION,
138
- article=ARTICLE,
139
- theme=args.theme,
140
- allow_flagging=args.allow_flagging,
141
- live=args.live,
142
- ).launch(
143
- enable_queue=args.enable_queue,
144
- server_port=args.port,
145
- share=args.share,
146
- )
147
-
148
-
149
- if __name__ == '__main__':
150
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import json
7
  import os
 
17
 
18
  TITLE = 'RF5/danbooru-pretrained'
19
  DESCRIPTION = 'This is an unofficial demo for https://github.com/RF5/danbooru-pretrained.'
 
20
 
21
+ HF_TOKEN = os.getenv('HF_TOKEN')
22
  MODEL_REPO = 'hysts/danbooru-pretrained'
23
  MODEL_FILENAME = 'resnet50-13306192.pth'
24
  LABEL_FILENAME = 'class_names_6000.json'
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def load_sample_image_paths() -> list[pathlib.Path]:
28
  image_dir = pathlib.Path('images')
29
  if not image_dir.exists():
 
31
  path = huggingface_hub.hf_hub_download(dataset_repo,
32
  'images.tar.gz',
33
  repo_type='dataset',
34
+ use_auth_token=HF_TOKEN)
35
  with tarfile.open(path) as f:
36
  f.extractall()
37
  return sorted(image_dir.glob('*'))
 
40
  def load_model(device: torch.device) -> torch.nn.Module:
41
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
42
  MODEL_FILENAME,
43
+ use_auth_token=HF_TOKEN)
44
  state_dict = torch.load(path)
45
  model = torch.hub.load('RF5/danbooru-pretrained',
46
  'resnet50',
 
54
  def load_labels() -> list[str]:
55
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
56
  LABEL_FILENAME,
57
+ use_auth_token=HF_TOKEN)
58
  with open(path) as f:
59
  labels = json.load(f)
60
  return labels
 
78
  return res
79
 
80
 
81
+ image_paths = load_sample_image_paths()
82
+ examples = [[path.as_posix(), 0.4] for path in image_paths]
83
+
84
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
85
+ model = load_model(device)
86
+ labels = load_labels()
87
+
88
+ transform = T.Compose([
89
+ T.Resize(360),
90
+ T.ToTensor(),
91
+ T.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
92
+ ])
93
+
94
+ func = functools.partial(predict,
95
+ transform=transform,
96
+ device=device,
97
+ model=model,
98
+ labels=labels)
99
+
100
+ gr.Interface(
101
+ fn=func,
102
+ inputs=[
103
+ gr.Image(label='Input', type='pil'),
104
+ gr.Slider(label='Score Threshold',
105
+ minimum=0,
106
+ maximum=1,
107
+ step=0.05,
108
+ value=0.4),
109
+ ],
110
+ outputs=gr.Label(label='Output'),
111
+ examples=examples,
112
+ title=TITLE,
113
+ description=DESCRIPTION,
114
+ ).queue().launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- torch>=1.10.1
2
- torchvision>=0.11.2
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1