aadnk commited on
Commit
74b7d77
·
1 Parent(s): 67b8308

Adding JSON initial prompt

Browse files

By selecting "json_prompt_mode", you can
customize the prompt to each segment.

For instance:
[
{"segment_index": 0, "prompt": "Hello, how are you?"},
{"segment_index": 1, "prompt": "I'm doing well, how are you?"},
{"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
]

app.py CHANGED
@@ -13,12 +13,14 @@ import numpy as np
13
 
14
  import torch
15
 
16
- from src.config import ApplicationConfig, VadInitialPromptMode
17
  from src.hooks.progressListener import ProgressListener
18
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
19
  from src.hooks.whisperProgressHook import create_progress_listener_handle
20
  from src.languages import get_language_names
21
  from src.modelCache import ModelCache
 
 
22
  from src.source import get_audio_source_collection
23
  from src.vadParallel import ParallelContext, ParallelTranscription
24
 
@@ -271,8 +273,18 @@ class WhisperTranscriber:
271
  if ('task' in decodeOptions):
272
  task = decodeOptions.pop('task')
273
 
 
 
 
 
 
 
 
 
 
 
274
  # Callable for processing an audio file
275
- whisperCallable = model.create_callback(language, task, initial_prompt, initial_prompt_mode=vadOptions.vadInitialPromptMode, **decodeOptions)
276
 
277
  # The results
278
  if (vadOptions.vad == 'silero-vad'):
@@ -519,7 +531,7 @@ def create_ui(app_config: ApplicationConfig):
519
  *common_vad_inputs(),
520
  gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
521
  gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
522
- gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
523
 
524
  *common_word_timestamps_inputs(),
525
  gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
@@ -580,7 +592,7 @@ if __name__ == '__main__':
580
  help="The default model name.") # medium
581
  parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
582
  help="The default VAD.") # silero-vad
583
- parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
584
  help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
585
  parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
586
  help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
 
13
 
14
  import torch
15
 
16
+ from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
  from src.hooks.progressListener import ProgressListener
18
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
19
  from src.hooks.whisperProgressHook import create_progress_listener_handle
20
  from src.languages import get_language_names
21
  from src.modelCache import ModelCache
22
+ from src.prompts.jsonPromptStrategy import JsonPromptStrategy
23
+ from src.prompts.prependPromptStrategy import PrependPromptStrategy
24
  from src.source import get_audio_source_collection
25
  from src.vadParallel import ParallelContext, ParallelTranscription
26
 
 
273
  if ('task' in decodeOptions):
274
  task = decodeOptions.pop('task')
275
 
276
+ if (vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
277
+ vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
278
+ # Prepend initial prompt
279
+ prompt_strategy = PrependPromptStrategy(initial_prompt, vadOptions.vadInitialPromptMode)
280
+ elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
281
+ # Use a JSON format to specify the prompt for each segment
282
+ prompt_strategy = JsonPromptStrategy(initial_prompt)
283
+ else:
284
+ raise ValueError("Invalid vadInitialPromptMode: " + vadOptions.vadInitialPromptMode)
285
+
286
  # Callable for processing an audio file
287
+ whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
288
 
289
  # The results
290
  if (vadOptions.vad == 'silero-vad'):
 
531
  *common_vad_inputs(),
532
  gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
533
  gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
534
+ gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode"),
535
 
536
  *common_word_timestamps_inputs(),
537
  gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
 
592
  help="The default model name.") # medium
593
  parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
594
  help="The default VAD.") # silero-vad
595
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
596
  help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
597
  parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
598
  help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
cli.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
 
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
- from src.config import ApplicationConfig, VadInitialPromptMode
11
  from src.download import download_url
12
  from src.languages import get_language_names
13
 
@@ -47,7 +47,7 @@ def cli():
47
 
48
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
49
  help="The voice activity detection algorithm to use") # silero-vad
50
- parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
51
  help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
52
  parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
53
  help="The window size (in seconds) to merge voice segments")
 
7
 
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
+ from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
11
  from src.download import download_url
12
  from src.languages import get_language_names
13
 
 
47
 
48
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
49
  help="The voice activity detection algorithm to use") # silero-vad
50
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
51
  help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
52
  parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
53
  help="The window size (in seconds) to merge voice segments")
src/config.py CHANGED
@@ -24,9 +24,12 @@ class ModelConfig:
24
  self.path = path
25
  self.type = type
26
 
 
 
27
  class VadInitialPromptMode(Enum):
28
  PREPEND_ALL_SEGMENTS = 1
29
  PREPREND_FIRST_SEGMENT = 2
 
30
 
31
  @staticmethod
32
  def from_string(s: str):
@@ -36,6 +39,8 @@ class VadInitialPromptMode(Enum):
36
  return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
37
  elif normalized == "prepend_first_segment":
38
  return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
 
 
39
  else:
40
  raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
41
 
 
24
  self.path = path
25
  self.type = type
26
 
27
+ VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
28
+
29
  class VadInitialPromptMode(Enum):
30
  PREPEND_ALL_SEGMENTS = 1
31
  PREPREND_FIRST_SEGMENT = 2
32
+ JSON_PROMPT_MODE = 3
33
 
34
  @staticmethod
35
  def from_string(s: str):
 
39
  return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
40
  elif normalized == "prepend_first_segment":
41
  return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
42
+ elif normalized == "json_prompt_mode":
43
+ return VadInitialPromptMode.JSON_PROMPT_MODE
44
  else:
45
  raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
46
 
src/prompts/abstractPromptStrategy.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+
4
+ class AbstractPromptStrategy:
5
+ """
6
+ Represents a strategy for generating prompts for a given audio segment.
7
+
8
+ Note that the strategy must be picklable, as it will be serialized and sent to the workers.
9
+ """
10
+
11
+ @abc.abstractmethod
12
+ def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
13
+ """
14
+ Retrieves the prompt for a given segment.
15
+
16
+ Parameters
17
+ ----------
18
+ segment_index: int
19
+ The index of the segment.
20
+ whisper_prompt: str
21
+ The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
22
+ detected_language: str
23
+ The language detected for the segment.
24
+ """
25
+ pass
26
+
27
+ @abc.abstractmethod
28
+ def on_segment_finished(self, segment_index: int, whisper_prompt: str, detected_language: str, result: dict):
29
+ """
30
+ Called when a segment has finished processing.
31
+
32
+ Parameters
33
+ ----------
34
+ segment_index: int
35
+ The index of the segment.
36
+ whisper_prompt: str
37
+ The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
38
+ detected_language: str
39
+ The language detected for the segment.
40
+ result: dict
41
+ The result of the segment. It has the following format:
42
+ {
43
+ "text": str,
44
+ "segments": [
45
+ {
46
+ "text": str,
47
+ "start": float,
48
+ "end": float,
49
+ "words": [words],
50
+ }
51
+ ],
52
+ "language": str,
53
+ }
54
+ """
55
+ pass
56
+
57
+ def _concat_prompt(self, prompt1, prompt2):
58
+ """
59
+ Concatenates two prompts.
60
+
61
+ Parameters
62
+ ----------
63
+ prompt1: str
64
+ The first prompt.
65
+ prompt2: str
66
+ The second prompt.
67
+ """
68
+ if (prompt1 is None):
69
+ return prompt2
70
+ elif (prompt2 is None):
71
+ return prompt1
72
+ else:
73
+ return prompt1 + " " + prompt2
src/prompts/jsonPromptStrategy.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
3
+
4
+
5
+ class JsonPromptSegment():
6
+ def __init__(self, segment_index: int, prompt: str, format_prompt: bool = False):
7
+ self.prompt = prompt
8
+ self.segment_index = segment_index
9
+ self.format_prompt = format_prompt
10
+
11
+ class JsonPromptStrategy(AbstractPromptStrategy):
12
+ def __init__(self, initial_json_prompt: str):
13
+ """
14
+ Parameters
15
+ ----------
16
+ initial_json_prompt: str
17
+ The initial prompts for each segment in JSON form.
18
+
19
+ Format:
20
+ [
21
+ {"segment_index": 0, "prompt": "Hello, how are you?"},
22
+ {"segment_index": 1, "prompt": "I'm doing well, how are you?"},
23
+ {"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
24
+ ]
25
+
26
+ """
27
+ parsed_json = json.loads(initial_json_prompt)
28
+ self.segment_lookup = dict[str, JsonPromptSegment]()
29
+
30
+ for prompt_entry in parsed_json:
31
+ segment_index = prompt_entry["segment_index"]
32
+ prompt = prompt_entry["prompt"]
33
+ format_prompt = prompt_entry.get("format_prompt", False)
34
+ self.segment_lookup[str(segment_index)] = JsonPromptSegment(segment_index, prompt, format_prompt)
35
+
36
+ def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
37
+ # Lookup prompt
38
+ prompt = self.segment_lookup.get(str(segment_index), None)
39
+
40
+ if (prompt is None):
41
+ # No prompt found, return whisper prompt
42
+ print(f"Could not find prompt for segment {segment_index}, returning whisper prompt")
43
+ return whisper_prompt
44
+
45
+ if (prompt.format_prompt):
46
+ return prompt.prompt.format(whisper_prompt)
47
+ else:
48
+ return self._concat_prompt(prompt.prompt, whisper_prompt)
src/prompts/prependPromptStrategy.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config import VadInitialPromptMode
2
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
3
+
4
+ class PrependPromptStrategy(AbstractPromptStrategy):
5
+ """
6
+ A simple prompt strategy that prepends a single prompt to all segments of audio, or prepends the prompt to the first segment of audio.
7
+ """
8
+ def __init__(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode):
9
+ """
10
+ Parameters
11
+ ----------
12
+ initial_prompt: str
13
+ The initial prompt to use for the transcription.
14
+ initial_prompt_mode: VadInitialPromptMode
15
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
16
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
17
+ """
18
+ self.initial_prompt = initial_prompt
19
+ self.initial_prompt_mode = initial_prompt_mode
20
+
21
+ # This is a simple prompt strategy, so we only support these two modes
22
+ if initial_prompt_mode not in [VadInitialPromptMode.PREPEND_ALL_SEGMENTS, VadInitialPromptMode.PREPREND_FIRST_SEGMENT]:
23
+ raise ValueError(f"Unsupported initial prompt mode {initial_prompt_mode}")
24
+
25
+ def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
26
+ if (self.initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
27
+ return self._concat_prompt(self.initial_prompt, whisper_prompt)
28
+ elif (self.initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
29
+ return self._concat_prompt(self.initial_prompt, whisper_prompt) if segment_index == 0 else whisper_prompt
30
+ else:
31
+ raise ValueError(f"Unknown initial prompt mode {self.initial_prompt_mode}")
src/whisper/abstractWhisperContainer.py CHANGED
@@ -1,11 +1,16 @@
1
  import abc
2
  from typing import List
 
3
  from src.config import ModelConfig, VadInitialPromptMode
4
 
5
  from src.hooks.progressListener import ProgressListener
6
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
 
7
 
8
  class AbstractWhisperCallback:
 
 
 
9
  @abc.abstractmethod
10
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
11
  """
@@ -24,23 +29,6 @@ class AbstractWhisperCallback:
24
  """
25
  raise NotImplementedError()
26
 
27
- def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
28
- prompt: str, segment_index: int):
29
- if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
30
- return self._concat_prompt(initial_prompt, prompt)
31
- elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
32
- return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
33
- else:
34
- raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
35
-
36
- def _concat_prompt(self, prompt1, prompt2):
37
- if (prompt1 is None):
38
- return prompt2
39
- elif (prompt2 is None):
40
- return prompt1
41
- else:
42
- return prompt1 + " " + prompt2
43
-
44
  class AbstractWhisperContainer:
45
  def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
46
  download_root: str = None,
@@ -75,8 +63,8 @@ class AbstractWhisperContainer:
75
  pass
76
 
77
  @abc.abstractmethod
78
- def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
79
- initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
80
  **decodeOptions: dict) -> AbstractWhisperCallback:
81
  """
82
  Create a WhisperCallback object that can be used to transcript audio files.
@@ -87,11 +75,8 @@ class AbstractWhisperContainer:
87
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
88
  task: str
89
  The task - either translate or transcribe.
90
- initial_prompt: str
91
- The initial prompt to use for the transcription.
92
- initial_prompt_mode: VadInitialPromptMode
93
- The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
94
- If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
95
  decodeOptions: dict
96
  Additional options to pass to the decoder. Must be pickleable.
97
 
 
1
  import abc
2
  from typing import List
3
+
4
  from src.config import ModelConfig, VadInitialPromptMode
5
 
6
  from src.hooks.progressListener import ProgressListener
7
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
8
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
9
 
10
  class AbstractWhisperCallback:
11
+ def __init__(self):
12
+ self.__prompt_mode_gpt = None
13
+
14
  @abc.abstractmethod
15
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
16
  """
 
29
  """
30
  raise NotImplementedError()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class AbstractWhisperContainer:
33
  def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
34
  download_root: str = None,
 
63
  pass
64
 
65
  @abc.abstractmethod
66
+ def create_callback(self, language: str = None, task: str = None,
67
+ prompt_strategy: AbstractPromptStrategy = None,
68
  **decodeOptions: dict) -> AbstractWhisperCallback:
69
  """
70
  Create a WhisperCallback object that can be used to transcript audio files.
 
75
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
76
  task: str
77
  The task - either translate or transcribe.
78
+ prompt_strategy: AbstractPromptStrategy
79
+ The prompt strategy to use for the transcription.
 
 
 
80
  decodeOptions: dict
81
  Additional options to pass to the decoder. Must be pickleable.
82
 
src/whisper/fasterWhisperContainer.py CHANGED
@@ -6,6 +6,7 @@ from src.config import ModelConfig, VadInitialPromptMode
6
  from src.hooks.progressListener import ProgressListener
7
  from src.languages import get_language_from_name
8
  from src.modelCache import ModelCache
 
9
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
10
  from src.utils import format_timestamp
11
 
@@ -56,8 +57,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
56
  model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
57
  return model
58
 
59
- def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
60
- initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
61
  **decodeOptions: dict) -> AbstractWhisperCallback:
62
  """
63
  Create a WhisperCallback object that can be used to transcript audio files.
@@ -68,11 +69,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
68
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
69
  task: str
70
  The task - either translate or transcribe.
71
- initial_prompt: str
72
- The initial prompt to use for the transcription.
73
- initial_prompt_mode: VadInitialPromptMode
74
- The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
75
- If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
76
  decodeOptions: dict
77
  Additional options to pass to the decoder. Must be pickleable.
78
 
@@ -80,17 +78,16 @@ class FasterWhisperContainer(AbstractWhisperContainer):
80
  -------
81
  A WhisperCallback object.
82
  """
83
- return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
84
 
85
  class FasterWhisperCallback(AbstractWhisperCallback):
86
  def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
87
- initial_prompt: str = None, initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
88
  **decodeOptions: dict):
89
  self.model_container = model_container
90
  self.language = language
91
  self.task = task
92
- self.initial_prompt = initial_prompt
93
- self.initial_prompt_mode = initial_prompt_mode
94
  self.decodeOptions = decodeOptions
95
 
96
  self._printed_warning = False
@@ -138,7 +135,8 @@ class FasterWhisperCallback(AbstractWhisperCallback):
138
  # See if supress_tokens is a string - if so, convert it to a list of ints
139
  decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
140
 
141
- initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
 
142
 
143
  segments_generator, info = model.transcribe(audio, \
144
  language=language_code if language_code else detected_language, task=self.task, \
@@ -184,6 +182,10 @@ class FasterWhisperCallback(AbstractWhisperCallback):
184
  "duration": info.duration if info else None
185
  }
186
 
 
 
 
 
187
  if progress_listener is not None:
188
  progress_listener.on_finished()
189
  return result
 
6
  from src.hooks.progressListener import ProgressListener
7
  from src.languages import get_language_from_name
8
  from src.modelCache import ModelCache
9
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
10
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
11
  from src.utils import format_timestamp
12
 
 
57
  model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
58
  return model
59
 
60
+ def create_callback(self, language: str = None, task: str = None,
61
+ prompt_strategy: AbstractPromptStrategy = None,
62
  **decodeOptions: dict) -> AbstractWhisperCallback:
63
  """
64
  Create a WhisperCallback object that can be used to transcript audio files.
 
69
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
70
  task: str
71
  The task - either translate or transcribe.
72
+ prompt_strategy: AbstractPromptStrategy
73
+ The prompt strategy to use. If not specified, the prompt from Whisper will be used.
 
 
 
74
  decodeOptions: dict
75
  Additional options to pass to the decoder. Must be pickleable.
76
 
 
78
  -------
79
  A WhisperCallback object.
80
  """
81
+ return FasterWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
82
 
83
  class FasterWhisperCallback(AbstractWhisperCallback):
84
  def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
85
+ prompt_strategy: AbstractPromptStrategy = None,
86
  **decodeOptions: dict):
87
  self.model_container = model_container
88
  self.language = language
89
  self.task = task
90
+ self.prompt_strategy = prompt_strategy
 
91
  self.decodeOptions = decodeOptions
92
 
93
  self._printed_warning = False
 
135
  # See if supress_tokens is a string - if so, convert it to a list of ints
136
  decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
137
 
138
+ initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
139
+ if self.prompt_strategy else prompt
140
 
141
  segments_generator, info = model.transcribe(audio, \
142
  language=language_code if language_code else detected_language, task=self.task, \
 
182
  "duration": info.duration if info else None
183
  }
184
 
185
+ # If we have a prompt strategy, we need to increment the current prompt
186
+ if self.prompt_strategy:
187
+ self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
188
+
189
  if progress_listener is not None:
190
  progress_listener.on_finished()
191
  return result
src/whisper/whisperContainer.py CHANGED
@@ -15,6 +15,7 @@ from src.config import ModelConfig, VadInitialPromptMode
15
  from src.hooks.whisperProgressHook import create_progress_listener_handle
16
 
17
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
 
18
  from src.utils import download_file
19
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
20
 
@@ -69,8 +70,8 @@ class WhisperContainer(AbstractWhisperContainer):
69
 
70
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
71
 
72
- def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
73
- initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
74
  **decodeOptions: dict) -> AbstractWhisperCallback:
75
  """
76
  Create a WhisperCallback object that can be used to transcript audio files.
@@ -81,11 +82,8 @@ class WhisperContainer(AbstractWhisperContainer):
81
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
82
  task: str
83
  The task - either translate or transcribe.
84
- initial_prompt: str
85
- The initial prompt to use for the transcription.
86
- initial_prompt_mode: VadInitialPromptMode
87
- The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
88
- If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
89
  decodeOptions: dict
90
  Additional options to pass to the decoder. Must be pickleable.
91
 
@@ -93,7 +91,7 @@ class WhisperContainer(AbstractWhisperContainer):
93
  -------
94
  A WhisperCallback object.
95
  """
96
- return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
97
 
98
  def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
99
  from src.conversion.hf_converter import convert_hf_whisper
@@ -162,13 +160,14 @@ class WhisperContainer(AbstractWhisperContainer):
162
  return model_config.path
163
 
164
  class WhisperCallback(AbstractWhisperCallback):
165
- def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None,
166
- initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT, **decodeOptions: dict):
 
167
  self.model_container = model_container
168
  self.language = language
169
  self.task = task
170
- self.initial_prompt = initial_prompt
171
- self.initial_prompt_mode = initial_prompt_mode
172
  self.decodeOptions = decodeOptions
173
 
174
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
@@ -201,11 +200,17 @@ class WhisperCallback(AbstractWhisperCallback):
201
  if self.model_container.compute_type in ["fp16", "float16"]:
202
  decodeOptions["fp16"] = True
203
 
204
- initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
 
205
 
206
  result = model.transcribe(audio, \
207
  language=self.language if self.language else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
  )
 
 
 
 
 
211
  return result
 
15
  from src.hooks.whisperProgressHook import create_progress_listener_handle
16
 
17
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
18
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
19
  from src.utils import download_file
20
  from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
21
 
 
70
 
71
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
72
 
73
+ def create_callback(self, language: str = None, task: str = None,
74
+ prompt_strategy: AbstractPromptStrategy = None,
75
  **decodeOptions: dict) -> AbstractWhisperCallback:
76
  """
77
  Create a WhisperCallback object that can be used to transcript audio files.
 
82
  The target language of the transcription. If not specified, the language will be inferred from the audio content.
83
  task: str
84
  The task - either translate or transcribe.
85
+ prompt_strategy: AbstractPromptStrategy
86
+ The prompt strategy to use. If not specified, the prompt from Whisper will be used.
 
 
 
87
  decodeOptions: dict
88
  Additional options to pass to the decoder. Must be pickleable.
89
 
 
91
  -------
92
  A WhisperCallback object.
93
  """
94
+ return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
95
 
96
  def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
97
  from src.conversion.hf_converter import convert_hf_whisper
 
160
  return model_config.path
161
 
162
  class WhisperCallback(AbstractWhisperCallback):
163
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
164
+ prompt_strategy: AbstractPromptStrategy = None,
165
+ **decodeOptions: dict):
166
  self.model_container = model_container
167
  self.language = language
168
  self.task = task
169
+ self.prompt_strategy = prompt_strategy
170
+
171
  self.decodeOptions = decodeOptions
172
 
173
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
 
200
  if self.model_container.compute_type in ["fp16", "float16"]:
201
  decodeOptions["fp16"] = True
202
 
203
+ initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
204
+ if self.prompt_strategy else prompt
205
 
206
  result = model.transcribe(audio, \
207
  language=self.language if self.language else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
  )
211
+
212
+ # If we have a prompt strategy, we need to increment the current prompt
213
+ if self.prompt_strategy:
214
+ self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
215
+
216
  return result