Spaces:
Running
on
Zero
Running
on
Zero
import io | |
import os | |
import time | |
import traceback | |
from dataclasses import dataclass, field | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import pvorca | |
import soundfile as sf | |
import spaces | |
import torch | |
import xxhash | |
from datasets import Audio | |
from transformers import AutoModel | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
orca = pvorca.create( | |
access_key=os.environ.get("ORCA_KEY"), | |
model_path="./static/orca_params_masculine.pv", | |
) | |
LOADER_STR = "♫♪.ılılıll|̲̅̅●̲̅̅|̲̅̅=̲̅̅|̲̅̅●̲̅̅|llılılı.♫♪loading♫♪.ılılıll|̲̅̅●̲̅̅|̲̅̅=̲̅̅|̲̅̅●̲̅̅|llılılı.♫♪loading♫♪.ılılıll|̲̅̅●̲̅̅|̲̅̅=̲̅̅|̲̅̅●̲̅̅|llılılı.♫♪♫" | |
if gr.NO_RELOAD: | |
diva_model = AutoModel.from_pretrained( | |
"WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True, revision="79fec82089dd6cde6eca15ef6f48dd99993e1a90" | |
) | |
resampler = Audio(sampling_rate=16_000) | |
def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None): | |
sr, y = audio_input | |
x = xxhash.xxh32(bytes(y)).hexdigest() | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
a = resampler.decode_example( | |
resampler.encode_example({"array": y, "sampling_rate": sr}) | |
) | |
yield from diva_model.generate_stream( | |
a["array"], | |
( | |
"Your name is DiVA, which stands for Distilled Voice Assistant. You were trained with early-fusion training to merge OpenAI's Whisper and Meta AI's Llama 3 8B to provide end-to-end voice processing. You should respond in a conversational style. The user is talking to you with their voice and you are responding with text. Use fewer than 20 words." | |
if prev_outs == None | |
else None | |
), | |
do_sample=do_sample, | |
max_new_tokens=256, | |
init_outputs=prev_outs, | |
return_outputs=True, | |
) | |
class AppState: | |
conversation: list = field(default_factory=list) | |
stopped: bool = False | |
model_outs: any = None | |
def process_audio(audio: tuple, state: AppState): | |
return audio, state | |
def response(state: AppState, audio: tuple): | |
if not audio: | |
return AppState() | |
file_name = f"/tmp/{xxhash.xxh32(bytes(audio[1])).hexdigest()}.wav" | |
sf.write(file_name, audio[1], audio[0], format="wav") | |
state.conversation.append( | |
{"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}} | |
) | |
if state.model_outs is None: | |
gr.Warning( | |
"The first response might take a second to generate as DiVA is loaded from Disk to the ZeroGPU!" | |
) | |
state.conversation.append( | |
{ | |
"role": "assistant", | |
"content": LOADER_STR, | |
} | |
) | |
yield state, state.conversation, None | |
if spaces.config.Config.zero_gpu: | |
if state.model_outs is not None: | |
state.model_outs = tuple( | |
tuple(torch.tensor(vec).cuda() for vec in tup) | |
for tup in state.model_outs | |
) | |
causal_outs = ( | |
CausalLMOutputWithPast(past_key_values=state.model_outs) | |
if state.model_outs | |
else None | |
) | |
else: | |
causal_outs = state.model_outs | |
state.model_outs = None | |
prev_outs = causal_outs | |
stream = orca.stream_open() | |
buff = [] | |
for resp, outs in diva_audio( | |
(audio[0], audio[1]), | |
prev_outs=(prev_outs if prev_outs is not None else None), | |
): | |
prev_resp = state.conversation[-1]["content"] | |
if prev_resp == LOADER_STR: | |
prev_resp = "" | |
state.conversation[-1]["content"] = resp | |
pcm = stream.synthesize(resp[len(prev_resp) :]) | |
audio_chunk = None | |
if pcm is not None: | |
buff.extend(pcm) | |
if len(buff) > (orca.sample_rate * 2): | |
mp3_io = io.BytesIO() | |
sf.write( | |
mp3_io, | |
np.asarray(buff[: orca.sample_rate]).astype(np.int16), | |
orca.sample_rate, | |
format="mp3", | |
) | |
audio_chunk = mp3_io.getvalue() | |
mp3_io.close() | |
buff = buff[orca.sample_rate :] | |
yield state, state.conversation, audio_chunk | |
del outs.logits | |
del outs.hidden_states | |
if spaces.config.Config.zero_gpu: | |
outs = tuple( | |
tuple(vec.cpu().numpy() for vec in tup) for tup in outs.past_key_values | |
) | |
audio_chunk = None | |
pcm = stream.flush() | |
if pcm is not None: | |
mp3_io = io.BytesIO() | |
sf.write( | |
mp3_io, | |
np.asarray(buff + pcm).astype(np.int16), | |
orca.sample_rate, | |
format="mp3", | |
) | |
audio_chunk = mp3_io.getvalue() | |
mp3_io.close() | |
stream.close() | |
yield ( | |
AppState(conversation=state.conversation, model_outs=outs), | |
state.conversation, | |
audio_chunk, | |
) | |
def start_recording_user(state: AppState): | |
return None | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c100="#82000019", | |
c200="#82000033", | |
c300="#8200004c", | |
c400="#82000066", | |
c50="#8200007f", | |
c500="#8200007f", | |
c600="#82000099", | |
c700="#820000b2", | |
c800="#820000cc", | |
c900="#820000e5", | |
c950="#820000f2", | |
), | |
secondary_hue="rose", | |
neutral_hue="stone", | |
) | |
js = """ | |
async function main() { | |
const script1 = document.createElement("script"); | |
script1.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/ort.js"; | |
document.head.appendChild(script1) | |
const script2 = document.createElement("script"); | |
script2.onload = async () => { | |
console.log("vad loaded") ; | |
var record = document.querySelector('.record-button'); | |
record.textContent = "Just Start Talking!" | |
record.style = "width: fit-content; padding-right: 0.5vw;" | |
const myvad = await vad.MicVAD.new({ | |
onSpeechStart: () => { | |
var record = document.querySelector('.record-button'); | |
var player = document.getElementById("streaming_out").querySelector(".standard-player") | |
if (record != null && (player == null || player.paused)) { | |
console.log(record); | |
record.click(); | |
} | |
}, | |
onSpeechEnd: (audio) => { | |
var stop = document.querySelector('.stop-button'); | |
if (stop != null) { | |
console.log(stop); | |
stop.click(); | |
} | |
} | |
}) | |
myvad.start() | |
} | |
script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js"; | |
script1.onload = () => { | |
console.log("onnx loaded") | |
document.head.appendChild(script2) | |
}; | |
} | |
""" | |
js_reset = """ | |
() => { | |
var record = document.querySelector('.record-button'); | |
record.textContent = "Just Start Talking!" | |
record.style = "width: fit-content; padding-right: 0.5vw;" | |
} | |
""" | |
with gr.Blocks(theme=theme, js=js) as demo: | |
with gr.Row(): | |
input_audio = gr.Audio( | |
label="Input Audio", | |
sources=["microphone"], | |
type="numpy", | |
streaming=False, | |
waveform_options=gr.WaveformOptions(waveform_color="#B83A4B"), | |
) | |
with gr.Row(): | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
with gr.Row(max_height="50vh"): | |
output_audio = gr.Audio( | |
label="Output Audio", | |
streaming=True, | |
autoplay=True, | |
visible=True, | |
elem_id="streaming_out", | |
) | |
state = gr.State(value=AppState()) | |
stream = input_audio.start_recording( | |
process_audio, | |
[input_audio, state], | |
[input_audio, state], | |
) | |
respond = input_audio.stop_recording( | |
response, [state, input_audio], [state, chatbot, output_audio] | |
) | |
restart = respond.then(start_recording_user, [state], [input_audio]).then( | |
lambda state: state, state, state, js=js_reset | |
) | |
cancel = gr.Button("Restart Conversation", variant="stop") | |
cancel.click( | |
lambda: (AppState(), gr.Audio(recording=False)), | |
None, | |
[state, input_audio], | |
cancels=[respond, restart], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |