Fedir Zadniprovskyi commited on
Commit
ec4d8ae
·
1 Parent(s): 6b7295b

refactor: update response model names and module name

Browse files
src/faster_whisper_server/{server_models.py → api_models.py} RENAMED
@@ -4,36 +4,117 @@ from typing import TYPE_CHECKING, Literal
4
 
5
  from pydantic import BaseModel, ConfigDict, Field
6
 
7
- from faster_whisper_server.core import Segment, Transcription, Word, segments_to_text
8
 
9
  if TYPE_CHECKING:
10
- from faster_whisper.transcribe import TranscriptionInfo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  # https://platform.openai.com/docs/api-reference/audio/json-object
14
- class TranscriptionJsonResponse(BaseModel):
 
15
  text: str
16
 
17
  @classmethod
18
- def from_segments(cls, segments: list[Segment]) -> TranscriptionJsonResponse:
19
  return cls(text=segments_to_text(segments))
20
 
21
  @classmethod
22
- def from_transcription(cls, transcription: Transcription) -> TranscriptionJsonResponse:
23
  return cls(text=transcription.text)
24
 
25
 
26
  # https://platform.openai.com/docs/api-reference/audio/verbose-json-object
27
- class TranscriptionVerboseJsonResponse(BaseModel):
 
28
  task: str = "transcribe"
29
  language: str
30
  duration: float
31
  text: str
32
- words: list[Word] | None
33
- segments: list[Segment]
34
 
35
  @classmethod
36
- def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) -> TranscriptionVerboseJsonResponse:
 
 
37
  return cls(
38
  language=transcription_info.language,
39
  duration=segment.end - segment.start,
@@ -44,18 +125,20 @@ class TranscriptionVerboseJsonResponse(BaseModel):
44
 
45
  @classmethod
46
  def from_segments(
47
- cls, segments: list[Segment], transcription_info: TranscriptionInfo
48
- ) -> TranscriptionVerboseJsonResponse:
49
  return cls(
50
  language=transcription_info.language,
51
  duration=transcription_info.duration,
52
  text=segments_to_text(segments),
53
  segments=segments,
54
- words=Word.from_segments(segments) if transcription_info.transcription_options.word_timestamps else None,
 
 
55
  )
56
 
57
  @classmethod
58
- def from_transcription(cls, transcription: Transcription) -> TranscriptionVerboseJsonResponse:
59
  return cls(
60
  language="english", # FIX: hardcoded
61
  duration=transcription.duration,
@@ -65,12 +148,14 @@ class TranscriptionVerboseJsonResponse(BaseModel):
65
  )
66
 
67
 
68
- class ModelListResponse(BaseModel):
69
- data: list[ModelObject]
 
70
  object: Literal["list"] = "list"
71
 
72
 
73
- class ModelObject(BaseModel):
 
74
  id: str
75
  """The model identifier, which can be referenced in the API endpoints."""
76
  created: int
@@ -109,6 +194,7 @@ class ModelObject(BaseModel):
109
  )
110
 
111
 
 
112
  TimestampGranularities = list[Literal["segment", "word"]]
113
 
114
 
 
4
 
5
  from pydantic import BaseModel, ConfigDict, Field
6
 
7
+ from faster_whisper_server.text_utils import Transcription, canonicalize_word, segments_to_text
8
 
9
  if TYPE_CHECKING:
10
+ from collections.abc import Iterable
11
+
12
+ import faster_whisper.transcribe
13
+
14
+
15
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10909
16
+ class TranscriptionWord(BaseModel):
17
+ start: float
18
+ end: float
19
+ word: str
20
+ probability: float
21
+
22
+ @classmethod
23
+ def from_segments(cls, segments: Iterable[TranscriptionSegment]) -> list[TranscriptionWord]:
24
+ words: list[TranscriptionWord] = []
25
+ for segment in segments:
26
+ # NOTE: a temporary "fix" for https://github.com/fedirz/faster-whisper-server/issues/58.
27
+ # TODO: properly address the issue
28
+ assert (
29
+ segment.words is not None
30
+ ), "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set"
31
+ words.extend(segment.words)
32
+ return words
33
+
34
+ def offset(self, seconds: float) -> None:
35
+ self.start += seconds
36
+ self.end += seconds
37
+
38
+ @classmethod
39
+ def common_prefix(cls, a: list[TranscriptionWord], b: list[TranscriptionWord]) -> list[TranscriptionWord]:
40
+ i = 0
41
+ while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
42
+ i += 1
43
+ return a[:i]
44
+
45
+
46
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10938
47
+ class TranscriptionSegment(BaseModel):
48
+ id: int
49
+ seek: int
50
+ start: float
51
+ end: float
52
+ text: str
53
+ tokens: list[int]
54
+ temperature: float
55
+ avg_logprob: float
56
+ compression_ratio: float
57
+ no_speech_prob: float
58
+ words: list[TranscriptionWord] | None
59
+
60
+ @classmethod
61
+ def from_faster_whisper_segments(
62
+ cls, segments: Iterable[faster_whisper.transcribe.Segment]
63
+ ) -> Iterable[TranscriptionSegment]:
64
+ for segment in segments:
65
+ yield cls(
66
+ id=segment.id,
67
+ seek=segment.seek,
68
+ start=segment.start,
69
+ end=segment.end,
70
+ text=segment.text,
71
+ tokens=segment.tokens,
72
+ temperature=segment.temperature,
73
+ avg_logprob=segment.avg_logprob,
74
+ compression_ratio=segment.compression_ratio,
75
+ no_speech_prob=segment.no_speech_prob,
76
+ words=[
77
+ TranscriptionWord(
78
+ start=word.start,
79
+ end=word.end,
80
+ word=word.word,
81
+ probability=word.probability,
82
+ )
83
+ for word in segment.words
84
+ ]
85
+ if segment.words is not None
86
+ else None,
87
+ )
88
 
89
 
90
  # https://platform.openai.com/docs/api-reference/audio/json-object
91
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10924
92
+ class CreateTranscriptionResponseJson(BaseModel):
93
  text: str
94
 
95
  @classmethod
96
+ def from_segments(cls, segments: list[TranscriptionSegment]) -> CreateTranscriptionResponseJson:
97
  return cls(text=segments_to_text(segments))
98
 
99
  @classmethod
100
+ def from_transcription(cls, transcription: Transcription) -> CreateTranscriptionResponseJson:
101
  return cls(text=transcription.text)
102
 
103
 
104
  # https://platform.openai.com/docs/api-reference/audio/verbose-json-object
105
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L11007
106
+ class CreateTranscriptionResponseVerboseJson(BaseModel):
107
  task: str = "transcribe"
108
  language: str
109
  duration: float
110
  text: str
111
+ words: list[TranscriptionWord] | None
112
+ segments: list[TranscriptionSegment]
113
 
114
  @classmethod
115
+ def from_segment(
116
+ cls, segment: TranscriptionSegment, transcription_info: faster_whisper.transcribe.TranscriptionInfo
117
+ ) -> CreateTranscriptionResponseVerboseJson:
118
  return cls(
119
  language=transcription_info.language,
120
  duration=segment.end - segment.start,
 
125
 
126
  @classmethod
127
  def from_segments(
128
+ cls, segments: list[TranscriptionSegment], transcription_info: faster_whisper.transcribe.TranscriptionInfo
129
+ ) -> CreateTranscriptionResponseVerboseJson:
130
  return cls(
131
  language=transcription_info.language,
132
  duration=transcription_info.duration,
133
  text=segments_to_text(segments),
134
  segments=segments,
135
+ words=TranscriptionWord.from_segments(segments)
136
+ if transcription_info.transcription_options.word_timestamps
137
+ else None,
138
  )
139
 
140
  @classmethod
141
+ def from_transcription(cls, transcription: Transcription) -> CreateTranscriptionResponseVerboseJson:
142
  return cls(
143
  language="english", # FIX: hardcoded
144
  duration=transcription.duration,
 
148
  )
149
 
150
 
151
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8730
152
+ class ListModelsResponse(BaseModel):
153
+ data: list[Model]
154
  object: Literal["list"] = "list"
155
 
156
 
157
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L11146
158
+ class Model(BaseModel):
159
  id: str
160
  """The model identifier, which can be referenced in the API endpoints."""
161
  created: int
 
194
  )
195
 
196
 
197
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10909
198
  TimestampGranularities = list[Literal["segment", "word"]]
199
 
200
 
src/faster_whisper_server/asr.py CHANGED
@@ -1,11 +1,17 @@
 
 
1
  import asyncio
2
  import logging
3
  import time
 
 
 
 
4
 
5
- from faster_whisper import transcribe
 
6
 
7
- from faster_whisper_server.audio import Audio
8
- from faster_whisper_server.core import Segment, Transcription, Word
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -31,8 +37,8 @@ class FasterWhisperASR:
31
  word_timestamps=True,
32
  **self.transcribe_opts,
33
  )
34
- segments = Segment.from_faster_whisper_segments(segments)
35
- words = Word.from_segments(segments)
36
  for word in words:
37
  word.offset(audio.start)
38
  transcription = Transcription(words)
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import logging
5
  import time
6
+ from typing import TYPE_CHECKING
7
+
8
+ from faster_whisper_server.api_models import TranscriptionSegment, TranscriptionWord
9
+ from faster_whisper_server.text_utils import Transcription
10
 
11
+ if TYPE_CHECKING:
12
+ from faster_whisper import transcribe
13
 
14
+ from faster_whisper_server.audio import Audio
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
37
  word_timestamps=True,
38
  **self.transcribe_opts,
39
  )
40
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments)
41
+ words = TranscriptionWord.from_segments(segments)
42
  for word in words:
43
  word.offset(audio.start)
44
  transcription = Transcription(words)
src/faster_whisper_server/routers/list_models.py CHANGED
@@ -9,9 +9,9 @@ from fastapi import (
9
  )
10
  import huggingface_hub
11
 
12
- from faster_whisper_server.server_models import (
13
- ModelListResponse,
14
- ModelObject,
15
  )
16
 
17
  if TYPE_CHECKING:
@@ -21,11 +21,11 @@ router = APIRouter()
21
 
22
 
23
  @router.get("/v1/models")
24
- def get_models() -> ModelListResponse:
25
  models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
26
  models = list(models)
27
  models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
28
- transformed_models: list[ModelObject] = []
29
  for model in models:
30
  assert model.created_at is not None
31
  assert model.card_data is not None
@@ -36,7 +36,7 @@ def get_models() -> ModelListResponse:
36
  language = [model.card_data.language]
37
  else:
38
  language = model.card_data.language
39
- transformed_model = ModelObject(
40
  id=model.id,
41
  created=int(model.created_at.timestamp()),
42
  object_="model",
@@ -44,14 +44,14 @@ def get_models() -> ModelListResponse:
44
  language=language,
45
  )
46
  transformed_models.append(transformed_model)
47
- return ModelListResponse(data=transformed_models)
48
 
49
 
50
  @router.get("/v1/models/{model_name:path}")
51
  # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
52
  def get_model(
53
  model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
54
- ) -> ModelObject:
55
  models = huggingface_hub.list_models(
56
  model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
57
  )
@@ -78,7 +78,7 @@ def get_model(
78
  language = [exact_match.card_data.language]
79
  else:
80
  language = exact_match.card_data.language
81
- return ModelObject(
82
  id=exact_match.id,
83
  created=int(exact_match.created_at.timestamp()),
84
  object_="model",
 
9
  )
10
  import huggingface_hub
11
 
12
+ from faster_whisper_server.api_models import (
13
+ ListModelsResponse,
14
+ Model,
15
  )
16
 
17
  if TYPE_CHECKING:
 
21
 
22
 
23
  @router.get("/v1/models")
24
+ def get_models() -> ListModelsResponse:
25
  models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
26
  models = list(models)
27
  models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
28
+ transformed_models: list[Model] = []
29
  for model in models:
30
  assert model.created_at is not None
31
  assert model.card_data is not None
 
36
  language = [model.card_data.language]
37
  else:
38
  language = model.card_data.language
39
+ transformed_model = Model(
40
  id=model.id,
41
  created=int(model.created_at.timestamp()),
42
  object_="model",
 
44
  language=language,
45
  )
46
  transformed_models.append(transformed_model)
47
+ return ListModelsResponse(data=transformed_models)
48
 
49
 
50
  @router.get("/v1/models/{model_name:path}")
51
  # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
52
  def get_model(
53
  model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
54
+ ) -> Model:
55
  models = huggingface_hub.list_models(
56
  model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
57
  )
 
78
  language = [exact_match.card_data.language]
79
  else:
80
  language = exact_match.card_data.language
81
+ return Model(
82
  id=exact_match.id,
83
  created=int(exact_match.created_at.timestamp()),
84
  object_="model",
src/faster_whisper_server/routers/stt.py CHANGED
@@ -20,6 +20,14 @@ from fastapi.websockets import WebSocketState
20
  from faster_whisper.vad import VadOptions, get_speech_timestamps
21
  from pydantic import AfterValidator
22
 
 
 
 
 
 
 
 
 
23
  from faster_whisper_server.asr import FasterWhisperASR
24
  from faster_whisper_server.audio import AudioStream, audio_samples_from_file
25
  from faster_whisper_server.config import (
@@ -28,15 +36,8 @@ from faster_whisper_server.config import (
28
  ResponseFormat,
29
  Task,
30
  )
31
- from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
32
  from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
33
- from faster_whisper_server.server_models import (
34
- DEFAULT_TIMESTAMP_GRANULARITIES,
35
- TIMESTAMP_GRANULARITIES_COMBINATIONS,
36
- TimestampGranularities,
37
- TranscriptionJsonResponse,
38
- TranscriptionVerboseJsonResponse,
39
- )
40
  from faster_whisper_server.transcriber import audio_transcriber
41
 
42
  if TYPE_CHECKING:
@@ -51,7 +52,7 @@ router = APIRouter()
51
 
52
 
53
  def segments_to_response(
54
- segments: Iterable[Segment],
55
  transcription_info: TranscriptionInfo,
56
  response_format: ResponseFormat,
57
  ) -> Response:
@@ -60,12 +61,12 @@ def segments_to_response(
60
  return Response(segments_to_text(segments), media_type="text/plain")
61
  elif response_format == ResponseFormat.JSON:
62
  return Response(
63
- TranscriptionJsonResponse.from_segments(segments).model_dump_json(),
64
  media_type="application/json",
65
  )
66
  elif response_format == ResponseFormat.VERBOSE_JSON:
67
  return Response(
68
- TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(),
69
  media_type="application/json",
70
  )
71
  elif response_format == ResponseFormat.VTT:
@@ -83,7 +84,7 @@ def format_as_sse(data: str) -> str:
83
 
84
 
85
  def segments_to_streaming_response(
86
- segments: Iterable[Segment],
87
  transcription_info: TranscriptionInfo,
88
  response_format: ResponseFormat,
89
  ) -> StreamingResponse:
@@ -92,9 +93,11 @@ def segments_to_streaming_response(
92
  if response_format == ResponseFormat.TEXT:
93
  data = segment.text
94
  elif response_format == ResponseFormat.JSON:
95
- data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
96
  elif response_format == ResponseFormat.VERBOSE_JSON:
97
- data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
 
 
98
  elif response_format == ResponseFormat.VTT:
99
  data = segments_to_vtt(segment, i)
100
  elif response_format == ResponseFormat.SRT:
@@ -121,7 +124,7 @@ ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
121
 
122
  @router.post(
123
  "/v1/audio/translations",
124
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
125
  )
126
  def translate_file(
127
  config: ConfigDependency,
@@ -145,7 +148,7 @@ def translate_file(
145
  temperature=temperature,
146
  vad_filter=True,
147
  )
148
- segments = Segment.from_faster_whisper_segments(segments)
149
 
150
  if stream:
151
  return segments_to_streaming_response(segments, transcription_info, response_format)
@@ -169,7 +172,7 @@ async def get_timestamp_granularities(request: Request) -> TimestampGranularitie
169
  # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
170
  @router.post(
171
  "/v1/audio/transcriptions",
172
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
173
  )
174
  def transcribe_file(
175
  config: ConfigDependency,
@@ -211,7 +214,7 @@ def transcribe_file(
211
  vad_filter=True,
212
  hotwords=hotwords,
213
  )
214
- segments = Segment.from_faster_whisper_segments(segments)
215
 
216
  if stream:
217
  return segments_to_streaming_response(segments, transcription_info, response_format)
@@ -286,9 +289,11 @@ async def transcribe_stream(
286
  if response_format == ResponseFormat.TEXT:
287
  await ws.send_text(transcription.text)
288
  elif response_format == ResponseFormat.JSON:
289
- await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump())
290
  elif response_format == ResponseFormat.VERBOSE_JSON:
291
- await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump())
 
 
292
 
293
  if ws.client_state != WebSocketState.DISCONNECTED:
294
  logger.info("Closing the connection.")
 
20
  from faster_whisper.vad import VadOptions, get_speech_timestamps
21
  from pydantic import AfterValidator
22
 
23
+ from faster_whisper_server.api_models import (
24
+ DEFAULT_TIMESTAMP_GRANULARITIES,
25
+ TIMESTAMP_GRANULARITIES_COMBINATIONS,
26
+ CreateTranscriptionResponseJson,
27
+ CreateTranscriptionResponseVerboseJson,
28
+ TimestampGranularities,
29
+ TranscriptionSegment,
30
+ )
31
  from faster_whisper_server.asr import FasterWhisperASR
32
  from faster_whisper_server.audio import AudioStream, audio_samples_from_file
33
  from faster_whisper_server.config import (
 
36
  ResponseFormat,
37
  Task,
38
  )
 
39
  from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
40
+ from faster_whisper_server.text_utils import segments_to_srt, segments_to_text, segments_to_vtt
 
 
 
 
 
 
41
  from faster_whisper_server.transcriber import audio_transcriber
42
 
43
  if TYPE_CHECKING:
 
52
 
53
 
54
  def segments_to_response(
55
+ segments: Iterable[TranscriptionSegment],
56
  transcription_info: TranscriptionInfo,
57
  response_format: ResponseFormat,
58
  ) -> Response:
 
61
  return Response(segments_to_text(segments), media_type="text/plain")
62
  elif response_format == ResponseFormat.JSON:
63
  return Response(
64
+ CreateTranscriptionResponseJson.from_segments(segments).model_dump_json(),
65
  media_type="application/json",
66
  )
67
  elif response_format == ResponseFormat.VERBOSE_JSON:
68
  return Response(
69
+ CreateTranscriptionResponseVerboseJson.from_segments(segments, transcription_info).model_dump_json(),
70
  media_type="application/json",
71
  )
72
  elif response_format == ResponseFormat.VTT:
 
84
 
85
 
86
  def segments_to_streaming_response(
87
+ segments: Iterable[TranscriptionSegment],
88
  transcription_info: TranscriptionInfo,
89
  response_format: ResponseFormat,
90
  ) -> StreamingResponse:
 
93
  if response_format == ResponseFormat.TEXT:
94
  data = segment.text
95
  elif response_format == ResponseFormat.JSON:
96
+ data = CreateTranscriptionResponseJson.from_segments([segment]).model_dump_json()
97
  elif response_format == ResponseFormat.VERBOSE_JSON:
98
+ data = CreateTranscriptionResponseVerboseJson.from_segment(
99
+ segment, transcription_info
100
+ ).model_dump_json()
101
  elif response_format == ResponseFormat.VTT:
102
  data = segments_to_vtt(segment, i)
103
  elif response_format == ResponseFormat.SRT:
 
124
 
125
  @router.post(
126
  "/v1/audio/translations",
127
+ response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson,
128
  )
129
  def translate_file(
130
  config: ConfigDependency,
 
148
  temperature=temperature,
149
  vad_filter=True,
150
  )
151
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments)
152
 
153
  if stream:
154
  return segments_to_streaming_response(segments, transcription_info, response_format)
 
172
  # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
173
  @router.post(
174
  "/v1/audio/transcriptions",
175
+ response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson,
176
  )
177
  def transcribe_file(
178
  config: ConfigDependency,
 
214
  vad_filter=True,
215
  hotwords=hotwords,
216
  )
217
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments)
218
 
219
  if stream:
220
  return segments_to_streaming_response(segments, transcription_info, response_format)
 
289
  if response_format == ResponseFormat.TEXT:
290
  await ws.send_text(transcription.text)
291
  elif response_format == ResponseFormat.JSON:
292
+ await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
293
  elif response_format == ResponseFormat.VERBOSE_JSON:
294
+ await ws.send_json(
295
+ CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
296
+ )
297
 
298
  if ws.client_state != WebSocketState.DISCONNECTED:
299
  logger.info("Closing the connection.")
src/faster_whisper_server/{core.py → text_utils.py} RENAMED
@@ -3,90 +3,17 @@ from __future__ import annotations
3
  import re
4
  from typing import TYPE_CHECKING
5
 
6
- from pydantic import BaseModel
7
-
8
  from faster_whisper_server.dependencies import get_config
9
 
10
  if TYPE_CHECKING:
11
  from collections.abc import Iterable
12
 
13
- import faster_whisper.transcribe
14
-
15
-
16
- class Word(BaseModel):
17
- start: float
18
- end: float
19
- word: str
20
- probability: float
21
-
22
- @classmethod
23
- def from_segments(cls, segments: Iterable[Segment]) -> list[Word]:
24
- words: list[Word] = []
25
- for segment in segments:
26
- # NOTE: a temporary "fix" for https://github.com/fedirz/faster-whisper-server/issues/58.
27
- # TODO: properly address the issue
28
- assert (
29
- segment.words is not None
30
- ), "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set"
31
- words.extend(segment.words)
32
- return words
33
-
34
- def offset(self, seconds: float) -> None:
35
- self.start += seconds
36
- self.end += seconds
37
-
38
- @classmethod
39
- def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
40
- i = 0
41
- while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
42
- i += 1
43
- return a[:i]
44
-
45
-
46
- class Segment(BaseModel):
47
- id: int
48
- seek: int
49
- start: float
50
- end: float
51
- text: str
52
- tokens: list[int]
53
- temperature: float
54
- avg_logprob: float
55
- compression_ratio: float
56
- no_speech_prob: float
57
- words: list[Word] | None
58
-
59
- @classmethod
60
- def from_faster_whisper_segments(cls, segments: Iterable[faster_whisper.transcribe.Segment]) -> Iterable[Segment]:
61
- for segment in segments:
62
- yield cls(
63
- id=segment.id,
64
- seek=segment.seek,
65
- start=segment.start,
66
- end=segment.end,
67
- text=segment.text,
68
- tokens=segment.tokens,
69
- temperature=segment.temperature,
70
- avg_logprob=segment.avg_logprob,
71
- compression_ratio=segment.compression_ratio,
72
- no_speech_prob=segment.no_speech_prob,
73
- words=[
74
- Word(
75
- start=word.start,
76
- end=word.end,
77
- word=word.word,
78
- probability=word.probability,
79
- )
80
- for word in segment.words
81
- ]
82
- if segment.words is not None
83
- else None,
84
- )
85
 
86
 
87
  class Transcription:
88
- def __init__(self, words: list[Word] = []) -> None:
89
- self.words: list[Word] = []
90
  self.extend(words)
91
 
92
  @property
@@ -108,11 +35,11 @@ class Transcription:
108
  def after(self, seconds: float) -> Transcription:
109
  return Transcription(words=[word for word in self.words if word.start > seconds])
110
 
111
- def extend(self, words: list[Word]) -> None:
112
  self._ensure_no_word_overlap(words)
113
  self.words.extend(words)
114
 
115
- def _ensure_no_word_overlap(self, words: list[Word]) -> None:
116
  config = get_config() # HACK
117
  if len(self.words) > 0 and len(words) > 0:
118
  if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
@@ -130,19 +57,8 @@ def is_eos(text: str) -> bool:
130
  return any(text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
131
 
132
 
133
- def test_is_eos() -> None:
134
- assert not is_eos("Hello")
135
- assert not is_eos("Hello...")
136
- assert is_eos("Hello.")
137
- assert is_eos("Hello!")
138
- assert is_eos("Hello?")
139
- assert not is_eos("Hello. Yo")
140
- assert not is_eos("Hello. Yo...")
141
- assert is_eos("Hello. Yo.")
142
-
143
-
144
- def to_full_sentences(words: list[Word]) -> list[list[Word]]:
145
- sentences: list[list[Word]] = [[]]
146
  for word in words:
147
  sentences[-1].append(word)
148
  if is_eos(word.word):
@@ -152,28 +68,15 @@ def to_full_sentences(words: list[Word]) -> list[list[Word]]:
152
  return sentences
153
 
154
 
155
- def tests_to_full_sentences() -> None:
156
- def word(text: str) -> Word:
157
- return Word(word=text, start=0.0, end=0.0, probability=0.0)
158
-
159
- assert to_full_sentences([]) == []
160
- assert to_full_sentences([word(text="Hello")]) == []
161
- assert to_full_sentences([word(text="Hello..."), word(" world")]) == []
162
- assert to_full_sentences([word(text="Hello..."), word(" world.")]) == [[word("Hello..."), word(" world.")]]
163
- assert to_full_sentences([word(text="Hello..."), word(" world."), word(" How")]) == [
164
- [word("Hello..."), word(" world.")],
165
- ]
166
-
167
-
168
- def word_to_text(words: list[Word]) -> str:
169
  return "".join(word.word for word in words)
170
 
171
 
172
- def words_to_text_w_ts(words: list[Word]) -> str:
173
  return "".join(f"{word.word}({word.start:.2f}-{word.end:.2f})" for word in words)
174
 
175
 
176
- def segments_to_text(segments: Iterable[Segment]) -> str:
177
  return "".join(segment.text for segment in segments).strip()
178
 
179
 
@@ -185,19 +88,6 @@ def srt_format_timestamp(ts: float) -> str:
185
  return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
186
 
187
 
188
- def test_srt_format_timestamp() -> None:
189
- assert srt_format_timestamp(0.0) == "00:00:00,000"
190
- assert srt_format_timestamp(1.0) == "00:00:01,000"
191
- assert srt_format_timestamp(1.234) == "00:00:01,234"
192
- assert srt_format_timestamp(60.0) == "00:01:00,000"
193
- assert srt_format_timestamp(61.0) == "00:01:01,000"
194
- assert srt_format_timestamp(61.234) == "00:01:01,234"
195
- assert srt_format_timestamp(3600.0) == "01:00:00,000"
196
- assert srt_format_timestamp(3601.0) == "01:00:01,000"
197
- assert srt_format_timestamp(3601.234) == "01:00:01,234"
198
- assert srt_format_timestamp(23423.4234) == "06:30:23,423"
199
-
200
-
201
  def vtt_format_timestamp(ts: float) -> str:
202
  hours = ts // 3600
203
  minutes = (ts % 3600) // 60
@@ -206,20 +96,7 @@ def vtt_format_timestamp(ts: float) -> str:
206
  return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
207
 
208
 
209
- def test_vtt_format_timestamp() -> None:
210
- assert vtt_format_timestamp(0.0) == "00:00:00.000"
211
- assert vtt_format_timestamp(1.0) == "00:00:01.000"
212
- assert vtt_format_timestamp(1.234) == "00:00:01.234"
213
- assert vtt_format_timestamp(60.0) == "00:01:00.000"
214
- assert vtt_format_timestamp(61.0) == "00:01:01.000"
215
- assert vtt_format_timestamp(61.234) == "00:01:01.234"
216
- assert vtt_format_timestamp(3600.0) == "01:00:00.000"
217
- assert vtt_format_timestamp(3601.0) == "01:00:01.000"
218
- assert vtt_format_timestamp(3601.234) == "01:00:01.234"
219
- assert vtt_format_timestamp(23423.4234) == "06:30:23.423"
220
-
221
-
222
- def segments_to_vtt(segment: Segment, i: int) -> str:
223
  start = segment.start if i > 0 else 0.0
224
  result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n"
225
 
@@ -229,7 +106,7 @@ def segments_to_vtt(segment: Segment, i: int) -> str:
229
  return result
230
 
231
 
232
- def segments_to_srt(segment: Segment, i: int) -> str:
233
  return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n"
234
 
235
 
@@ -240,60 +117,8 @@ def canonicalize_word(text: str) -> str:
240
  return text.lower().strip().strip(".,?!")
241
 
242
 
243
- def test_canonicalize_word() -> None:
244
- assert canonicalize_word("ABC") == "abc"
245
- assert canonicalize_word("...ABC?") == "abc"
246
- assert canonicalize_word("... AbC ...") == "abc"
247
-
248
-
249
- def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
250
  i = 0
251
  while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
252
  i += 1
253
  return a[:i]
254
-
255
-
256
- def test_common_prefix() -> None:
257
- def word(text: str) -> Word:
258
- return Word(word=text, start=0.0, end=0.0, probability=0.0)
259
-
260
- a = [word("a"), word("b"), word("c")]
261
- b = [word("a"), word("b"), word("c")]
262
- assert common_prefix(a, b) == [word("a"), word("b"), word("c")]
263
-
264
- a = [word("a"), word("b"), word("c")]
265
- b = [word("a"), word("b"), word("d")]
266
- assert common_prefix(a, b) == [word("a"), word("b")]
267
-
268
- a = [word("a"), word("b"), word("c")]
269
- b = [word("a")]
270
- assert common_prefix(a, b) == [word("a")]
271
-
272
- a = [word("a")]
273
- b = [word("a"), word("b"), word("c")]
274
- assert common_prefix(a, b) == [word("a")]
275
-
276
- a = [word("a")]
277
- b = []
278
- assert common_prefix(a, b) == []
279
-
280
- a = []
281
- b = [word("a")]
282
- assert common_prefix(a, b) == []
283
-
284
- a = [word("a"), word("b"), word("c")]
285
- b = [word("b"), word("c")]
286
- assert common_prefix(a, b) == []
287
-
288
-
289
- def test_common_prefix_and_canonicalization() -> None:
290
- def word(text: str) -> Word:
291
- return Word(word=text, start=0.0, end=0.0, probability=0.0)
292
-
293
- a = [word("A...")]
294
- b = [word("a?"), word("b"), word("c")]
295
- assert common_prefix(a, b) == [word("A...")]
296
-
297
- a = [word("A..."), word("B?"), word("C,")]
298
- b = [word("a??"), word(" b"), word(" ,c")]
299
- assert common_prefix(a, b) == [word("A..."), word("B?"), word("C,")]
 
3
  import re
4
  from typing import TYPE_CHECKING
5
 
 
 
6
  from faster_whisper_server.dependencies import get_config
7
 
8
  if TYPE_CHECKING:
9
  from collections.abc import Iterable
10
 
11
+ from faster_whisper_server.api_models import TranscriptionSegment, TranscriptionWord
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class Transcription:
15
+ def __init__(self, words: list[TranscriptionWord] = []) -> None:
16
+ self.words: list[TranscriptionWord] = []
17
  self.extend(words)
18
 
19
  @property
 
35
  def after(self, seconds: float) -> Transcription:
36
  return Transcription(words=[word for word in self.words if word.start > seconds])
37
 
38
+ def extend(self, words: list[TranscriptionWord]) -> None:
39
  self._ensure_no_word_overlap(words)
40
  self.words.extend(words)
41
 
42
+ def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None:
43
  config = get_config() # HACK
44
  if len(self.words) > 0 and len(words) > 0:
45
  if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
 
57
  return any(text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
58
 
59
 
60
+ def to_full_sentences(words: list[TranscriptionWord]) -> list[list[TranscriptionWord]]:
61
+ sentences: list[list[TranscriptionWord]] = [[]]
 
 
 
 
 
 
 
 
 
 
 
62
  for word in words:
63
  sentences[-1].append(word)
64
  if is_eos(word.word):
 
68
  return sentences
69
 
70
 
71
+ def word_to_text(words: list[TranscriptionWord]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  return "".join(word.word for word in words)
73
 
74
 
75
+ def words_to_text_w_ts(words: list[TranscriptionWord]) -> str:
76
  return "".join(f"{word.word}({word.start:.2f}-{word.end:.2f})" for word in words)
77
 
78
 
79
+ def segments_to_text(segments: Iterable[TranscriptionSegment]) -> str:
80
  return "".join(segment.text for segment in segments).strip()
81
 
82
 
 
88
  return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def vtt_format_timestamp(ts: float) -> str:
92
  hours = ts // 3600
93
  minutes = (ts % 3600) // 60
 
96
  return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
97
 
98
 
99
+ def segments_to_vtt(segment: TranscriptionSegment, i: int) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  start = segment.start if i > 0 else 0.0
101
  result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n"
102
 
 
106
  return result
107
 
108
 
109
+ def segments_to_srt(segment: TranscriptionSegment, i: int) -> str:
110
  return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n"
111
 
112
 
 
117
  return text.lower().strip().strip(".,?!")
118
 
119
 
120
+ def common_prefix(a: list[TranscriptionWord], b: list[TranscriptionWord]) -> list[TranscriptionWord]:
 
 
 
 
 
 
121
  i = 0
122
  while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
123
  i += 1
124
  return a[:i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/faster_whisper_server/text_utils_test.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from faster_whisper_server.api_models import TranscriptionWord
2
+ from faster_whisper_server.text_utils import (
3
+ canonicalize_word,
4
+ common_prefix,
5
+ is_eos,
6
+ srt_format_timestamp,
7
+ to_full_sentences,
8
+ vtt_format_timestamp,
9
+ )
10
+
11
+
12
+ def test_is_eos() -> None:
13
+ assert not is_eos("Hello")
14
+ assert not is_eos("Hello...")
15
+ assert is_eos("Hello.")
16
+ assert is_eos("Hello!")
17
+ assert is_eos("Hello?")
18
+ assert not is_eos("Hello. Yo")
19
+ assert not is_eos("Hello. Yo...")
20
+ assert is_eos("Hello. Yo.")
21
+
22
+
23
+ def tests_to_full_sentences() -> None:
24
+ def word(text: str) -> TranscriptionWord:
25
+ return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0)
26
+
27
+ assert to_full_sentences([]) == []
28
+ assert to_full_sentences([word(text="Hello")]) == []
29
+ assert to_full_sentences([word(text="Hello..."), word(" world")]) == []
30
+ assert to_full_sentences([word(text="Hello..."), word(" world.")]) == [[word("Hello..."), word(" world.")]]
31
+ assert to_full_sentences([word(text="Hello..."), word(" world."), word(" How")]) == [
32
+ [word("Hello..."), word(" world.")],
33
+ ]
34
+
35
+
36
+ def test_srt_format_timestamp() -> None:
37
+ assert srt_format_timestamp(0.0) == "00:00:00,000"
38
+ assert srt_format_timestamp(1.0) == "00:00:01,000"
39
+ assert srt_format_timestamp(1.234) == "00:00:01,234"
40
+ assert srt_format_timestamp(60.0) == "00:01:00,000"
41
+ assert srt_format_timestamp(61.0) == "00:01:01,000"
42
+ assert srt_format_timestamp(61.234) == "00:01:01,234"
43
+ assert srt_format_timestamp(3600.0) == "01:00:00,000"
44
+ assert srt_format_timestamp(3601.0) == "01:00:01,000"
45
+ assert srt_format_timestamp(3601.234) == "01:00:01,234"
46
+ assert srt_format_timestamp(23423.4234) == "06:30:23,423"
47
+
48
+
49
+ def test_vtt_format_timestamp() -> None:
50
+ assert vtt_format_timestamp(0.0) == "00:00:00.000"
51
+ assert vtt_format_timestamp(1.0) == "00:00:01.000"
52
+ assert vtt_format_timestamp(1.234) == "00:00:01.234"
53
+ assert vtt_format_timestamp(60.0) == "00:01:00.000"
54
+ assert vtt_format_timestamp(61.0) == "00:01:01.000"
55
+ assert vtt_format_timestamp(61.234) == "00:01:01.234"
56
+ assert vtt_format_timestamp(3600.0) == "01:00:00.000"
57
+ assert vtt_format_timestamp(3601.0) == "01:00:01.000"
58
+ assert vtt_format_timestamp(3601.234) == "01:00:01.234"
59
+ assert vtt_format_timestamp(23423.4234) == "06:30:23.423"
60
+
61
+
62
+ def test_canonicalize_word() -> None:
63
+ assert canonicalize_word("ABC") == "abc"
64
+ assert canonicalize_word("...ABC?") == "abc"
65
+ assert canonicalize_word("... AbC ...") == "abc"
66
+
67
+
68
+ def test_common_prefix() -> None:
69
+ def word(text: str) -> TranscriptionWord:
70
+ return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0)
71
+
72
+ a = [word("a"), word("b"), word("c")]
73
+ b = [word("a"), word("b"), word("c")]
74
+ assert common_prefix(a, b) == [word("a"), word("b"), word("c")]
75
+
76
+ a = [word("a"), word("b"), word("c")]
77
+ b = [word("a"), word("b"), word("d")]
78
+ assert common_prefix(a, b) == [word("a"), word("b")]
79
+
80
+ a = [word("a"), word("b"), word("c")]
81
+ b = [word("a")]
82
+ assert common_prefix(a, b) == [word("a")]
83
+
84
+ a = [word("a")]
85
+ b = [word("a"), word("b"), word("c")]
86
+ assert common_prefix(a, b) == [word("a")]
87
+
88
+ a = [word("a")]
89
+ b = []
90
+ assert common_prefix(a, b) == []
91
+
92
+ a = []
93
+ b = [word("a")]
94
+ assert common_prefix(a, b) == []
95
+
96
+ a = [word("a"), word("b"), word("c")]
97
+ b = [word("b"), word("c")]
98
+ assert common_prefix(a, b) == []
99
+
100
+
101
+ def test_common_prefix_and_canonicalization() -> None:
102
+ def word(text: str) -> TranscriptionWord:
103
+ return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0)
104
+
105
+ a = [word("A...")]
106
+ b = [word("a?"), word("b"), word("c")]
107
+ assert common_prefix(a, b) == [word("A...")]
108
+
109
+ a = [word("A..."), word("B?"), word("C,")]
110
+ b = [word("a??"), word(" b"), word(" ,c")]
111
+ assert common_prefix(a, b) == [word("A..."), word("B?"), word("C,")]
src/faster_whisper_server/transcriber.py CHANGED
@@ -4,11 +4,12 @@ import logging
4
  from typing import TYPE_CHECKING
5
 
6
  from faster_whisper_server.audio import Audio, AudioStream
7
- from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text
8
 
9
  if TYPE_CHECKING:
10
  from collections.abc import AsyncGenerator
11
 
 
12
  from faster_whisper_server.asr import FasterWhisperASR
13
 
14
  logger = logging.getLogger(__name__)
@@ -18,7 +19,7 @@ class LocalAgreement:
18
  def __init__(self) -> None:
19
  self.unconfirmed = Transcription()
20
 
21
- def merge(self, confirmed: Transcription, incoming: Transcription) -> list[Word]:
22
  # https://github.com/ufal/whisper_streaming/blob/main/whisper_online.py#L264
23
  incoming = incoming.after(confirmed.end - 0.1)
24
  prefix = common_prefix(incoming.words, self.unconfirmed.words)
 
4
  from typing import TYPE_CHECKING
5
 
6
  from faster_whisper_server.audio import Audio, AudioStream
7
+ from faster_whisper_server.text_utils import Transcription, common_prefix, to_full_sentences, word_to_text
8
 
9
  if TYPE_CHECKING:
10
  from collections.abc import AsyncGenerator
11
 
12
+ from faster_whisper_server.api_models import TranscriptionWord
13
  from faster_whisper_server.asr import FasterWhisperASR
14
 
15
  logger = logging.getLogger(__name__)
 
19
  def __init__(self) -> None:
20
  self.unconfirmed = Transcription()
21
 
22
+ def merge(self, confirmed: Transcription, incoming: Transcription) -> list[TranscriptionWord]:
23
  # https://github.com/ufal/whisper_streaming/blob/main/whisper_online.py#L264
24
  incoming = incoming.after(confirmed.end - 0.1)
25
  prefix = common_prefix(incoming.words, self.unconfirmed.words)
tests/api_timestamp_granularities_test.py CHANGED
@@ -1,6 +1,6 @@
1
  """See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501
2
 
3
- from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
  from openai import AsyncOpenAI
5
  import pytest
6
 
 
1
  """See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501
2
 
3
+ from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
  from openai import AsyncOpenAI
5
  import pytest
6
 
tests/openai_timestamp_granularities_test.py CHANGED
@@ -1,6 +1,6 @@
1
  """OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
2
 
3
- from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
  from openai import AsyncOpenAI, BadRequestError
5
  import pytest
6
 
 
1
  """OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
2
 
3
+ from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
  from openai import AsyncOpenAI, BadRequestError
5
  import pytest
6
 
tests/sse_test.py CHANGED
@@ -2,9 +2,9 @@ import json
2
  import os
3
 
4
  from fastapi.testclient import TestClient
5
- from faster_whisper_server.server_models import (
6
- TranscriptionJsonResponse,
7
- TranscriptionVerboseJsonResponse,
8
  )
9
  from httpx_sse import connect_sse
10
  import pytest
@@ -48,7 +48,7 @@ def test_streaming_transcription_json(client: TestClient, file_path: str, endpoi
48
  }
49
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
50
  for event in event_source.iter_sse():
51
- TranscriptionJsonResponse(**json.loads(event.data))
52
 
53
 
54
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
@@ -62,7 +62,7 @@ def test_streaming_transcription_verbose_json(client: TestClient, file_path: str
62
  }
63
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
64
  for event in event_source.iter_sse():
65
- TranscriptionVerboseJsonResponse(**json.loads(event.data))
66
 
67
 
68
  def test_transcription_vtt(client: TestClient) -> None:
 
2
  import os
3
 
4
  from fastapi.testclient import TestClient
5
+ from faster_whisper_server.api_models import (
6
+ CreateTranscriptionResponseJson,
7
+ CreateTranscriptionResponseVerboseJson,
8
  )
9
  from httpx_sse import connect_sse
10
  import pytest
 
48
  }
49
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
50
  for event in event_source.iter_sse():
51
+ CreateTranscriptionResponseJson(**json.loads(event.data))
52
 
53
 
54
  @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
 
62
  }
63
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
64
  for event in event_source.iter_sse():
65
+ CreateTranscriptionResponseVerboseJson(**json.loads(event.data))
66
 
67
 
68
  def test_transcription_vtt(client: TestClient) -> None: