Spaces:
Running
on
Zero
Running
on
Zero
import io | |
import os | |
import time | |
import traceback | |
from dataclasses import dataclass, field | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import pvorca | |
import soundfile as sf | |
import spaces | |
import torch | |
import xxhash | |
from datasets import Audio | |
from transformers import AutoModel | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
orca = pvorca.create(access_key=os.environ.get("ORCA_KEY")) | |
if gr.NO_RELOAD: | |
diva_model = AutoModel.from_pretrained( | |
"WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True | |
) | |
resampler = Audio(sampling_rate=16_000) | |
def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None): | |
sr, y = audio_input | |
x = xxhash.xxh32(bytes(y)).hexdigest() | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
a = resampler.decode_example( | |
resampler.encode_example({"array": y, "sampling_rate": sr}) | |
) | |
yield from diva_model.generate_stream( | |
a["array"], | |
( | |
"Your name is DiVA, which stands for Distilled Voice Assistant. You were trained with early-fusion training to merge OpenAI's Whisper and Meta AI's Llama 3 8B to provide end-to-end voice processing. You should give brief and helpful answers, in a conversational style. The user is talking to you with their voice and you are responding with text." | |
if prev_outs == None | |
else None | |
), | |
do_sample=do_sample, | |
max_new_tokens=256, | |
init_outputs=prev_outs, | |
return_outputs=True, | |
) | |
class AppState: | |
conversation: list = field(default_factory=list) | |
stopped: bool = False | |
model_outs: any = None | |
def process_audio(audio: tuple, state: AppState): | |
return audio, state | |
def response(state: AppState, audio: tuple): | |
if not audio: | |
return AppState() | |
file_name = f"/tmp/{xxhash.xxh32(bytes(audio[1])).hexdigest()}.wav" | |
sf.write(file_name, audio[1], audio[0], format="wav") | |
state.conversation.append( | |
{"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}} | |
) | |
state.conversation.append({"role": "assistant", "content": ""}) | |
yield state, state.conversation, None | |
if spaces.config.Config.zero_gpu: | |
if state.model_outs is not None: | |
state.model_outs = tuple( | |
tuple(torch.tensor(vec).cuda() for vec in tup) | |
for tup in state.model_outs | |
) | |
causal_outs = ( | |
CausalLMOutputWithPast(past_key_values=state.model_outs) | |
if state.model_outs | |
else None | |
) | |
else: | |
causal_outs = state.model_outs | |
state.model_outs = None | |
prev_outs = causal_outs | |
stream = orca.stream_open() | |
for resp, outs in diva_audio( | |
(audio[0], audio[1]), | |
prev_outs=(prev_outs if prev_outs is not None else None), | |
): | |
prev_resp = state.conversation[-1]["content"] | |
state.conversation[-1]["content"] = resp | |
pcm = stream.synthesize(resp[len(prev_resp) :]) | |
audio_chunk = None | |
if pcm is not None: | |
mp3_io = io.BytesIO() | |
sf.write( | |
mp3_io, np.asarray(pcm).astype(np.int16), orca.sample_rate, format="mp3" | |
) | |
audio_chunk = mp3_io.getvalue() | |
mp3_io.close() | |
yield state, state.conversation, audio_chunk | |
del outs.logits | |
del outs.hidden_states | |
if spaces.config.Config.zero_gpu: | |
outs = tuple( | |
tuple(vec.cpu().numpy() for vec in tup) for tup in outs.past_key_values | |
) | |
audio_chunk = None | |
pcm = stream.flush() | |
if pcm is not None: | |
audio_chunk = np.asarray(pcm).tobytes() | |
mp3_io = io.BytesIO() | |
sf.write( | |
mp3_io, np.asarray(pcm).astype(np.int16), orca.sample_rate, format="mp3" | |
) | |
audio_chunk = mp3_io.getvalue() | |
mp3_io.close() | |
stream.close() | |
yield ( | |
AppState(conversation=state.conversation, model_outs=outs), | |
state.conversation, | |
audio_chunk, | |
) | |
def start_recording_user(state: AppState): | |
return None | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c100="#82000019", | |
c200="#82000033", | |
c300="#8200004c", | |
c400="#82000066", | |
c50="#8200007f", | |
c500="#8200007f", | |
c600="#82000099", | |
c700="#820000b2", | |
c800="#820000cc", | |
c900="#820000e5", | |
c950="#820000f2", | |
), | |
secondary_hue="rose", | |
neutral_hue="stone", | |
) | |
js = """ | |
async function main() { | |
const script1 = document.createElement("script"); | |
script1.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/ort.js"; | |
document.head.appendChild(script1) | |
const script2 = document.createElement("script"); | |
script2.onload = async () => { | |
console.log("vad loaded") ; | |
var record = document.querySelector('.record-button'); | |
record.textContent = "Just Start Talking!" | |
record.style = "width: 11vw" | |
const myvad = await vad.MicVAD.new({ | |
onSpeechStart: () => { | |
var record = document.querySelector('.record-button'); | |
if (record != null) { | |
console.log(record); | |
record.click(); | |
} | |
}, | |
onSpeechEnd: (audio) => { | |
var stop = document.querySelector('.stop-button'); | |
if (stop != null) { | |
console.log(stop); | |
stop.click(); | |
} | |
} | |
}) | |
myvad.start() | |
} | |
script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js"; | |
script1.onload = () => { | |
console.log("onnx loaded") | |
document.head.appendChild(script2) | |
}; | |
} | |
""" | |
js_reset = """ | |
() => { | |
var record = document.querySelector('.record-button'); | |
record.textContent = "Just Start Talking!" | |
record.style = "width: 11vw" | |
} | |
""" | |
with gr.Blocks(theme=theme, js=js) as demo: | |
with gr.Row(): | |
input_audio = gr.Audio( | |
label="Input Audio", | |
sources=["microphone"], | |
type="numpy", | |
streaming=False, | |
) | |
with gr.Row(): | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
with gr.Row(): | |
output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True) | |
state = gr.State(value=AppState()) | |
stream = input_audio.start_recording( | |
process_audio, | |
[input_audio, state], | |
[input_audio, state], | |
) | |
respond = input_audio.stop_recording( | |
response, [state, input_audio], [state, chatbot, output_audio] | |
) | |
restart = output_audio.stop(start_recording_user, [state], [input_audio]).then( | |
lambda state: state, state, state, js=js_reset | |
) | |
cancel = gr.Button("Restart Conversation", variant="stop") | |
cancel.click( | |
lambda: (AppState(), gr.Audio(recording=False)), | |
None, | |
[state, input_audio], | |
cancels=[respond, restart], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |