Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
ec4d8ae
1
Parent(s):
6b7295b
refactor: update response model names and module name
Browse files- src/faster_whisper_server/{server_models.py → api_models.py} +102 -16
- src/faster_whisper_server/asr.py +11 -5
- src/faster_whisper_server/routers/list_models.py +9 -9
- src/faster_whisper_server/routers/stt.py +25 -20
- src/faster_whisper_server/{core.py → text_utils.py} +13 -188
- src/faster_whisper_server/text_utils_test.py +111 -0
- src/faster_whisper_server/transcriber.py +3 -2
- tests/api_timestamp_granularities_test.py +1 -1
- tests/openai_timestamp_granularities_test.py +1 -1
- tests/sse_test.py +5 -5
src/faster_whisper_server/{server_models.py → api_models.py}
RENAMED
@@ -4,36 +4,117 @@ from typing import TYPE_CHECKING, Literal
|
|
4 |
|
5 |
from pydantic import BaseModel, ConfigDict, Field
|
6 |
|
7 |
-
from faster_whisper_server.
|
8 |
|
9 |
if TYPE_CHECKING:
|
10 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
# https://platform.openai.com/docs/api-reference/audio/json-object
|
14 |
-
|
|
|
15 |
text: str
|
16 |
|
17 |
@classmethod
|
18 |
-
def from_segments(cls, segments: list[
|
19 |
return cls(text=segments_to_text(segments))
|
20 |
|
21 |
@classmethod
|
22 |
-
def from_transcription(cls, transcription: Transcription) ->
|
23 |
return cls(text=transcription.text)
|
24 |
|
25 |
|
26 |
# https://platform.openai.com/docs/api-reference/audio/verbose-json-object
|
27 |
-
|
|
|
28 |
task: str = "transcribe"
|
29 |
language: str
|
30 |
duration: float
|
31 |
text: str
|
32 |
-
words: list[
|
33 |
-
segments: list[
|
34 |
|
35 |
@classmethod
|
36 |
-
def from_segment(
|
|
|
|
|
37 |
return cls(
|
38 |
language=transcription_info.language,
|
39 |
duration=segment.end - segment.start,
|
@@ -44,18 +125,20 @@ class TranscriptionVerboseJsonResponse(BaseModel):
|
|
44 |
|
45 |
@classmethod
|
46 |
def from_segments(
|
47 |
-
cls, segments: list[
|
48 |
-
) ->
|
49 |
return cls(
|
50 |
language=transcription_info.language,
|
51 |
duration=transcription_info.duration,
|
52 |
text=segments_to_text(segments),
|
53 |
segments=segments,
|
54 |
-
words=
|
|
|
|
|
55 |
)
|
56 |
|
57 |
@classmethod
|
58 |
-
def from_transcription(cls, transcription: Transcription) ->
|
59 |
return cls(
|
60 |
language="english", # FIX: hardcoded
|
61 |
duration=transcription.duration,
|
@@ -65,12 +148,14 @@ class TranscriptionVerboseJsonResponse(BaseModel):
|
|
65 |
)
|
66 |
|
67 |
|
68 |
-
|
69 |
-
|
|
|
70 |
object: Literal["list"] = "list"
|
71 |
|
72 |
|
73 |
-
|
|
|
74 |
id: str
|
75 |
"""The model identifier, which can be referenced in the API endpoints."""
|
76 |
created: int
|
@@ -109,6 +194,7 @@ class ModelObject(BaseModel):
|
|
109 |
)
|
110 |
|
111 |
|
|
|
112 |
TimestampGranularities = list[Literal["segment", "word"]]
|
113 |
|
114 |
|
|
|
4 |
|
5 |
from pydantic import BaseModel, ConfigDict, Field
|
6 |
|
7 |
+
from faster_whisper_server.text_utils import Transcription, canonicalize_word, segments_to_text
|
8 |
|
9 |
if TYPE_CHECKING:
|
10 |
+
from collections.abc import Iterable
|
11 |
+
|
12 |
+
import faster_whisper.transcribe
|
13 |
+
|
14 |
+
|
15 |
+
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10909
|
16 |
+
class TranscriptionWord(BaseModel):
|
17 |
+
start: float
|
18 |
+
end: float
|
19 |
+
word: str
|
20 |
+
probability: float
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def from_segments(cls, segments: Iterable[TranscriptionSegment]) -> list[TranscriptionWord]:
|
24 |
+
words: list[TranscriptionWord] = []
|
25 |
+
for segment in segments:
|
26 |
+
# NOTE: a temporary "fix" for https://github.com/fedirz/faster-whisper-server/issues/58.
|
27 |
+
# TODO: properly address the issue
|
28 |
+
assert (
|
29 |
+
segment.words is not None
|
30 |
+
), "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set"
|
31 |
+
words.extend(segment.words)
|
32 |
+
return words
|
33 |
+
|
34 |
+
def offset(self, seconds: float) -> None:
|
35 |
+
self.start += seconds
|
36 |
+
self.end += seconds
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def common_prefix(cls, a: list[TranscriptionWord], b: list[TranscriptionWord]) -> list[TranscriptionWord]:
|
40 |
+
i = 0
|
41 |
+
while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
|
42 |
+
i += 1
|
43 |
+
return a[:i]
|
44 |
+
|
45 |
+
|
46 |
+
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10938
|
47 |
+
class TranscriptionSegment(BaseModel):
|
48 |
+
id: int
|
49 |
+
seek: int
|
50 |
+
start: float
|
51 |
+
end: float
|
52 |
+
text: str
|
53 |
+
tokens: list[int]
|
54 |
+
temperature: float
|
55 |
+
avg_logprob: float
|
56 |
+
compression_ratio: float
|
57 |
+
no_speech_prob: float
|
58 |
+
words: list[TranscriptionWord] | None
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def from_faster_whisper_segments(
|
62 |
+
cls, segments: Iterable[faster_whisper.transcribe.Segment]
|
63 |
+
) -> Iterable[TranscriptionSegment]:
|
64 |
+
for segment in segments:
|
65 |
+
yield cls(
|
66 |
+
id=segment.id,
|
67 |
+
seek=segment.seek,
|
68 |
+
start=segment.start,
|
69 |
+
end=segment.end,
|
70 |
+
text=segment.text,
|
71 |
+
tokens=segment.tokens,
|
72 |
+
temperature=segment.temperature,
|
73 |
+
avg_logprob=segment.avg_logprob,
|
74 |
+
compression_ratio=segment.compression_ratio,
|
75 |
+
no_speech_prob=segment.no_speech_prob,
|
76 |
+
words=[
|
77 |
+
TranscriptionWord(
|
78 |
+
start=word.start,
|
79 |
+
end=word.end,
|
80 |
+
word=word.word,
|
81 |
+
probability=word.probability,
|
82 |
+
)
|
83 |
+
for word in segment.words
|
84 |
+
]
|
85 |
+
if segment.words is not None
|
86 |
+
else None,
|
87 |
+
)
|
88 |
|
89 |
|
90 |
# https://platform.openai.com/docs/api-reference/audio/json-object
|
91 |
+
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10924
|
92 |
+
class CreateTranscriptionResponseJson(BaseModel):
|
93 |
text: str
|
94 |
|
95 |
@classmethod
|
96 |
+
def from_segments(cls, segments: list[TranscriptionSegment]) -> CreateTranscriptionResponseJson:
|
97 |
return cls(text=segments_to_text(segments))
|
98 |
|
99 |
@classmethod
|
100 |
+
def from_transcription(cls, transcription: Transcription) -> CreateTranscriptionResponseJson:
|
101 |
return cls(text=transcription.text)
|
102 |
|
103 |
|
104 |
# https://platform.openai.com/docs/api-reference/audio/verbose-json-object
|
105 |
+
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L11007
|
106 |
+
class CreateTranscriptionResponseVerboseJson(BaseModel):
|
107 |
task: str = "transcribe"
|
108 |
language: str
|
109 |
duration: float
|
110 |
text: str
|
111 |
+
words: list[TranscriptionWord] | None
|
112 |
+
segments: list[TranscriptionSegment]
|
113 |
|
114 |
@classmethod
|
115 |
+
def from_segment(
|
116 |
+
cls, segment: TranscriptionSegment, transcription_info: faster_whisper.transcribe.TranscriptionInfo
|
117 |
+
) -> CreateTranscriptionResponseVerboseJson:
|
118 |
return cls(
|
119 |
language=transcription_info.language,
|
120 |
duration=segment.end - segment.start,
|
|
|
125 |
|
126 |
@classmethod
|
127 |
def from_segments(
|
128 |
+
cls, segments: list[TranscriptionSegment], transcription_info: faster_whisper.transcribe.TranscriptionInfo
|
129 |
+
) -> CreateTranscriptionResponseVerboseJson:
|
130 |
return cls(
|
131 |
language=transcription_info.language,
|
132 |
duration=transcription_info.duration,
|
133 |
text=segments_to_text(segments),
|
134 |
segments=segments,
|
135 |
+
words=TranscriptionWord.from_segments(segments)
|
136 |
+
if transcription_info.transcription_options.word_timestamps
|
137 |
+
else None,
|
138 |
)
|
139 |
|
140 |
@classmethod
|
141 |
+
def from_transcription(cls, transcription: Transcription) -> CreateTranscriptionResponseVerboseJson:
|
142 |
return cls(
|
143 |
language="english", # FIX: hardcoded
|
144 |
duration=transcription.duration,
|
|
|
148 |
)
|
149 |
|
150 |
|
151 |
+
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8730
|
152 |
+
class ListModelsResponse(BaseModel):
|
153 |
+
data: list[Model]
|
154 |
object: Literal["list"] = "list"
|
155 |
|
156 |
|
157 |
+
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L11146
|
158 |
+
class Model(BaseModel):
|
159 |
id: str
|
160 |
"""The model identifier, which can be referenced in the API endpoints."""
|
161 |
created: int
|
|
|
194 |
)
|
195 |
|
196 |
|
197 |
+
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10909
|
198 |
TimestampGranularities = list[Literal["segment", "word"]]
|
199 |
|
200 |
|
src/faster_whisper_server/asr.py
CHANGED
@@ -1,11 +1,17 @@
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import logging
|
3 |
import time
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
|
|
6 |
|
7 |
-
from faster_whisper_server.audio import Audio
|
8 |
-
from faster_whisper_server.core import Segment, Transcription, Word
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
@@ -31,8 +37,8 @@ class FasterWhisperASR:
|
|
31 |
word_timestamps=True,
|
32 |
**self.transcribe_opts,
|
33 |
)
|
34 |
-
segments =
|
35 |
-
words =
|
36 |
for word in words:
|
37 |
word.offset(audio.start)
|
38 |
transcription = Transcription(words)
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import asyncio
|
4 |
import logging
|
5 |
import time
|
6 |
+
from typing import TYPE_CHECKING
|
7 |
+
|
8 |
+
from faster_whisper_server.api_models import TranscriptionSegment, TranscriptionWord
|
9 |
+
from faster_whisper_server.text_utils import Transcription
|
10 |
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from faster_whisper import transcribe
|
13 |
|
14 |
+
from faster_whisper_server.audio import Audio
|
|
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
37 |
word_timestamps=True,
|
38 |
**self.transcribe_opts,
|
39 |
)
|
40 |
+
segments = TranscriptionSegment.from_faster_whisper_segments(segments)
|
41 |
+
words = TranscriptionWord.from_segments(segments)
|
42 |
for word in words:
|
43 |
word.offset(audio.start)
|
44 |
transcription = Transcription(words)
|
src/faster_whisper_server/routers/list_models.py
CHANGED
@@ -9,9 +9,9 @@ from fastapi import (
|
|
9 |
)
|
10 |
import huggingface_hub
|
11 |
|
12 |
-
from faster_whisper_server.
|
13 |
-
|
14 |
-
|
15 |
)
|
16 |
|
17 |
if TYPE_CHECKING:
|
@@ -21,11 +21,11 @@ router = APIRouter()
|
|
21 |
|
22 |
|
23 |
@router.get("/v1/models")
|
24 |
-
def get_models() ->
|
25 |
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
|
26 |
models = list(models)
|
27 |
models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
|
28 |
-
transformed_models: list[
|
29 |
for model in models:
|
30 |
assert model.created_at is not None
|
31 |
assert model.card_data is not None
|
@@ -36,7 +36,7 @@ def get_models() -> ModelListResponse:
|
|
36 |
language = [model.card_data.language]
|
37 |
else:
|
38 |
language = model.card_data.language
|
39 |
-
transformed_model =
|
40 |
id=model.id,
|
41 |
created=int(model.created_at.timestamp()),
|
42 |
object_="model",
|
@@ -44,14 +44,14 @@ def get_models() -> ModelListResponse:
|
|
44 |
language=language,
|
45 |
)
|
46 |
transformed_models.append(transformed_model)
|
47 |
-
return
|
48 |
|
49 |
|
50 |
@router.get("/v1/models/{model_name:path}")
|
51 |
# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
|
52 |
def get_model(
|
53 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
54 |
-
) ->
|
55 |
models = huggingface_hub.list_models(
|
56 |
model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
|
57 |
)
|
@@ -78,7 +78,7 @@ def get_model(
|
|
78 |
language = [exact_match.card_data.language]
|
79 |
else:
|
80 |
language = exact_match.card_data.language
|
81 |
-
return
|
82 |
id=exact_match.id,
|
83 |
created=int(exact_match.created_at.timestamp()),
|
84 |
object_="model",
|
|
|
9 |
)
|
10 |
import huggingface_hub
|
11 |
|
12 |
+
from faster_whisper_server.api_models import (
|
13 |
+
ListModelsResponse,
|
14 |
+
Model,
|
15 |
)
|
16 |
|
17 |
if TYPE_CHECKING:
|
|
|
21 |
|
22 |
|
23 |
@router.get("/v1/models")
|
24 |
+
def get_models() -> ListModelsResponse:
|
25 |
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
|
26 |
models = list(models)
|
27 |
models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
|
28 |
+
transformed_models: list[Model] = []
|
29 |
for model in models:
|
30 |
assert model.created_at is not None
|
31 |
assert model.card_data is not None
|
|
|
36 |
language = [model.card_data.language]
|
37 |
else:
|
38 |
language = model.card_data.language
|
39 |
+
transformed_model = Model(
|
40 |
id=model.id,
|
41 |
created=int(model.created_at.timestamp()),
|
42 |
object_="model",
|
|
|
44 |
language=language,
|
45 |
)
|
46 |
transformed_models.append(transformed_model)
|
47 |
+
return ListModelsResponse(data=transformed_models)
|
48 |
|
49 |
|
50 |
@router.get("/v1/models/{model_name:path}")
|
51 |
# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
|
52 |
def get_model(
|
53 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
54 |
+
) -> Model:
|
55 |
models = huggingface_hub.list_models(
|
56 |
model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
|
57 |
)
|
|
|
78 |
language = [exact_match.card_data.language]
|
79 |
else:
|
80 |
language = exact_match.card_data.language
|
81 |
+
return Model(
|
82 |
id=exact_match.id,
|
83 |
created=int(exact_match.created_at.timestamp()),
|
84 |
object_="model",
|
src/faster_whisper_server/routers/stt.py
CHANGED
@@ -20,6 +20,14 @@ from fastapi.websockets import WebSocketState
|
|
20 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
21 |
from pydantic import AfterValidator
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
from faster_whisper_server.asr import FasterWhisperASR
|
24 |
from faster_whisper_server.audio import AudioStream, audio_samples_from_file
|
25 |
from faster_whisper_server.config import (
|
@@ -28,15 +36,8 @@ from faster_whisper_server.config import (
|
|
28 |
ResponseFormat,
|
29 |
Task,
|
30 |
)
|
31 |
-
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
|
32 |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
|
33 |
-
from faster_whisper_server.
|
34 |
-
DEFAULT_TIMESTAMP_GRANULARITIES,
|
35 |
-
TIMESTAMP_GRANULARITIES_COMBINATIONS,
|
36 |
-
TimestampGranularities,
|
37 |
-
TranscriptionJsonResponse,
|
38 |
-
TranscriptionVerboseJsonResponse,
|
39 |
-
)
|
40 |
from faster_whisper_server.transcriber import audio_transcriber
|
41 |
|
42 |
if TYPE_CHECKING:
|
@@ -51,7 +52,7 @@ router = APIRouter()
|
|
51 |
|
52 |
|
53 |
def segments_to_response(
|
54 |
-
segments: Iterable[
|
55 |
transcription_info: TranscriptionInfo,
|
56 |
response_format: ResponseFormat,
|
57 |
) -> Response:
|
@@ -60,12 +61,12 @@ def segments_to_response(
|
|
60 |
return Response(segments_to_text(segments), media_type="text/plain")
|
61 |
elif response_format == ResponseFormat.JSON:
|
62 |
return Response(
|
63 |
-
|
64 |
media_type="application/json",
|
65 |
)
|
66 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
67 |
return Response(
|
68 |
-
|
69 |
media_type="application/json",
|
70 |
)
|
71 |
elif response_format == ResponseFormat.VTT:
|
@@ -83,7 +84,7 @@ def format_as_sse(data: str) -> str:
|
|
83 |
|
84 |
|
85 |
def segments_to_streaming_response(
|
86 |
-
segments: Iterable[
|
87 |
transcription_info: TranscriptionInfo,
|
88 |
response_format: ResponseFormat,
|
89 |
) -> StreamingResponse:
|
@@ -92,9 +93,11 @@ def segments_to_streaming_response(
|
|
92 |
if response_format == ResponseFormat.TEXT:
|
93 |
data = segment.text
|
94 |
elif response_format == ResponseFormat.JSON:
|
95 |
-
data =
|
96 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
97 |
-
data =
|
|
|
|
|
98 |
elif response_format == ResponseFormat.VTT:
|
99 |
data = segments_to_vtt(segment, i)
|
100 |
elif response_format == ResponseFormat.SRT:
|
@@ -121,7 +124,7 @@ ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
|
|
121 |
|
122 |
@router.post(
|
123 |
"/v1/audio/translations",
|
124 |
-
response_model=str |
|
125 |
)
|
126 |
def translate_file(
|
127 |
config: ConfigDependency,
|
@@ -145,7 +148,7 @@ def translate_file(
|
|
145 |
temperature=temperature,
|
146 |
vad_filter=True,
|
147 |
)
|
148 |
-
segments =
|
149 |
|
150 |
if stream:
|
151 |
return segments_to_streaming_response(segments, transcription_info, response_format)
|
@@ -169,7 +172,7 @@ async def get_timestamp_granularities(request: Request) -> TimestampGranularitie
|
|
169 |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
|
170 |
@router.post(
|
171 |
"/v1/audio/transcriptions",
|
172 |
-
response_model=str |
|
173 |
)
|
174 |
def transcribe_file(
|
175 |
config: ConfigDependency,
|
@@ -211,7 +214,7 @@ def transcribe_file(
|
|
211 |
vad_filter=True,
|
212 |
hotwords=hotwords,
|
213 |
)
|
214 |
-
segments =
|
215 |
|
216 |
if stream:
|
217 |
return segments_to_streaming_response(segments, transcription_info, response_format)
|
@@ -286,9 +289,11 @@ async def transcribe_stream(
|
|
286 |
if response_format == ResponseFormat.TEXT:
|
287 |
await ws.send_text(transcription.text)
|
288 |
elif response_format == ResponseFormat.JSON:
|
289 |
-
await ws.send_json(
|
290 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
291 |
-
await ws.send_json(
|
|
|
|
|
292 |
|
293 |
if ws.client_state != WebSocketState.DISCONNECTED:
|
294 |
logger.info("Closing the connection.")
|
|
|
20 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
21 |
from pydantic import AfterValidator
|
22 |
|
23 |
+
from faster_whisper_server.api_models import (
|
24 |
+
DEFAULT_TIMESTAMP_GRANULARITIES,
|
25 |
+
TIMESTAMP_GRANULARITIES_COMBINATIONS,
|
26 |
+
CreateTranscriptionResponseJson,
|
27 |
+
CreateTranscriptionResponseVerboseJson,
|
28 |
+
TimestampGranularities,
|
29 |
+
TranscriptionSegment,
|
30 |
+
)
|
31 |
from faster_whisper_server.asr import FasterWhisperASR
|
32 |
from faster_whisper_server.audio import AudioStream, audio_samples_from_file
|
33 |
from faster_whisper_server.config import (
|
|
|
36 |
ResponseFormat,
|
37 |
Task,
|
38 |
)
|
|
|
39 |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
|
40 |
+
from faster_whisper_server.text_utils import segments_to_srt, segments_to_text, segments_to_vtt
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
from faster_whisper_server.transcriber import audio_transcriber
|
42 |
|
43 |
if TYPE_CHECKING:
|
|
|
52 |
|
53 |
|
54 |
def segments_to_response(
|
55 |
+
segments: Iterable[TranscriptionSegment],
|
56 |
transcription_info: TranscriptionInfo,
|
57 |
response_format: ResponseFormat,
|
58 |
) -> Response:
|
|
|
61 |
return Response(segments_to_text(segments), media_type="text/plain")
|
62 |
elif response_format == ResponseFormat.JSON:
|
63 |
return Response(
|
64 |
+
CreateTranscriptionResponseJson.from_segments(segments).model_dump_json(),
|
65 |
media_type="application/json",
|
66 |
)
|
67 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
68 |
return Response(
|
69 |
+
CreateTranscriptionResponseVerboseJson.from_segments(segments, transcription_info).model_dump_json(),
|
70 |
media_type="application/json",
|
71 |
)
|
72 |
elif response_format == ResponseFormat.VTT:
|
|
|
84 |
|
85 |
|
86 |
def segments_to_streaming_response(
|
87 |
+
segments: Iterable[TranscriptionSegment],
|
88 |
transcription_info: TranscriptionInfo,
|
89 |
response_format: ResponseFormat,
|
90 |
) -> StreamingResponse:
|
|
|
93 |
if response_format == ResponseFormat.TEXT:
|
94 |
data = segment.text
|
95 |
elif response_format == ResponseFormat.JSON:
|
96 |
+
data = CreateTranscriptionResponseJson.from_segments([segment]).model_dump_json()
|
97 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
98 |
+
data = CreateTranscriptionResponseVerboseJson.from_segment(
|
99 |
+
segment, transcription_info
|
100 |
+
).model_dump_json()
|
101 |
elif response_format == ResponseFormat.VTT:
|
102 |
data = segments_to_vtt(segment, i)
|
103 |
elif response_format == ResponseFormat.SRT:
|
|
|
124 |
|
125 |
@router.post(
|
126 |
"/v1/audio/translations",
|
127 |
+
response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson,
|
128 |
)
|
129 |
def translate_file(
|
130 |
config: ConfigDependency,
|
|
|
148 |
temperature=temperature,
|
149 |
vad_filter=True,
|
150 |
)
|
151 |
+
segments = TranscriptionSegment.from_faster_whisper_segments(segments)
|
152 |
|
153 |
if stream:
|
154 |
return segments_to_streaming_response(segments, transcription_info, response_format)
|
|
|
172 |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
|
173 |
@router.post(
|
174 |
"/v1/audio/transcriptions",
|
175 |
+
response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson,
|
176 |
)
|
177 |
def transcribe_file(
|
178 |
config: ConfigDependency,
|
|
|
214 |
vad_filter=True,
|
215 |
hotwords=hotwords,
|
216 |
)
|
217 |
+
segments = TranscriptionSegment.from_faster_whisper_segments(segments)
|
218 |
|
219 |
if stream:
|
220 |
return segments_to_streaming_response(segments, transcription_info, response_format)
|
|
|
289 |
if response_format == ResponseFormat.TEXT:
|
290 |
await ws.send_text(transcription.text)
|
291 |
elif response_format == ResponseFormat.JSON:
|
292 |
+
await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
|
293 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
294 |
+
await ws.send_json(
|
295 |
+
CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
|
296 |
+
)
|
297 |
|
298 |
if ws.client_state != WebSocketState.DISCONNECTED:
|
299 |
logger.info("Closing the connection.")
|
src/faster_whisper_server/{core.py → text_utils.py}
RENAMED
@@ -3,90 +3,17 @@ from __future__ import annotations
|
|
3 |
import re
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
6 |
-
from pydantic import BaseModel
|
7 |
-
|
8 |
from faster_whisper_server.dependencies import get_config
|
9 |
|
10 |
if TYPE_CHECKING:
|
11 |
from collections.abc import Iterable
|
12 |
|
13 |
-
import
|
14 |
-
|
15 |
-
|
16 |
-
class Word(BaseModel):
|
17 |
-
start: float
|
18 |
-
end: float
|
19 |
-
word: str
|
20 |
-
probability: float
|
21 |
-
|
22 |
-
@classmethod
|
23 |
-
def from_segments(cls, segments: Iterable[Segment]) -> list[Word]:
|
24 |
-
words: list[Word] = []
|
25 |
-
for segment in segments:
|
26 |
-
# NOTE: a temporary "fix" for https://github.com/fedirz/faster-whisper-server/issues/58.
|
27 |
-
# TODO: properly address the issue
|
28 |
-
assert (
|
29 |
-
segment.words is not None
|
30 |
-
), "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set"
|
31 |
-
words.extend(segment.words)
|
32 |
-
return words
|
33 |
-
|
34 |
-
def offset(self, seconds: float) -> None:
|
35 |
-
self.start += seconds
|
36 |
-
self.end += seconds
|
37 |
-
|
38 |
-
@classmethod
|
39 |
-
def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
|
40 |
-
i = 0
|
41 |
-
while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
|
42 |
-
i += 1
|
43 |
-
return a[:i]
|
44 |
-
|
45 |
-
|
46 |
-
class Segment(BaseModel):
|
47 |
-
id: int
|
48 |
-
seek: int
|
49 |
-
start: float
|
50 |
-
end: float
|
51 |
-
text: str
|
52 |
-
tokens: list[int]
|
53 |
-
temperature: float
|
54 |
-
avg_logprob: float
|
55 |
-
compression_ratio: float
|
56 |
-
no_speech_prob: float
|
57 |
-
words: list[Word] | None
|
58 |
-
|
59 |
-
@classmethod
|
60 |
-
def from_faster_whisper_segments(cls, segments: Iterable[faster_whisper.transcribe.Segment]) -> Iterable[Segment]:
|
61 |
-
for segment in segments:
|
62 |
-
yield cls(
|
63 |
-
id=segment.id,
|
64 |
-
seek=segment.seek,
|
65 |
-
start=segment.start,
|
66 |
-
end=segment.end,
|
67 |
-
text=segment.text,
|
68 |
-
tokens=segment.tokens,
|
69 |
-
temperature=segment.temperature,
|
70 |
-
avg_logprob=segment.avg_logprob,
|
71 |
-
compression_ratio=segment.compression_ratio,
|
72 |
-
no_speech_prob=segment.no_speech_prob,
|
73 |
-
words=[
|
74 |
-
Word(
|
75 |
-
start=word.start,
|
76 |
-
end=word.end,
|
77 |
-
word=word.word,
|
78 |
-
probability=word.probability,
|
79 |
-
)
|
80 |
-
for word in segment.words
|
81 |
-
]
|
82 |
-
if segment.words is not None
|
83 |
-
else None,
|
84 |
-
)
|
85 |
|
86 |
|
87 |
class Transcription:
|
88 |
-
def __init__(self, words: list[
|
89 |
-
self.words: list[
|
90 |
self.extend(words)
|
91 |
|
92 |
@property
|
@@ -108,11 +35,11 @@ class Transcription:
|
|
108 |
def after(self, seconds: float) -> Transcription:
|
109 |
return Transcription(words=[word for word in self.words if word.start > seconds])
|
110 |
|
111 |
-
def extend(self, words: list[
|
112 |
self._ensure_no_word_overlap(words)
|
113 |
self.words.extend(words)
|
114 |
|
115 |
-
def _ensure_no_word_overlap(self, words: list[
|
116 |
config = get_config() # HACK
|
117 |
if len(self.words) > 0 and len(words) > 0:
|
118 |
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
@@ -130,19 +57,8 @@ def is_eos(text: str) -> bool:
|
|
130 |
return any(text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
|
131 |
|
132 |
|
133 |
-
def
|
134 |
-
|
135 |
-
assert not is_eos("Hello...")
|
136 |
-
assert is_eos("Hello.")
|
137 |
-
assert is_eos("Hello!")
|
138 |
-
assert is_eos("Hello?")
|
139 |
-
assert not is_eos("Hello. Yo")
|
140 |
-
assert not is_eos("Hello. Yo...")
|
141 |
-
assert is_eos("Hello. Yo.")
|
142 |
-
|
143 |
-
|
144 |
-
def to_full_sentences(words: list[Word]) -> list[list[Word]]:
|
145 |
-
sentences: list[list[Word]] = [[]]
|
146 |
for word in words:
|
147 |
sentences[-1].append(word)
|
148 |
if is_eos(word.word):
|
@@ -152,28 +68,15 @@ def to_full_sentences(words: list[Word]) -> list[list[Word]]:
|
|
152 |
return sentences
|
153 |
|
154 |
|
155 |
-
def
|
156 |
-
def word(text: str) -> Word:
|
157 |
-
return Word(word=text, start=0.0, end=0.0, probability=0.0)
|
158 |
-
|
159 |
-
assert to_full_sentences([]) == []
|
160 |
-
assert to_full_sentences([word(text="Hello")]) == []
|
161 |
-
assert to_full_sentences([word(text="Hello..."), word(" world")]) == []
|
162 |
-
assert to_full_sentences([word(text="Hello..."), word(" world.")]) == [[word("Hello..."), word(" world.")]]
|
163 |
-
assert to_full_sentences([word(text="Hello..."), word(" world."), word(" How")]) == [
|
164 |
-
[word("Hello..."), word(" world.")],
|
165 |
-
]
|
166 |
-
|
167 |
-
|
168 |
-
def word_to_text(words: list[Word]) -> str:
|
169 |
return "".join(word.word for word in words)
|
170 |
|
171 |
|
172 |
-
def words_to_text_w_ts(words: list[
|
173 |
return "".join(f"{word.word}({word.start:.2f}-{word.end:.2f})" for word in words)
|
174 |
|
175 |
|
176 |
-
def segments_to_text(segments: Iterable[
|
177 |
return "".join(segment.text for segment in segments).strip()
|
178 |
|
179 |
|
@@ -185,19 +88,6 @@ def srt_format_timestamp(ts: float) -> str:
|
|
185 |
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
186 |
|
187 |
|
188 |
-
def test_srt_format_timestamp() -> None:
|
189 |
-
assert srt_format_timestamp(0.0) == "00:00:00,000"
|
190 |
-
assert srt_format_timestamp(1.0) == "00:00:01,000"
|
191 |
-
assert srt_format_timestamp(1.234) == "00:00:01,234"
|
192 |
-
assert srt_format_timestamp(60.0) == "00:01:00,000"
|
193 |
-
assert srt_format_timestamp(61.0) == "00:01:01,000"
|
194 |
-
assert srt_format_timestamp(61.234) == "00:01:01,234"
|
195 |
-
assert srt_format_timestamp(3600.0) == "01:00:00,000"
|
196 |
-
assert srt_format_timestamp(3601.0) == "01:00:01,000"
|
197 |
-
assert srt_format_timestamp(3601.234) == "01:00:01,234"
|
198 |
-
assert srt_format_timestamp(23423.4234) == "06:30:23,423"
|
199 |
-
|
200 |
-
|
201 |
def vtt_format_timestamp(ts: float) -> str:
|
202 |
hours = ts // 3600
|
203 |
minutes = (ts % 3600) // 60
|
@@ -206,20 +96,7 @@ def vtt_format_timestamp(ts: float) -> str:
|
|
206 |
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
207 |
|
208 |
|
209 |
-
def
|
210 |
-
assert vtt_format_timestamp(0.0) == "00:00:00.000"
|
211 |
-
assert vtt_format_timestamp(1.0) == "00:00:01.000"
|
212 |
-
assert vtt_format_timestamp(1.234) == "00:00:01.234"
|
213 |
-
assert vtt_format_timestamp(60.0) == "00:01:00.000"
|
214 |
-
assert vtt_format_timestamp(61.0) == "00:01:01.000"
|
215 |
-
assert vtt_format_timestamp(61.234) == "00:01:01.234"
|
216 |
-
assert vtt_format_timestamp(3600.0) == "01:00:00.000"
|
217 |
-
assert vtt_format_timestamp(3601.0) == "01:00:01.000"
|
218 |
-
assert vtt_format_timestamp(3601.234) == "01:00:01.234"
|
219 |
-
assert vtt_format_timestamp(23423.4234) == "06:30:23.423"
|
220 |
-
|
221 |
-
|
222 |
-
def segments_to_vtt(segment: Segment, i: int) -> str:
|
223 |
start = segment.start if i > 0 else 0.0
|
224 |
result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n"
|
225 |
|
@@ -229,7 +106,7 @@ def segments_to_vtt(segment: Segment, i: int) -> str:
|
|
229 |
return result
|
230 |
|
231 |
|
232 |
-
def segments_to_srt(segment:
|
233 |
return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n"
|
234 |
|
235 |
|
@@ -240,60 +117,8 @@ def canonicalize_word(text: str) -> str:
|
|
240 |
return text.lower().strip().strip(".,?!")
|
241 |
|
242 |
|
243 |
-
def
|
244 |
-
assert canonicalize_word("ABC") == "abc"
|
245 |
-
assert canonicalize_word("...ABC?") == "abc"
|
246 |
-
assert canonicalize_word("... AbC ...") == "abc"
|
247 |
-
|
248 |
-
|
249 |
-
def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
|
250 |
i = 0
|
251 |
while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
|
252 |
i += 1
|
253 |
return a[:i]
|
254 |
-
|
255 |
-
|
256 |
-
def test_common_prefix() -> None:
|
257 |
-
def word(text: str) -> Word:
|
258 |
-
return Word(word=text, start=0.0, end=0.0, probability=0.0)
|
259 |
-
|
260 |
-
a = [word("a"), word("b"), word("c")]
|
261 |
-
b = [word("a"), word("b"), word("c")]
|
262 |
-
assert common_prefix(a, b) == [word("a"), word("b"), word("c")]
|
263 |
-
|
264 |
-
a = [word("a"), word("b"), word("c")]
|
265 |
-
b = [word("a"), word("b"), word("d")]
|
266 |
-
assert common_prefix(a, b) == [word("a"), word("b")]
|
267 |
-
|
268 |
-
a = [word("a"), word("b"), word("c")]
|
269 |
-
b = [word("a")]
|
270 |
-
assert common_prefix(a, b) == [word("a")]
|
271 |
-
|
272 |
-
a = [word("a")]
|
273 |
-
b = [word("a"), word("b"), word("c")]
|
274 |
-
assert common_prefix(a, b) == [word("a")]
|
275 |
-
|
276 |
-
a = [word("a")]
|
277 |
-
b = []
|
278 |
-
assert common_prefix(a, b) == []
|
279 |
-
|
280 |
-
a = []
|
281 |
-
b = [word("a")]
|
282 |
-
assert common_prefix(a, b) == []
|
283 |
-
|
284 |
-
a = [word("a"), word("b"), word("c")]
|
285 |
-
b = [word("b"), word("c")]
|
286 |
-
assert common_prefix(a, b) == []
|
287 |
-
|
288 |
-
|
289 |
-
def test_common_prefix_and_canonicalization() -> None:
|
290 |
-
def word(text: str) -> Word:
|
291 |
-
return Word(word=text, start=0.0, end=0.0, probability=0.0)
|
292 |
-
|
293 |
-
a = [word("A...")]
|
294 |
-
b = [word("a?"), word("b"), word("c")]
|
295 |
-
assert common_prefix(a, b) == [word("A...")]
|
296 |
-
|
297 |
-
a = [word("A..."), word("B?"), word("C,")]
|
298 |
-
b = [word("a??"), word(" b"), word(" ,c")]
|
299 |
-
assert common_prefix(a, b) == [word("A..."), word("B?"), word("C,")]
|
|
|
3 |
import re
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
|
|
|
|
6 |
from faster_whisper_server.dependencies import get_config
|
7 |
|
8 |
if TYPE_CHECKING:
|
9 |
from collections.abc import Iterable
|
10 |
|
11 |
+
from faster_whisper_server.api_models import TranscriptionSegment, TranscriptionWord
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class Transcription:
|
15 |
+
def __init__(self, words: list[TranscriptionWord] = []) -> None:
|
16 |
+
self.words: list[TranscriptionWord] = []
|
17 |
self.extend(words)
|
18 |
|
19 |
@property
|
|
|
35 |
def after(self, seconds: float) -> Transcription:
|
36 |
return Transcription(words=[word for word in self.words if word.start > seconds])
|
37 |
|
38 |
+
def extend(self, words: list[TranscriptionWord]) -> None:
|
39 |
self._ensure_no_word_overlap(words)
|
40 |
self.words.extend(words)
|
41 |
|
42 |
+
def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None:
|
43 |
config = get_config() # HACK
|
44 |
if len(self.words) > 0 and len(words) > 0:
|
45 |
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
|
|
57 |
return any(text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
|
58 |
|
59 |
|
60 |
+
def to_full_sentences(words: list[TranscriptionWord]) -> list[list[TranscriptionWord]]:
|
61 |
+
sentences: list[list[TranscriptionWord]] = [[]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
for word in words:
|
63 |
sentences[-1].append(word)
|
64 |
if is_eos(word.word):
|
|
|
68 |
return sentences
|
69 |
|
70 |
|
71 |
+
def word_to_text(words: list[TranscriptionWord]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
return "".join(word.word for word in words)
|
73 |
|
74 |
|
75 |
+
def words_to_text_w_ts(words: list[TranscriptionWord]) -> str:
|
76 |
return "".join(f"{word.word}({word.start:.2f}-{word.end:.2f})" for word in words)
|
77 |
|
78 |
|
79 |
+
def segments_to_text(segments: Iterable[TranscriptionSegment]) -> str:
|
80 |
return "".join(segment.text for segment in segments).strip()
|
81 |
|
82 |
|
|
|
88 |
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
89 |
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
def vtt_format_timestamp(ts: float) -> str:
|
92 |
hours = ts // 3600
|
93 |
minutes = (ts % 3600) // 60
|
|
|
96 |
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
97 |
|
98 |
|
99 |
+
def segments_to_vtt(segment: TranscriptionSegment, i: int) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
start = segment.start if i > 0 else 0.0
|
101 |
result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n"
|
102 |
|
|
|
106 |
return result
|
107 |
|
108 |
|
109 |
+
def segments_to_srt(segment: TranscriptionSegment, i: int) -> str:
|
110 |
return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n"
|
111 |
|
112 |
|
|
|
117 |
return text.lower().strip().strip(".,?!")
|
118 |
|
119 |
|
120 |
+
def common_prefix(a: list[TranscriptionWord], b: list[TranscriptionWord]) -> list[TranscriptionWord]:
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
i = 0
|
122 |
while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
|
123 |
i += 1
|
124 |
return a[:i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/faster_whisper_server/text_utils_test.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from faster_whisper_server.api_models import TranscriptionWord
|
2 |
+
from faster_whisper_server.text_utils import (
|
3 |
+
canonicalize_word,
|
4 |
+
common_prefix,
|
5 |
+
is_eos,
|
6 |
+
srt_format_timestamp,
|
7 |
+
to_full_sentences,
|
8 |
+
vtt_format_timestamp,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def test_is_eos() -> None:
|
13 |
+
assert not is_eos("Hello")
|
14 |
+
assert not is_eos("Hello...")
|
15 |
+
assert is_eos("Hello.")
|
16 |
+
assert is_eos("Hello!")
|
17 |
+
assert is_eos("Hello?")
|
18 |
+
assert not is_eos("Hello. Yo")
|
19 |
+
assert not is_eos("Hello. Yo...")
|
20 |
+
assert is_eos("Hello. Yo.")
|
21 |
+
|
22 |
+
|
23 |
+
def tests_to_full_sentences() -> None:
|
24 |
+
def word(text: str) -> TranscriptionWord:
|
25 |
+
return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0)
|
26 |
+
|
27 |
+
assert to_full_sentences([]) == []
|
28 |
+
assert to_full_sentences([word(text="Hello")]) == []
|
29 |
+
assert to_full_sentences([word(text="Hello..."), word(" world")]) == []
|
30 |
+
assert to_full_sentences([word(text="Hello..."), word(" world.")]) == [[word("Hello..."), word(" world.")]]
|
31 |
+
assert to_full_sentences([word(text="Hello..."), word(" world."), word(" How")]) == [
|
32 |
+
[word("Hello..."), word(" world.")],
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
def test_srt_format_timestamp() -> None:
|
37 |
+
assert srt_format_timestamp(0.0) == "00:00:00,000"
|
38 |
+
assert srt_format_timestamp(1.0) == "00:00:01,000"
|
39 |
+
assert srt_format_timestamp(1.234) == "00:00:01,234"
|
40 |
+
assert srt_format_timestamp(60.0) == "00:01:00,000"
|
41 |
+
assert srt_format_timestamp(61.0) == "00:01:01,000"
|
42 |
+
assert srt_format_timestamp(61.234) == "00:01:01,234"
|
43 |
+
assert srt_format_timestamp(3600.0) == "01:00:00,000"
|
44 |
+
assert srt_format_timestamp(3601.0) == "01:00:01,000"
|
45 |
+
assert srt_format_timestamp(3601.234) == "01:00:01,234"
|
46 |
+
assert srt_format_timestamp(23423.4234) == "06:30:23,423"
|
47 |
+
|
48 |
+
|
49 |
+
def test_vtt_format_timestamp() -> None:
|
50 |
+
assert vtt_format_timestamp(0.0) == "00:00:00.000"
|
51 |
+
assert vtt_format_timestamp(1.0) == "00:00:01.000"
|
52 |
+
assert vtt_format_timestamp(1.234) == "00:00:01.234"
|
53 |
+
assert vtt_format_timestamp(60.0) == "00:01:00.000"
|
54 |
+
assert vtt_format_timestamp(61.0) == "00:01:01.000"
|
55 |
+
assert vtt_format_timestamp(61.234) == "00:01:01.234"
|
56 |
+
assert vtt_format_timestamp(3600.0) == "01:00:00.000"
|
57 |
+
assert vtt_format_timestamp(3601.0) == "01:00:01.000"
|
58 |
+
assert vtt_format_timestamp(3601.234) == "01:00:01.234"
|
59 |
+
assert vtt_format_timestamp(23423.4234) == "06:30:23.423"
|
60 |
+
|
61 |
+
|
62 |
+
def test_canonicalize_word() -> None:
|
63 |
+
assert canonicalize_word("ABC") == "abc"
|
64 |
+
assert canonicalize_word("...ABC?") == "abc"
|
65 |
+
assert canonicalize_word("... AbC ...") == "abc"
|
66 |
+
|
67 |
+
|
68 |
+
def test_common_prefix() -> None:
|
69 |
+
def word(text: str) -> TranscriptionWord:
|
70 |
+
return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0)
|
71 |
+
|
72 |
+
a = [word("a"), word("b"), word("c")]
|
73 |
+
b = [word("a"), word("b"), word("c")]
|
74 |
+
assert common_prefix(a, b) == [word("a"), word("b"), word("c")]
|
75 |
+
|
76 |
+
a = [word("a"), word("b"), word("c")]
|
77 |
+
b = [word("a"), word("b"), word("d")]
|
78 |
+
assert common_prefix(a, b) == [word("a"), word("b")]
|
79 |
+
|
80 |
+
a = [word("a"), word("b"), word("c")]
|
81 |
+
b = [word("a")]
|
82 |
+
assert common_prefix(a, b) == [word("a")]
|
83 |
+
|
84 |
+
a = [word("a")]
|
85 |
+
b = [word("a"), word("b"), word("c")]
|
86 |
+
assert common_prefix(a, b) == [word("a")]
|
87 |
+
|
88 |
+
a = [word("a")]
|
89 |
+
b = []
|
90 |
+
assert common_prefix(a, b) == []
|
91 |
+
|
92 |
+
a = []
|
93 |
+
b = [word("a")]
|
94 |
+
assert common_prefix(a, b) == []
|
95 |
+
|
96 |
+
a = [word("a"), word("b"), word("c")]
|
97 |
+
b = [word("b"), word("c")]
|
98 |
+
assert common_prefix(a, b) == []
|
99 |
+
|
100 |
+
|
101 |
+
def test_common_prefix_and_canonicalization() -> None:
|
102 |
+
def word(text: str) -> TranscriptionWord:
|
103 |
+
return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0)
|
104 |
+
|
105 |
+
a = [word("A...")]
|
106 |
+
b = [word("a?"), word("b"), word("c")]
|
107 |
+
assert common_prefix(a, b) == [word("A...")]
|
108 |
+
|
109 |
+
a = [word("A..."), word("B?"), word("C,")]
|
110 |
+
b = [word("a??"), word(" b"), word(" ,c")]
|
111 |
+
assert common_prefix(a, b) == [word("A..."), word("B?"), word("C,")]
|
src/faster_whisper_server/transcriber.py
CHANGED
@@ -4,11 +4,12 @@ import logging
|
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
6 |
from faster_whisper_server.audio import Audio, AudioStream
|
7 |
-
from faster_whisper_server.
|
8 |
|
9 |
if TYPE_CHECKING:
|
10 |
from collections.abc import AsyncGenerator
|
11 |
|
|
|
12 |
from faster_whisper_server.asr import FasterWhisperASR
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
@@ -18,7 +19,7 @@ class LocalAgreement:
|
|
18 |
def __init__(self) -> None:
|
19 |
self.unconfirmed = Transcription()
|
20 |
|
21 |
-
def merge(self, confirmed: Transcription, incoming: Transcription) -> list[
|
22 |
# https://github.com/ufal/whisper_streaming/blob/main/whisper_online.py#L264
|
23 |
incoming = incoming.after(confirmed.end - 0.1)
|
24 |
prefix = common_prefix(incoming.words, self.unconfirmed.words)
|
|
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
6 |
from faster_whisper_server.audio import Audio, AudioStream
|
7 |
+
from faster_whisper_server.text_utils import Transcription, common_prefix, to_full_sentences, word_to_text
|
8 |
|
9 |
if TYPE_CHECKING:
|
10 |
from collections.abc import AsyncGenerator
|
11 |
|
12 |
+
from faster_whisper_server.api_models import TranscriptionWord
|
13 |
from faster_whisper_server.asr import FasterWhisperASR
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
|
|
19 |
def __init__(self) -> None:
|
20 |
self.unconfirmed = Transcription()
|
21 |
|
22 |
+
def merge(self, confirmed: Transcription, incoming: Transcription) -> list[TranscriptionWord]:
|
23 |
# https://github.com/ufal/whisper_streaming/blob/main/whisper_online.py#L264
|
24 |
incoming = incoming.after(confirmed.end - 0.1)
|
25 |
prefix = common_prefix(incoming.words, self.unconfirmed.words)
|
tests/api_timestamp_granularities_test.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501
|
2 |
|
3 |
-
from faster_whisper_server.
|
4 |
from openai import AsyncOpenAI
|
5 |
import pytest
|
6 |
|
|
|
1 |
"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501
|
2 |
|
3 |
+
from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
|
4 |
from openai import AsyncOpenAI
|
5 |
import pytest
|
6 |
|
tests/openai_timestamp_granularities_test.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
|
2 |
|
3 |
-
from faster_whisper_server.
|
4 |
from openai import AsyncOpenAI, BadRequestError
|
5 |
import pytest
|
6 |
|
|
|
1 |
"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
|
2 |
|
3 |
+
from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
|
4 |
from openai import AsyncOpenAI, BadRequestError
|
5 |
import pytest
|
6 |
|
tests/sse_test.py
CHANGED
@@ -2,9 +2,9 @@ import json
|
|
2 |
import os
|
3 |
|
4 |
from fastapi.testclient import TestClient
|
5 |
-
from faster_whisper_server.
|
6 |
-
|
7 |
-
|
8 |
)
|
9 |
from httpx_sse import connect_sse
|
10 |
import pytest
|
@@ -48,7 +48,7 @@ def test_streaming_transcription_json(client: TestClient, file_path: str, endpoi
|
|
48 |
}
|
49 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
50 |
for event in event_source.iter_sse():
|
51 |
-
|
52 |
|
53 |
|
54 |
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
@@ -62,7 +62,7 @@ def test_streaming_transcription_verbose_json(client: TestClient, file_path: str
|
|
62 |
}
|
63 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
64 |
for event in event_source.iter_sse():
|
65 |
-
|
66 |
|
67 |
|
68 |
def test_transcription_vtt(client: TestClient) -> None:
|
|
|
2 |
import os
|
3 |
|
4 |
from fastapi.testclient import TestClient
|
5 |
+
from faster_whisper_server.api_models import (
|
6 |
+
CreateTranscriptionResponseJson,
|
7 |
+
CreateTranscriptionResponseVerboseJson,
|
8 |
)
|
9 |
from httpx_sse import connect_sse
|
10 |
import pytest
|
|
|
48 |
}
|
49 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
50 |
for event in event_source.iter_sse():
|
51 |
+
CreateTranscriptionResponseJson(**json.loads(event.data))
|
52 |
|
53 |
|
54 |
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
|
|
62 |
}
|
63 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
64 |
for event in event_source.iter_sse():
|
65 |
+
CreateTranscriptionResponseVerboseJson(**json.loads(event.data))
|
66 |
|
67 |
|
68 |
def test_transcription_vtt(client: TestClient) -> None:
|