top001 commited on
Commit
c2d6d13
·
verified ·
1 Parent(s): 6cc18e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -15,7 +15,6 @@ import tensorflow as tf
15
 
16
  DESCRIPTION = "# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)"
17
 
18
-
19
  def load_sample_image_paths() -> list[pathlib.Path]:
20
  image_dir = pathlib.Path("images")
21
  if not image_dir.exists():
@@ -24,24 +23,20 @@ def load_sample_image_paths() -> list[pathlib.Path]:
24
  f.extractall()
25
  return sorted(image_dir.glob("*"))
26
 
27
-
28
  def load_model() -> tf.keras.Model:
29
  path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5")
30
  model = tf.keras.models.load_model(path)
31
  return model
32
 
33
-
34
  def load_labels() -> list[str]:
35
  path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "tags.txt")
36
  with open(path) as f:
37
  labels = [line.strip() for line in f.readlines()]
38
  return labels
39
 
40
-
41
  model = load_model()
42
  labels = load_labels()
43
 
44
-
45
  def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
46
  _, height, width, _ = model.input_shape
47
  image = np.asarray(image)
@@ -65,12 +60,12 @@ def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, f
65
  result_text = ", ".join(result_all.keys())
66
  return result_threshold, result_all, result_text
67
 
68
-
69
  image_paths = load_sample_image_paths()
70
  examples = [[path.as_posix(), 0.5] for path in image_paths]
71
 
72
  with gr.Blocks(css="style.css") as demo:
73
  gr.Markdown(DESCRIPTION)
 
74
  with gr.Row():
75
  with gr.Column():
76
  image = gr.Image(label="Input", type="pil")
@@ -84,6 +79,7 @@ with gr.Blocks(css="style.css") as demo:
84
  result_json = gr.JSON(label="JSON output", show_label=False)
85
  with gr.Tab(label="Text"):
86
  result_text = gr.Text(label="Text output", show_label=False, lines=5)
 
87
  gr.Examples(
88
  examples=examples,
89
  inputs=[image, score_threshold],
 
15
 
16
  DESCRIPTION = "# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)"
17
 
 
18
  def load_sample_image_paths() -> list[pathlib.Path]:
19
  image_dir = pathlib.Path("images")
20
  if not image_dir.exists():
 
23
  f.extractall()
24
  return sorted(image_dir.glob("*"))
25
 
 
26
  def load_model() -> tf.keras.Model:
27
  path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5")
28
  model = tf.keras.models.load_model(path)
29
  return model
30
 
 
31
  def load_labels() -> list[str]:
32
  path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "tags.txt")
33
  with open(path) as f:
34
  labels = [line.strip() for line in f.readlines()]
35
  return labels
36
 
 
37
  model = load_model()
38
  labels = load_labels()
39
 
 
40
  def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
41
  _, height, width, _ = model.input_shape
42
  image = np.asarray(image)
 
60
  result_text = ", ".join(result_all.keys())
61
  return result_threshold, result_all, result_text
62
 
 
63
  image_paths = load_sample_image_paths()
64
  examples = [[path.as_posix(), 0.5] for path in image_paths]
65
 
66
  with gr.Blocks(css="style.css") as demo:
67
  gr.Markdown(DESCRIPTION)
68
+
69
  with gr.Row():
70
  with gr.Column():
71
  image = gr.Image(label="Input", type="pil")
 
79
  result_json = gr.JSON(label="JSON output", show_label=False)
80
  with gr.Tab(label="Text"):
81
  result_text = gr.Text(label="Text output", show_label=False, lines=5)
82
+
83
  gr.Examples(
84
  examples=examples,
85
  inputs=[image, score_threshold],