import time import traceback from dataclasses import dataclass, field import gradio as gr import librosa import numpy as np import soundfile as sf import spaces import torch import xxhash from datasets import Audio from transformers import AutoModel import io from pydub import AudioSegment import tempfile from utils.vad import VadOptions, collect_chunks, get_speech_timestamps if gr.NO_RELOAD: diva_model = AutoModel.from_pretrained( "WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True ) resampler = Audio(sampling_rate=16_000) @spaces.GPU(duration=20) @torch.no_grad def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None): sr, y = audio_input x = xxhash.xxh32(bytes(y)).hexdigest() y = y.astype(np.float32) y /= np.max(np.abs(y)) a = resampler.decode_example( resampler.encode_example({"array": y, "sampling_rate": sr}) ) yield from diva_model.generate_stream( a["array"], None, do_sample=do_sample, max_new_tokens=256, init_outputs=prev_outs, return_outputs=True, ) def run_vad(ori_audio, sr, duration): _st = time.time() try: audio = ori_audio if duration < 1: return -1, ori_audio, round(time.time() - _st, 4) audio = audio.astype(np.float32) / 32768.0 sampling_rate = 16000 if sr != sampling_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) vad_parameters = {} vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) audio = collect_chunks(audio, speech_chunks) duration_after_vad = audio.shape[0] / sampling_rate if sr != sampling_rate: # resample to original sampling rate vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) else: vad_audio = audio vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) vad_audio_bytes = vad_audio.tobytes() return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) except Exception as e: msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" print(msg) return -1, ori_audio, round(time.time() - _st, 4) def warm_up(): frames = np.ones(2048) # 1024 frames of 2 bytes each dur, frames, tcost = run_vad(frames, 16000, 10) print(f"warm up done, time_cost: {tcost:.3f} s") warm_up() @dataclass class AppState: stream: np.ndarray | None = None sampling_rate: int = 0 pause_detected: bool = False started_talking: bool = False stopped: bool = False conversation: list = field(default_factory=list) model_outs: any = None def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: """Take in the stream, determine if a pause happened""" temp_audio = audio[-2 * sampling_rate :] duration = len(audio) / sampling_rate dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate, duration) if dur_vad > 0.25 and not state.started_talking: print("started talking") state.started_talking = True return False print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s") return dur_vad < 0.5 def process_audio(audio: tuple, state: AppState): if state.stream is None: state.stream = audio[1] state.sampling_rate = audio[0] elif audio is not None and audio[1] is not None: state.stream = np.concatenate((state.stream, audio[1])) else: return None, state pause_detected = determine_pause(state.stream, state.sampling_rate, state) state.pause_detected = pause_detected if state.pause_detected and state.started_talking: return gr.Audio(recording=False), state return None, state def response(state: AppState): if not state.pause_detected and not state.started_talking: return AppState() file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav" sf.write(file_name, state.stream, state.sampling_rate, format="wav") state.conversation.append( {"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}} ) start = False for resp, outs in diva_audio( (state.sampling_rate, state.stream), prev_outs=state.model_outs ): if not start: state.conversation.append({"role": "assistant", "content": resp}) start = True else: state.conversation[-1]["content"] = resp yield state, state.conversation yield AppState(conversation=state.conversation, model_outs=outs), state.conversation def start_recording_user(state: AppState): if not state.stopped: return gr.Audio(recording=True) theme = gr.themes.Soft( primary_hue=gr.themes.Color( c100="#82000019", c200="#82000033", c300="#8200004c", c400="#82000066", c50="#8200007f", c500="#8200007f", c600="#82000099", c700="#820000b2", c800="#820000cc", c900="#820000e5", c950="#820000f2", ), secondary_hue="rose", neutral_hue="stone", ) with gr.Blocks(theme=theme) as demo: with gr.Row(): input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy") with gr.Row(): chatbot = gr.Chatbot(label="Conversation", type="messages") state = gr.State(value=AppState()) stream = input_audio.stream( process_audio, [input_audio, state], [input_audio, state], stream_every=0.25, time_limit=10, ) respond = input_audio.stop_recording(response, [state], [state, chatbot]) respond.then(start_recording_user, [state], [input_audio]) cancel = gr.Button("Stop Conversation", variant="stop") cancel.click( lambda: (AppState(stopped=True), gr.Audio(recording=False)), None, [state, input_audio], cancels=[respond, stream], ) if __name__ == "__main__": demo.launch()