Tonic's picture
testing 123
31a7207
import multiprocessing
import argparse
import threading
import ssl
import time
import sys
import functools
from multiprocessing import Process, Manager, Value, Queue
from whisper_live.trt_server import TranscriptionServer
from llm_service import TensorRTLLMEngine
from tts_service import WhisperSpeechTTS
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--whisper_tensorrt_path',
type=str,
default="/root/TensorRT-LLM/examples/whisper/whisper_small_en",
help='Whisper TensorRT model path')
parser.add_argument('--mistral',
action="store_true",
help='Mistral')
parser.add_argument('--mistral_tensorrt_path',
type=str,
default=None,
help='Mistral TensorRT model path')
parser.add_argument('--mistral_tokenizer_path',
type=str,
default="teknium/OpenHermes-2.5-Mistral-7B",
help='Mistral TensorRT model path')
parser.add_argument('--phi',
action="store_true",
help='Phi')
parser.add_argument('--phi_tensorrt_path',
type=str,
default="/root/TensorRT-LLM/examples/phi/phi_engine",
help='Phi TensorRT model path')
parser.add_argument('--phi_tokenizer_path',
type=str,
default="/root/TensorRT-LLM/examples/phi/phi-2",
help='Phi Tokenizer path')
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
if not args.whisper_tensorrt_path:
raise ValueError("Please provide whisper_tensorrt_path to run the pipeline.")
import sys
sys.exit(0)
if args.mistral:
if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path:
raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.")
import sys
sys.exit(0)
if args.phi:
if not args.phi_tensorrt_path or not args.phi_tokenizer_path:
raise ValueError("Please provide phi_tensorrt_path and phi_tokenizer_path to run the pipeline.")
import sys
sys.exit(0)
multiprocessing.set_start_method('spawn')
lock = multiprocessing.Lock()
manager = Manager()
shared_output = manager.list()
transcription_queue = Queue()
llm_queue = Queue()
audio_queue = Queue()
whisper_server = TranscriptionServer()
whisper_process = multiprocessing.Process(
target=whisper_server.run,
args=(
"0.0.0.0",
6006,
transcription_queue,
llm_queue,
args.whisper_tensorrt_path
)
)
whisper_process.start()
llm_provider = TensorRTLLMEngine()
# llm_provider = MistralTensorRTLLMProvider()
llm_process = multiprocessing.Process(
target=llm_provider.run,
args=(
# args.mistral_tensorrt_path,
# args.mistral_tokenizer_path,
args.phi_tensorrt_path,
args.phi_tokenizer_path,
transcription_queue,
llm_queue,
audio_queue,
)
)
llm_process.start()
# audio process
tts_runner = WhisperSpeechTTS()
tts_process = multiprocessing.Process(target=tts_runner.run, args=("0.0.0.0", 8888, audio_queue))
tts_process.start()
llm_process.join()
whisper_process.join()
tts_process.join()