umyuu commited on
Commit
46c8ce2
·
1 Parent(s): b76540e

リファクタリング

Browse files
Files changed (5) hide show
  1. src/__init__.py +1 -1
  2. src/args_parser.py +1 -1
  3. src/reporter.py +1 -1
  4. src/saliency.py +23 -9
  5. src/utils.py +17 -10
src/__init__.py CHANGED
@@ -9,7 +9,7 @@ from src.utils import get_package_version
9
 
10
  __all__ = ["LocalTimeFormatter"]
11
 
12
- PROGRAM_NAME = 'SaliencyMapDemo'
13
  __version__ = get_package_version()
14
 
15
 
 
9
 
10
  __all__ = ["LocalTimeFormatter"]
11
 
12
+ PROGRAM_NAME: str = 'SaliencyMapDemo'
13
  __version__ = get_package_version()
14
 
15
 
src/args_parser.py CHANGED
@@ -1,7 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
  """コマンドライン引数の解析"""
3
  from argparse import ArgumentParser, BooleanOptionalAction
4
- from src import PROGRAM_NAME, get_package_version
5
 
6
 
7
  def parse_args():
 
1
  # -*- coding: utf-8 -*-
2
  """コマンドライン引数の解析"""
3
  from argparse import ArgumentParser, BooleanOptionalAction
4
+ from . import PROGRAM_NAME, get_package_version
5
 
6
 
7
  def parse_args():
src/reporter.py CHANGED
@@ -1,6 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Reporter
4
  ログハンドラーが重複登録されるのを防ぐために1箇所で生成してログハンドラーを返します。
5
  Example:
6
  from src.reporter import log
@@ -27,6 +26,7 @@ class Reporter:
27
 
28
  def __new__(cls):
29
  """
 
30
  """
31
  # インスタンスがまだ存在しない場合は新たに作成します。
32
  if not cls._instance:
 
1
  # -*- coding: utf-8 -*-
2
  """
 
3
  ログハンドラーが重複登録されるのを防ぐために1箇所で生成してログハンドラーを返します。
4
  Example:
5
  from src.reporter import log
 
26
 
27
  def __new__(cls):
28
  """
29
+ インスタンスの生成を制御します。
30
  """
31
  # インスタンスがまだ存在しない場合は新たに作成します。
32
  if not cls._instance:
src/saliency.py CHANGED
@@ -8,7 +8,8 @@ import cv2
8
 
9
  class SaliencyMap:
10
  """
11
- 顕著性マップを計算するクラス。
 
12
  Example:
13
  from src.saliency import SaliencyMap
14
 
@@ -19,6 +20,15 @@ class SaliencyMap:
19
  self,
20
  algorithm: Literal["SpectralResidual", "FineGrained"] = "SpectralResidual",
21
  ):
 
 
 
 
 
 
 
 
 
22
  self.algorithm = algorithm
23
  # OpenCVのsaliencyを作成します。
24
  if algorithm == "SpectralResidual":
@@ -28,17 +38,16 @@ class SaliencyMap:
28
 
29
  def compute(self, image: np.ndarray) -> Tuple[bool, Any]:
30
  """
31
- 入力画像から顕著性マップを作成します。
32
 
33
  Parameters:
34
  image: 入力画像
35
 
36
  Returns:
37
- bool:
38
- true: SaliencyMap computed, false:NG
39
- np.ndarray: 顕著性マップ
40
  """
41
- # 画像の顕著性を計算します。
42
  return self.saliency.computeSaliency(image)
43
 
44
 
@@ -48,19 +57,24 @@ def convert_colormap(
48
  colormap_name: Literal["jet", "hot", "turbo"] = "jet"
49
  ):
50
  """
51
- 顕著性マップをカラーマップに変換後に、入力画像に重ね合わせします。
52
 
53
  Parameters:
54
  image: 入力画像
55
  saliency_map: 顕著性マップ
56
  colormap_name: カラーマップの種類
 
 
 
57
 
58
  Returns:
59
- np.ndarray: 重ね合わせた画像(RGBA形式)
60
  """
61
  maps = {"jet": cv2.COLORMAP_JET, "hot": cv2.COLORMAP_HOT, "turbo": cv2.COLORMAP_TURBO}
 
 
62
  if colormap_name not in maps:
63
- raise ValueError(colormap_name)
64
 
65
  # 顕著性マップをカラーマップに変換
66
  saliency_map = (saliency_map * 255).astype("uint8")
 
8
 
9
  class SaliencyMap:
10
  """
11
+ 画像から顕著性マップを計算するクラス。
12
+
13
  Example:
14
  from src.saliency import SaliencyMap
15
 
 
20
  self,
21
  algorithm: Literal["SpectralResidual", "FineGrained"] = "SpectralResidual",
22
  ):
23
+ """
24
+ SaliencyMapオブジェクトを初期化します。
25
+
26
+ Parameters:
27
+ algorithm: 使用する顕著性マップアルゴリズムの種類。
28
+ 有効なアルゴリズムについてはOpenCVのドキュメントを参照してください。
29
+ https://docs.opencv.org/4.9.0/d8/d65/group__saliency.html
30
+
31
+ """
32
  self.algorithm = algorithm
33
  # OpenCVのsaliencyを作成します。
34
  if algorithm == "SpectralResidual":
 
38
 
39
  def compute(self, image: np.ndarray) -> Tuple[bool, Any]:
40
  """
41
+ 入力画像から顕著性マップを計算します。
42
 
43
  Parameters:
44
  image: 入力画像
45
 
46
  Returns:
47
+ Tuple[bool, Any]: 顕著性マップの計算結果。
48
+ bool値がTrueの場合は計算成功、Falseの場合は失敗。
49
+ 顕著性マップのデータ。
50
  """
 
51
  return self.saliency.computeSaliency(image)
52
 
53
 
 
57
  colormap_name: Literal["jet", "hot", "turbo"] = "jet"
58
  ):
59
  """
60
+ 入力画像と顕著性マップを合成し、指定されたカラーマップを適用します。
61
 
62
  Parameters:
63
  image: 入力画像
64
  saliency_map: 顕著性マップ
65
  colormap_name: カラーマップの種類
66
+ "jet": Jetカラーマップ
67
+ "hot": Hotカラーマップ
68
+ "turbo": Turboカラーマップ
69
 
70
  Returns:
71
+ np.ndarray: 合成された画像 (RGBA形式)
72
  """
73
  maps = {"jet": cv2.COLORMAP_JET, "hot": cv2.COLORMAP_HOT, "turbo": cv2.COLORMAP_TURBO}
74
+
75
+ # colormap_nameが有効かどうかをチェック
76
  if colormap_name not in maps:
77
+ raise ValueError(f"Invalid colormap name: {colormap_name}")
78
 
79
  # 顕著性マップをカラーマップに変換
80
  saliency_map = (saliency_map * 255).astype("uint8")
src/utils.py CHANGED
@@ -6,7 +6,7 @@ import time
6
 
7
  def get_package_version() -> str:
8
  """
9
- バージョン情報
10
  """
11
  return '0.0.8'
12
 
@@ -29,18 +29,20 @@ class Stopwatch:
29
  @property
30
  def elapsed(self) -> float:
31
  """
32
- 経過時間を取得します。
 
 
 
33
  """
34
  if self._is_running:
35
- end_time = time.perf_counter()
36
- self._elapsed = end_time - self._start_time
37
 
38
  return self._elapsed
39
 
40
  @property
41
  def is_running(self) -> bool:
42
  """
43
- 実行中かどうかを取得します。
44
  """
45
  return self._is_running
46
 
@@ -53,9 +55,12 @@ class Stopwatch:
53
  self._is_running = True
54
 
55
  @classmethod
56
- def start_new(cls):
57
  """
58
- ストップウォッチを生成し計測を開始します。
 
 
 
59
  """
60
  stopwatch = Stopwatch()
61
  stopwatch.start()
@@ -63,10 +68,12 @@ class Stopwatch:
63
 
64
  def stop(self) -> float:
65
  """
66
- 計測を終了します。
 
 
 
67
  """
68
  if self._is_running:
69
- end_time = time.perf_counter()
70
- self._elapsed = end_time - self._start_time
71
  self._is_running = False
72
  return self._elapsed
 
6
 
7
  def get_package_version() -> str:
8
  """
9
+ バージョン情報を取得します。
10
  """
11
  return '0.0.8'
12
 
 
29
  @property
30
  def elapsed(self) -> float:
31
  """
32
+ 計測中の経過時間を取得します。
33
+
34
+ Returns:
35
+ float: 計測中の経過時間(小数秒)
36
  """
37
  if self._is_running:
38
+ self._elapsed = time.perf_counter() - self._start_time
 
39
 
40
  return self._elapsed
41
 
42
  @property
43
  def is_running(self) -> bool:
44
  """
45
+ 計測が実行中であるかどうかを取得します
46
  """
47
  return self._is_running
48
 
 
55
  self._is_running = True
56
 
57
  @classmethod
58
+ def start_new(cls) -> 'Stopwatch':
59
  """
60
+ 新しいストップウォッチを生成し、計測を開始します。
61
+
62
+ Returns:
63
+ Stopwatch: 新しいストップウォッチオブジェクト
64
  """
65
  stopwatch = Stopwatch()
66
  stopwatch.start()
 
68
 
69
  def stop(self) -> float:
70
  """
71
+ 計測を終了し、経過時間を返します。
72
+
73
+ Returns:
74
+ float: 計測中の経過時間
75
  """
76
  if self._is_running:
77
+ self._elapsed = time.perf_counter() - self._start_time
 
78
  self._is_running = False
79
  return self._elapsed