daihui.zhang commited on
Commit
dfb349e
·
1 Parent(s): 7e7b241

update whisper config

Browse files
config.py CHANGED
@@ -1,5 +1,5 @@
1
  import pathlib
2
-
3
  import logging
4
 
5
  logging.basicConfig(
@@ -18,6 +18,12 @@ ASSERT_DIR = BASE_DIR / "assets"
18
  SENTENCE_END_MARKERS = ['.', '!', '?', '。', '!', '?', ';', ';', ':', ':']
19
  PAUSE_END_MARKERS = [',', ',', '、']
20
 
 
 
 
 
 
 
21
  # whisper推理参数
22
  WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。"
23
  MAX_LENTH_ZH = 4
 
1
  import pathlib
2
+ import re
3
  import logging
4
 
5
  logging.basicConfig(
 
18
  SENTENCE_END_MARKERS = ['.', '!', '?', '。', '!', '?', ';', ';', ':', ':']
19
  PAUSE_END_MARKERS = [',', ',', '、']
20
 
21
+ sentence_end_chars = ''.join([re.escape(char) for char in SENTENCE_END_MARKERS])
22
+ SENTENCE_END_PATTERN = re.compile(f'[{sentence_end_chars}]')
23
+
24
+ # Method 2: Alternative approach with a character class
25
+ pattern_string = '[' + ''.join([re.escape(char) for char in PAUSE_END_MARKERS]) + ']'
26
+ PAUSEE_END_PATTERN = re.compile(pattern_string)
27
  # whisper推理参数
28
  WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。"
29
  MAX_LENTH_ZH = 4
transcribe/helpers/whisper.py CHANGED
@@ -17,7 +17,10 @@ class WhisperCPP:
17
  print_realtime=False,
18
  print_progress=False,
19
  print_timestamps=False,
20
- translate=False
 
 
 
21
  )
22
  if warmup:
23
  self.warmup()
 
17
  print_realtime=False,
18
  print_progress=False,
19
  print_timestamps=False,
20
+ translate=False,
21
+ # beam_search=1,
22
+ temperature=0.,
23
+ no_context=True
24
  )
25
  if warmup:
26
  self.warmup()
transcribe/strategy.py CHANGED
@@ -98,7 +98,7 @@ def segement_merge(segments):
98
 
99
  for seg in segments:
100
  temp_seq.append(seg)
101
- if any([mk in seg.text for mk in config.SENTENCE_END_MARKERS + config.PAUSE_END_MARKERS]):
102
  sequences.append(temp_seq.copy())
103
  temp_seq = []
104
  if temp_seq:
@@ -123,7 +123,8 @@ def segments_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
123
  if seg.text and seg.text[-1] in markers:
124
  seg_index = int(seg.t1 / 100 * sample_rate)
125
  # rest_buffer_duration = (len(audio_buffer) - seg_index) / sample_rate
126
- # is_end = any(i in seg.text for i in config.SENTENCE_END_MARKERS)
 
127
  right_watch_sequences = segments[min(idx+1, len(segments)):]
128
  # if rest_buffer_duration >= 1.5:
129
  left_watch_idx = seg_index
 
98
 
99
  for seg in segments:
100
  temp_seq.append(seg)
101
+ if any([mk in seg.text for mk in config.SENTENCE_END_MARKERS]):
102
  sequences.append(temp_seq.copy())
103
  temp_seq = []
104
  if temp_seq:
 
123
  if seg.text and seg.text[-1] in markers:
124
  seg_index = int(seg.t1 / 100 * sample_rate)
125
  # rest_buffer_duration = (len(audio_buffer) - seg_index) / sample_rate
126
+ is_end = config.SENTENCE_END_PATTERN.search(seg.text)
127
+
128
  right_watch_sequences = segments[min(idx+1, len(segments)):]
129
  # if rest_buffer_duration >= 1.5:
130
  left_watch_idx = seg_index