import asyncio import os from asyncio import TaskGroup from pathlib import Path from typing import Any, Callable, List from uuid import uuid4 from langchain_community.callbacks import get_openai_callback from pydantic import BaseModel from pydub import AudioSegment from src import tts, utils from src.config import ( CONTEXT_CHAR_LEN_FOR_TTS, ELEVENLABS_MAX_PARALLEL, OPENAI_MAX_PARALLEL, logger, ) from src.lc_callbacks import LCMessageLoggerAsync from src.preprocess_tts_emotions_chain import TTSParamProcessor from src.schemas import SoundEffectsParams, TTSParams, TTSTimestampsAlignment, TTSTimestampsResponse from src.select_voice_chain import ( CharacterPropertiesNullable, SelectVoiceChainOutput, VoiceSelector, ) from src.sound_effects_design import ( SoundEffectDescription, SoundEffectsDesignOutput, create_sound_effects_design_chain, ) from src.text_modification_chain import modify_text_chain from src.text_split_chain import SplitTextOutput, create_split_text_chain from src.utils import GPTModels, prettify_unknown_character_label from src.web.constructor import HTMLGenerator from src.web.utils import ( create_status_html, generate_text_split_inner_html_no_effect, generate_text_split_inner_html_with_effects, generate_voice_mapping_inner_html, ) class TTSPhrasesGenerationOutput(BaseModel): audio_fps: list[str] char2time: TTSTimestampsAlignment class AudiobookBuilder: def __init__(self, rm_artifacts: bool = False): self.voice_selector = VoiceSelector() self.params_tts_processor = TTSParamProcessor() self.rm_artifacts = rm_artifacts self.min_sound_effect_duration_sec = 1 self.sound_effects_prompt_influence = 0.75 # seems to work nicely self.html_generator = HTMLGenerator() self.name = type(self).__name__ @staticmethod async def _prepare_text_for_tts(text: str) -> str: chain = modify_text_chain(llm_model=GPTModels.GPT_4o) with get_openai_callback() as cb: result = await chain.ainvoke( {"text": text}, config={"callbacks": [LCMessageLoggerAsync()]} ) logger.info( f'End of modifying text with caps and symbols(?, !, ...). Openai callback stats: {cb}' ) return result.text_modified @staticmethod async def _split_text(text: str) -> SplitTextOutput: chain = create_split_text_chain(llm_model=GPTModels.GPT_4o) with get_openai_callback() as cb: chain_out = await chain.ainvoke( {"text": text}, config={"callbacks": [LCMessageLoggerAsync()]} ) logger.info(f'end of splitting text into characters. openai callback stats: {cb}') return chain_out @staticmethod async def _design_sound_effects(text: str) -> SoundEffectsDesignOutput: chain = create_sound_effects_design_chain(llm_model=GPTModels.GPT_4o) with get_openai_callback() as cb: res = await chain.ainvoke( {"text": text}, config={"callbacks": [LCMessageLoggerAsync()]} ) logger.info( f'designed {len(res.sound_effects_descriptions)} sound effects. ' f'openai callback stats: {cb}' ) return res async def _map_characters_to_voices( self, text_split: SplitTextOutput ) -> SelectVoiceChainOutput: chain = self.voice_selector.create_voice_mapping_chain(llm_model=GPTModels.GPT_4o) with get_openai_callback() as cb: chain_out = await chain.ainvoke( { "text": text_split.text_annotated, "characters": text_split.characters, }, config={"callbacks": [LCMessageLoggerAsync()]}, ) logger.info(f'end of mapping characters to voices. openai callback stats: {cb}') return chain_out async def _prepare_params_for_tts(self, text_split: SplitTextOutput) -> list[TTSParams]: semaphore = asyncio.Semaphore(OPENAI_MAX_PARALLEL) async def run_task_with_semaphore(func, **params): async with semaphore: outputs = await func(**params) return outputs tasks = [] for character_phrase in text_split.phrases: tasks.append( run_task_with_semaphore( func=self.params_tts_processor.run, text=character_phrase.text, ) ) tts_tasks_results = await asyncio.gather(*tasks) return tts_tasks_results @staticmethod def _add_voice_ids_to_tts_params( text_split: SplitTextOutput, tts_params_list: list[TTSParams], character2voice: dict[str, str], ) -> list[TTSParams]: for character_phrase, params in zip(text_split.phrases, tts_params_list): params.voice_id = character2voice[character_phrase.character] return tts_params_list @staticmethod def _get_left_and_right_contexts_for_each_phrase( phrases, context_length=CONTEXT_CHAR_LEN_FOR_TTS ): """ Return phrases from left and right sides which don't exceed `context_length`. Approx. number of words/tokens based on `context_length` can be calculated by dividing it by 5. """ # TODO: split first context phrase if it exceeds `context_length`, currently it's not added. # TODO: optimize algorithm to linear time using sliding window on top of cumulative length sums. left_right_contexts = [] for i in range(len(phrases)): left_text, right_text = '', '' for j in range(i - 1, -1, -1): if len(left_text) + len(phrases[j].text) < context_length: left_text = phrases[j].text + left_text else: break for phrase in phrases[i + 1 :]: if len(right_text) + len(phrase.text) < context_length: right_text += phrase.text else: break left_right_contexts.append((left_text, right_text)) return left_right_contexts def _add_previous_and_next_context_to_tts_params( self, text_split: SplitTextOutput, tts_params_list: list[TTSParams], ) -> list[TTSParams]: left_right_contexts = self._get_left_and_right_contexts_for_each_phrase(text_split.phrases) for cur_contexts, params in zip(left_right_contexts, tts_params_list): left_context, right_context = cur_contexts params.previous_text = left_context params.next_text = right_context return tts_params_list @staticmethod async def _generate_tts_audio( tts_params_list: list[TTSParams], out_dp: str, ) -> TTSPhrasesGenerationOutput: semaphore = asyncio.Semaphore(ELEVENLABS_MAX_PARALLEL) async def _tts_with_semaphore(params: TTSParams) -> TTSTimestampsResponse: async with semaphore: return await tts.tts_w_timestamps(params=params) tasks = [_tts_with_semaphore(params=params) for params in tts_params_list] tts_responses: list[TTSTimestampsResponse] = await asyncio.gather(*tasks) tts_audio_fps = [] for ix, (params, res) in enumerate(zip(tts_params_list, tts_responses), start=1): out_fp_no_ext = os.path.join(out_dp, f'tts_output_{ix}') out_fp = res.write_audio_to_file( filepath_no_ext=out_fp_no_ext, audio_format=params.output_format ) tts_audio_fps.append(out_fp) # combine alignments alignments = [response.alignment for response in tts_responses] char2time = TTSTimestampsAlignment.combine_alignments(alignments=alignments) # filter alignments char2time = char2time.filter_chars_without_duration() return TTSPhrasesGenerationOutput(audio_fps=tts_audio_fps, char2time=char2time) def _update_sound_effects_descriptions_with_durations( self, sound_effects_descriptions: list[SoundEffectDescription], char2time: TTSTimestampsAlignment, ) -> list[SoundEffectDescription]: for sed in sound_effects_descriptions: ix_start, ix_end = sed.ix_start_orig_text, sed.ix_end_orig_text time_start = char2time.get_start_time_by_char_ix(ix_start, safe=True) time_end = char2time.get_end_time_by_char_ix(ix_end, safe=True) duration = time_end - time_start # apply min effect duration duration = max(self.min_sound_effect_duration_sec, duration) # update inplace sed.start_sec = time_start sed.duration_sec = duration return sound_effects_descriptions # def _filter_short_sound_effects( # self, # sound_effects_descriptions: list[SoundEffectDescription], # ) -> list[SoundEffectDescription]: # filtered = [ # sed # for sed in sound_effects_descriptions # if sed.duration_sec > self.min_sound_effect_duration_sec # ] # len_orig = len(sound_effects_descriptions) # len_new = len(filtered) # logger.info( # f'{len_new} out of {len_orig} original sound effects are kept ' # f'after filtering by min duration: {self.min_sound_effect_duration_sec}' # ) # return filtered def _sound_effects_description_2_generation_params( self, sound_effects_descriptions: list[SoundEffectDescription], ) -> list[SoundEffectsParams]: params = [ SoundEffectsParams( text=sed.prompt, duration_seconds=sed.duration_sec, prompt_influence=self.sound_effects_prompt_influence, ) for sed in sound_effects_descriptions ] return params @staticmethod async def _generate_sound_effects( sound_effects_params: list[SoundEffectsParams], out_dp: str, ) -> list[str]: semaphore = asyncio.Semaphore(ELEVENLABS_MAX_PARALLEL) async def _se_gen_with_semaphore(params: SoundEffectsParams) -> list[bytes]: async with semaphore: return await tts.sound_generation_consumed(params=params) tasks = [_se_gen_with_semaphore(params=params) for params in sound_effects_params] results = await asyncio.gather(*tasks) se_fps = [] for ix, task_res in enumerate(results, start=1): out_fp = os.path.join(out_dp, f'sound_effect_{ix}.wav') utils.write_chunked_bytes(data=task_res, fp=out_fp) se_fps.append(out_fp) return se_fps @staticmethod def _save_text_split_debug_data( text_split: SplitTextOutput, out_dp: str, ): out_fp = os.path.join(out_dp, 'text_split.json') # NOTE: use `to_dict()` for correct conversion data = text_split.model_dump() utils.write_json(data, fp=out_fp) @staticmethod def _save_tts_debug_data( tts_params_list: list[TTSParams], tts_out: TTSPhrasesGenerationOutput, out_dp: str, ): out_fp = os.path.join(out_dp, 'tts.json') # NOTE: use `to_dict()` for correct conversion data = [param.to_dict() for param in tts_params_list] utils.write_json(data, fp=out_fp) out_dp = os.path.join(out_dp, 'tts_char2time.csv') df_char2time = tts_out.char2time.to_dataframe() df_char2time.to_csv(out_dp, index=True) @staticmethod def _save_sound_effects_debug_data( sound_effect_design_output: SoundEffectsDesignOutput, sound_effect_descriptions: list[SoundEffectDescription], out_dp: str, ): out_fp = os.path.join(out_dp, 'sound_effects_raw_llm_output.txt') utils.write_txt(sound_effect_design_output.text_annotated, fp=out_fp) out_fp = os.path.join(out_dp, 'sound_effects_descriptions.json') data = [sed.model_dump() for sed in sound_effect_descriptions] utils.write_json(data, fp=out_fp) @staticmethod def _postprocess_tts_audio(audio_fps: list[str], out_dp: str, target_dBFS: float) -> list[str]: fps = [] for in_fp in audio_fps: audio_segment = AudioSegment.from_file(in_fp) normalized_audio = utils.normalize_audio(audio_segment, target_dBFS) out_fp = os.path.join(out_dp, f"{Path(in_fp).stem}.normalized.wav") normalized_audio.export(out_fp, format="wav") fps.append(out_fp) return fps @staticmethod def _postprocess_sound_effects( audio_fps: list[str], out_dp: str, target_dBFS: float, fade_ms: int ) -> list[str]: fps = [] for in_fp in audio_fps: audio_segment = AudioSegment.from_file(in_fp) processed = utils.normalize_audio(audio_segment, target_dBFS) processed = processed.fade_in(duration=fade_ms) processed = processed.fade_out(duration=fade_ms) out_fp = os.path.join(out_dp, f"{Path(in_fp).stem}.postprocessed.wav") processed.export(out_fp, format="wav") fps.append(out_fp) return fps @staticmethod def _concatenate_audiofiles(audio_fps: list[str], out_wav_fp: str): concat = AudioSegment.from_file(audio_fps[0]) for filename in audio_fps[1:]: next_audio = AudioSegment.from_file(filename) concat += next_audio logger.info(f'saving concatenated audiobook to: "{out_wav_fp}"') concat.export(out_wav_fp, format="wav") def _get_text_split_html( self, text_split: SplitTextOutput, sound_effects_descriptions: list[SoundEffectDescription] | None, ): # modify copies of original phrases, keep original intact character_phrases = [p.model_copy(deep=True) for p in text_split.phrases] for phrase in character_phrases: phrase.character = prettify_unknown_character_label(phrase.character) if not sound_effects_descriptions: inner = generate_text_split_inner_html_no_effect(character_phrases=character_phrases) else: inner = generate_text_split_inner_html_with_effects( character_phrases=character_phrases, sound_effects_descriptions=sound_effects_descriptions, ) final = self.html_generator.generate_text_split(inner) return final def _get_voice_mapping_html( self, use_user_voice: bool, select_voice_chain_out: SelectVoiceChainOutput ): if use_user_voice: return '' inner = generate_voice_mapping_inner_html(select_voice_chain_out) final = self.html_generator.generate_voice_assignments(inner) return final STAGE_1 = 'Text Analysis' STAGE_2 = 'Voices Selection' STAGE_3 = 'Audio Generation' def _get_yield_data_stage_0(self): status = self.html_generator.generate_status("Starting", [("Analyzing Text...", False)]) return None, "", status def _get_yield_data_stage_1(self, text_split_html: str): status_html = create_status_html( "Text Analysis Complete", [(self.STAGE_1, True), ("Selecting Voices...", False)], ) html = status_html + text_split_html return None, "", html def _get_yield_data_stage_2(self, text_split_html: str, voice_mapping_html: str): status_html = create_status_html( "Voice Selection Complete", [(self.STAGE_1, True), (self.STAGE_2, True), ("Generating Audio...", False)], ) html = status_html + text_split_html + voice_mapping_html + '' return None, "", html def _get_yield_data_stage_3( self, final_audio_fp: str, text_split_html: str, voice_mapping_html: str ): status_html = create_status_html( "Audiobook is ready ✨", [(self.STAGE_1, True), (self.STAGE_2, True), (self.STAGE_3, True)], ) third_stage_result_html = ( status_html + text_split_html + voice_mapping_html + self.html_generator.generate_final_message() + '' ) return final_audio_fp, "", third_stage_result_html async def run( self, text: str, generate_effects: bool, use_user_voice: bool = False, voice_id: str | None = None, ): now_str = utils.get_utc_now_str() uuid_trimmed = str(uuid4()).split('-')[0] dir_name = f'{now_str}-{uuid_trimmed}' out_dp_root = os.path.join('data', 'audiobooks', dir_name) os.makedirs(out_dp_root, exist_ok=False) debug_dp = os.path.join(out_dp_root, 'debug') os.makedirs(debug_dp) # TODO: currently, we are constantly writing and reading audio segments from files. # I think it will be more efficient to keep all audio in memory. # zero stage if use_user_voice and not voice_id: yield None, "", self.html_generator.generate_message_without_voice_id() else: yield self._get_yield_data_stage_0() text_for_tts = await self._prepare_text_for_tts(text=text) # TODO: call sound effects chain in parallel with text split chain text_split = await self._split_text(text=text_for_tts) self._save_text_split_debug_data(text_split=text_split, out_dp=debug_dp) # yield stage 1 text_split_html = self._get_text_split_html( text_split=text_split, sound_effects_descriptions=None ) yield self._get_yield_data_stage_1(text_split_html=text_split_html) if generate_effects: se_design_output = await self._design_sound_effects(text=text_for_tts) se_descriptions = se_design_output.sound_effects_descriptions text_split_html = self._get_text_split_html( text_split=text_split, sound_effects_descriptions=se_descriptions ) # TODO: run voice mapping and tts params selection in parallel if not use_user_voice: select_voice_chain_out = await self._map_characters_to_voices(text_split=text_split) else: if voice_id is None: raise ValueError(f'voice_id is None') select_voice_chain_out = SelectVoiceChainOutput( character2props={ char: CharacterPropertiesNullable(gender=None, age_group=None) for char in text_split.characters }, character2voice={char: voice_id for char in text_split.characters}, ) tts_params_list = await self._prepare_params_for_tts(text_split=text_split) # yield stage 2 voice_mapping_html = self._get_voice_mapping_html( use_user_voice=use_user_voice, select_voice_chain_out=select_voice_chain_out ) yield self._get_yield_data_stage_2( text_split_html=text_split_html, voice_mapping_html=voice_mapping_html ) tts_params_list = self._add_voice_ids_to_tts_params( text_split=text_split, tts_params_list=tts_params_list, character2voice=select_voice_chain_out.character2voice, ) tts_params_list = self._add_previous_and_next_context_to_tts_params( text_split=text_split, tts_params_list=tts_params_list, ) tts_dp = os.path.join(out_dp_root, 'tts') os.makedirs(tts_dp) tts_out = await self._generate_tts_audio(tts_params_list=tts_params_list, out_dp=tts_dp) self._save_tts_debug_data( tts_params_list=tts_params_list, tts_out=tts_out, out_dp=debug_dp ) if generate_effects: se_descriptions = self._update_sound_effects_descriptions_with_durations( sound_effects_descriptions=se_descriptions, char2time=tts_out.char2time ) # no need in filtering, since we ensure the min duration above # se_descriptions = self._filter_short_sound_effects( # sound_effects_descriptions=se_descriptions # ) se_params = self._sound_effects_description_2_generation_params( sound_effects_descriptions=se_descriptions ) if len(se_descriptions) != len(se_params): raise ValueError( f'expected {len(se_descriptions)} sound effects params, got: {len(se_params)}' ) effects_dp = os.path.join(out_dp_root, 'sound_effects') os.makedirs(effects_dp) se_fps = await self._generate_sound_effects( sound_effects_params=se_params, out_dp=effects_dp ) if len(se_descriptions) != len(se_fps): raise ValueError( f'expected {len(se_descriptions)} generated sound effects, got: {len(se_fps)}' ) self._save_sound_effects_debug_data( sound_effect_design_output=se_design_output, sound_effect_descriptions=se_descriptions, out_dp=debug_dp, ) tts_normalized_dp = os.path.join(out_dp_root, 'tts_normalized') os.makedirs(tts_normalized_dp) tts_norm_fps = self._postprocess_tts_audio( audio_fps=tts_out.audio_fps, out_dp=tts_normalized_dp, target_dBFS=-20, ) if generate_effects: se_normalized_dp = os.path.join(out_dp_root, 'sound_effects_postprocessed') os.makedirs(se_normalized_dp) se_norm_fps = self._postprocess_sound_effects( audio_fps=se_fps, out_dp=se_normalized_dp, target_dBFS=-27, fade_ms=500, ) tts_concat_fp = os.path.join(out_dp_root, f'audiobook_{now_str}.wav') self._concatenate_audiofiles(audio_fps=tts_norm_fps, out_wav_fp=tts_concat_fp) if not generate_effects: final_audio_fp = tts_concat_fp else: tts_concat_with_effects_fp = os.path.join( out_dp_root, f'audiobook_with_effects_{now_str}.wav' ) se_starts_sec = [sed.start_sec for sed in se_descriptions] utils.overlay_multiple_audio( main_audio_fp=tts_concat_fp, audios_to_overlay_fps=se_norm_fps, starts_sec=se_starts_sec, out_fp=tts_concat_with_effects_fp, ) final_audio_fp = tts_concat_with_effects_fp utils.rm_dir_conditional(dp=out_dp_root, to_remove=self.rm_artifacts) # yield stage 3 yield self._get_yield_data_stage_3( final_audio_fp=final_audio_fp, text_split_html=text_split_html, voice_mapping_html=voice_mapping_html, ) logger.info(f'end of {self.name}.run()')