Spaces:
Runtime error
Runtime error
""" | |
TODO: | |
+ [x] Load Configuration | |
+ [ ] Checking | |
+ [ ] Better saving directory | |
""" | |
import numpy as np | |
from pathlib import Path | |
import jiwer | |
import pdb | |
import torch.nn as nn | |
import torch | |
import torchaudio | |
import gradio as gr | |
from logging import PlaceHolder | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import yaml | |
from transformers import pipeline | |
import librosa | |
import librosa.display | |
import matplotlib.pyplot as plt | |
from local.convert_metrics import nat2avaMOS, WER2INTELI | |
# Google cloud service | |
from googleapiclient.discovery import build | |
from google.oauth2 import service_account | |
from googleapiclient.http import MediaFileUpload | |
import datetime | |
# 来自Google Cloud控制台的JSON凭据文件 | |
credentials_file = "./src/peerless-window-254907-b386b71c0d99.json" | |
# Google Drive API版本 | |
api_version = 'v3' | |
# 创建服务对象 | |
credentials = service_account.Credentials.from_service_account_file( | |
credentials_file, scopes=['https://www.googleapis.com/auth/drive']) | |
service = build('drive', api_version, credentials=credentials) | |
# local import | |
import sys | |
sys.path.append("src") | |
import lightning_module | |
# Load automos | |
# config_yaml = sys.argv[1] | |
config_yaml = "config/Arthur.yaml" | |
with open(config_yaml, "r") as f: | |
# pdb.set_trace() | |
try: | |
config = yaml.safe_load(f) | |
except FileExistsError: | |
print("Config file Loading Error") | |
exit() | |
# Auto load examples | |
with open(config['ref_txt'], "r") as f: | |
refs = f.readlines() | |
# refs = np.loadtxt(config["ref_txt"], delimiter="\n", dtype="str") | |
refs_ids = [x.split()[0] for x in refs] | |
refs_txt = [" ".join(x.split()[1:]) for x in refs] | |
ref_feature = np.loadtxt(config["ref_feature"], delimiter=",", dtype="str") | |
ref_wavs = [str(x) for x in sorted(Path(config["ref_wavs"]).glob("**/*.wav"))] | |
dummy_wavs = [None for x in np.arange(len(ref_wavs))] | |
refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str") | |
reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID") | |
reference_textbox = gr.Textbox( | |
value="Input reference here", | |
placeholder="Input reference here", | |
label="Reference", | |
) | |
reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM") | |
# Set up interface | |
# remove dummpy wavs, ue the same ref_wavs for eval wavs | |
print("Preparing Examples") | |
examples = [ | |
[w, w_, i, x, y] for w, w_, i, x, y in zip(ref_wavs, ref_wavs, refs_ids, refs_txt, refs_ppm) | |
] | |
p = pipeline("automatic-speech-recognition") | |
# WER part | |
transformation = jiwer.Compose( | |
[ | |
jiwer.RemovePunctuation(), | |
jiwer.ToLowerCase(), | |
jiwer.RemoveWhiteSpace(replace_by_space=True), | |
jiwer.RemoveMultipleSpaces(), | |
jiwer.ReduceToListOfListOfWords(word_delimiter=" "), | |
] | |
) | |
# WPM part | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
class ChangeSampleRate(nn.Module): | |
def __init__(self, input_rate: int, output_rate: int): | |
super().__init__() | |
self.output_rate = output_rate | |
self.input_rate = input_rate | |
def forward(self, wav: torch.tensor) -> torch.tensor: | |
# Only accepts 1-channel waveform input | |
wav = wav.view(wav.size(0), -1) | |
new_length = wav.size(-1) * self.output_rate // self.input_rate | |
indices = torch.arange(new_length) * (self.input_rate / self.output_rate) | |
round_down = wav[:, indices.long()] | |
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] | |
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(0) + ( | |
round_up * indices.fmod(1.0).unsqueeze(0) | |
) | |
return output | |
# MOS model | |
model = lightning_module.BaselineLightningModule.load_from_checkpoint( | |
"src/epoch=3-step=7459.ckpt" | |
).eval() | |
# Get Speech Interval | |
def get_speech_interval(signal, db): | |
audio_interv = librosa.effects.split(signal, top_db=db) | |
pause_end = [x[0] for x in audio_interv[1:]] | |
pause_start = [x[1] for x in audio_interv[0:-1]] | |
pause_interv = [[x, y] for x, y in zip(pause_start, pause_end)] | |
return audio_interv, pause_interv | |
# plot UV | |
def plot_UV(signal, audio_interv, sr): | |
fig, ax = plt.subplots(nrows=2, sharex=True) | |
librosa.display.waveshow(signal, sr=sr, ax=ax[0]) | |
uv_flag = np.zeros(len(signal)) | |
for i in audio_interv: | |
uv_flag[i[0] : i[1]] = 1 | |
ax[1].plot(np.arange(len(signal)) / sr, uv_flag, "r") | |
ax[1].set_ylim([-0.1, 1.1]) | |
return fig | |
# Evaluation model | |
def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None): | |
if audio_path == None: | |
audio_path = _ | |
print("using ref audio as eval audio since it's empty") | |
wav, sr = torchaudio.load(audio_path) | |
if wav.shape[0] != 1: | |
wav = wav[0, :] | |
print(wav.shape) | |
osr = 16000 | |
batch = wav.unsqueeze(0).repeat(10, 1, 1) | |
csr = ChangeSampleRate(sr, osr) | |
out_wavs = csr(wav) | |
# ASR | |
trans = jiwer.ToLowerCase()(p(audio_path)["text"]) | |
# WER | |
wer = jiwer.wer( | |
ref, | |
trans, | |
truth_transform=transformation, | |
hypothesis_transform=transformation, | |
) | |
# round to 1 decimal | |
wer = np.round(wer, 1) | |
# WER convert to Intellibility score | |
INTELI_score = WER2INTELI(wer*100) | |
# MOS | |
batch = { | |
"wav": out_wavs, | |
"domains": torch.tensor([0]), | |
"judge_id": torch.tensor([288]), | |
} | |
with torch.no_grad(): | |
output = model(batch) | |
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3 | |
# round to 1 decimal | |
predic_mos = np.round(predic_mos, 1) | |
# MOS to AVA MOS | |
AVA_MOS = nat2avaMOS(predic_mos) | |
# Phonemes per minute (PPM) | |
with torch.no_grad(): | |
logits = phoneme_model(out_wavs).logits | |
phone_predicted_ids = torch.argmax(logits, dim=-1) | |
phone_transcription = processor.batch_decode(phone_predicted_ids) | |
lst_phonemes = phone_transcription[0].split(" ") | |
# VAD for pause detection | |
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr) | |
# pdb.set_trace() | |
a_h, p_h = get_speech_interval(wav_vad.numpy(), db=40) | |
# print(a_h) | |
# print(len(a_h)) | |
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr) | |
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60 | |
ppm = np.round(ppm, 1) | |
error_msg = "!!! ERROR MESSAGE !!!\n" | |
if audio_path == _ or audio_path == None: | |
error_msg += "ERROR: Fail recording, Please start from the beginning again." | |
return ( | |
fig_h, | |
predic_mos, | |
trans, | |
wer, | |
phone_transcription, | |
ppm, | |
error_msg, | |
) | |
# if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]): | |
# error_msg += "ERROR: Please speak slower.\n" | |
# elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]): | |
# error_msg += "ERROR: Please speak faster.\n" | |
if predic_mos <= float(config["thre"]["AUTOMOS"]): | |
error_msg += "ERROR: Naturalness is too low, Please try again.\n" | |
elif wer >= float(config["thre"]["WER"]): | |
error_msg += "ERROR: Intelligibility is too low, Please try again\n" | |
else: | |
error_msg = ( | |
"GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample." | |
) | |
# Google Drive saving | |
saved_google_id = None | |
if error_msg == ("GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."): | |
saved_google_id = click_google_saving(audio_path) | |
# TODO: add saved_google_id to the csv file | |
## else: | |
## TODO: clear all output as start recording again | |
## print("Saving Failed") | |
return ( | |
fig_h, | |
predic_mos, | |
trans, | |
wer, | |
phone_transcription, | |
ppm, | |
error_msg, | |
saved_google_id, | |
) | |
with open("src/description.html", "r", encoding="utf-8") as f: | |
description = f.read() | |
# description | |
refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str") | |
reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID", visible=False) | |
reference_textbox = gr.Textbox( | |
value="Input reference here", | |
placeholder="Input reference here", | |
label="Reference", | |
) | |
reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM", visible=False) | |
# Flagging setup | |
# Interface | |
# Participant Information | |
def record_part_info(name, gender, first_lng): | |
message = "Participant information is successfully collected." | |
id_str = "%s_%s_%s" % (name, gender[0], first_lng[0]) | |
if name == None: | |
message = "ERROR: Name Information incomplete!" | |
id_str = "ERROR" | |
if gender == None: | |
message = "ERROR: Please select gender" | |
id_str = "ERROR" | |
if len(gender) > 1: | |
message = "ERROR: Please select one gender only" | |
id_str = "ERROR" | |
if first_lng == None: | |
message = "ERROR: Please select your english proficiency" | |
id_str = "ERROR" | |
if len(first_lng) > 1: | |
message = "ERROR: Please select one english proficiency only" | |
id_str = "ERROR" | |
return message, id_str | |
# information page not using now | |
name = gr.Textbox(placeholder="Name", label="Name") | |
gender = gr.CheckboxGroup(["Male", "Female"], label="gender") | |
first_lng = gr.CheckboxGroup( | |
[ | |
"B1 Intermediate", | |
"B2: Upper Intermediate", | |
"C1: Advanced", | |
"C2: Proficient", | |
], | |
label="English Proficiency (CEFR)", | |
) | |
msg = gr.Textbox(placeholder="Evaluation for valid participant", label="message") | |
id_str = gr.Textbox(placeholder="participant id", label="participant_id") | |
info = gr.Interface( | |
fn=record_part_info, | |
inputs=[name, gender, first_lng], | |
outputs=[msg, id_str], | |
title="Participant Information Page", | |
allow_flagging="never", | |
css="body {background-color: blue}", | |
) | |
# Experiment | |
if config["exp_id"] == None: | |
config["exp_id"] = Path(config_yaml).stem | |
## Theme | |
css = """ | |
.ref_text textarea {font-size: 40px !important} | |
.message textarea {font-size: 40px !important} | |
""" | |
my_theme = gr.themes.Default().set( | |
button_primary_background_fill="#75DA99", | |
button_primary_background_fill_dark="#DEF2D7", | |
button_primary_text_color="black", | |
button_secondary_text_color="black", | |
) | |
# Callback for saving the recording | |
callback = gr.CSVLogger() | |
def generate_now_time_wav(): | |
# Get the current date and time | |
current_time = datetime.datetime.now() | |
# Format the date and time as a string | |
time_string = current_time.strftime("%Y-%m-%d_%H-%M-%S") | |
# Create the WAV file name with the formatted time | |
wavfile_name = f"audio_{time_string}.wav" | |
return wavfile_name | |
# Add google drive cloud saving | |
def click_google_saving(audio_file, | |
): | |
# reference_id, | |
# reference_textbox, | |
# reference_PPM, | |
# predict_mos, | |
# hyp, | |
# wer, | |
# ppm, | |
# msg, | |
name = generate_now_time_wav() | |
# 上传文件 | |
media = MediaFileUpload(audio_file, mimetype='audio/wav') | |
request = service.files().create( | |
media_body=media, | |
body={'name': name, } | |
) | |
# 'reference_id': reference_id, | |
# "reference_textbox": reference_textbox, | |
# "reference_PPM": reference_PPM, | |
# "predict_mos": predict_mos, | |
# "hyp": hyp, | |
# "wer": wer, | |
# "ppm": ppm, | |
# "msg": msg | |
response = request.execute() | |
# get saved file id | |
return response.get('id') | |
# return response.get('id') | |
with gr.Blocks(css=css, theme=my_theme) as demo: | |
with gr.Column(): | |
with gr.Row(): | |
ref_audio = gr.Audio( | |
source="microphone", | |
type="filepath", | |
label="Reference_Audio", | |
container=True, | |
interactive=False, | |
visible=False, | |
) | |
with gr.Row(): | |
eval_audio = gr.Audio( | |
source="microphone", | |
type="filepath", | |
container=True, | |
label="Audio_to_Evaluate", | |
) | |
b_redo = gr.ClearButton( | |
value="Redo", variant="stop", components=[eval_audio], size="sm" | |
) | |
reference_textbox = gr.Textbox( | |
value="Input reference here", | |
placeholder="Input reference here", | |
label="Reference", | |
interactive=True, | |
elem_classes="ref_text", | |
) | |
with gr.Row(): | |
with gr.Accordion("Input for Development", open=False): | |
reference_id = gr.Textbox( | |
value="ID", | |
placeholder="Utter ID", | |
label="Reference_ID", | |
visible=True, | |
) | |
reference_PPM = gr.Textbox( | |
placeholder="Pneumatic Voice's PPM", | |
label="Ref PPM", | |
visible=True, | |
) | |
with gr.Row(): | |
b = gr.Button(value="1.Submit", variant="primary", elem_classes="submit") | |
# TODO | |
# b_more = gr.Button(value="Show More", elem_classes="verbose") | |
with gr.Row(): | |
inputs = [ | |
ref_audio, | |
eval_audio, | |
reference_id, | |
reference_textbox, | |
reference_PPM, | |
] | |
e = gr.Examples(examples, inputs, examples_per_page=5) | |
with gr.Column(): | |
with gr.Row(): | |
## output block | |
msg = gr.Textbox( | |
placeholder="Recording Feedback", | |
label="Message", | |
interactive=False, | |
elem_classes="message", | |
) | |
with gr.Accordion("Output for Development", open=False): | |
wav_plot = gr.Plot(PlaceHolder="Wav/Pause Plot", label="wav_pause_plot", visible=True) | |
predict_mos = gr.Textbox( | |
placeholder="Predicted MOS", | |
label="Predicted MOS", | |
visible=True, | |
) | |
hyp = gr.Textbox(placeholder="Hypothesis", label="Hypothesis", visible=True) | |
wer = gr.Textbox(placeholder="Word Error Rate", label="WER", visible=True) | |
predict_pho = gr.Textbox( | |
placeholder="Predicted Phonemes", | |
label="Predicted Phonemes", | |
visible=True, | |
) | |
ppm = gr.Textbox( | |
placeholder="Phonemes per minutes", | |
label="PPM", | |
visible=True, | |
) | |
saved_google_drive_id = gr.Textbox( | |
placeholder="Saved Google Drive ID", | |
label="Saved Google Drive ID", | |
visible=True, | |
) | |
outputs = [ | |
wav_plot, | |
predict_mos, | |
hyp, | |
wer, | |
predict_pho, | |
ppm, | |
msg, | |
saved_google_drive_id | |
] | |
# b = gr.Button("Submit") | |
b.click(fn=calc_mos, inputs=inputs, outputs=outputs, api_name="Submit") | |
# Logger | |
callback.setup( | |
components=[ | |
eval_audio, | |
reference_id, | |
reference_textbox, | |
reference_PPM, | |
predict_mos, | |
hyp, | |
wer, | |
ppm, | |
msg, | |
saved_google_drive_id], | |
flagging_dir="./exp/%s" % config["exp_id"], | |
) | |
# Saving the Recording to CSV Logger (TO BE DELETED) | |
with gr.Row(): | |
b2 = gr.Button("2. Save the Recording", variant="primary", elem_id="save") | |
js_confirmed_saving = "(x) => confirm('Recording Saved!')" | |
# eval_audio, | |
b2.click( | |
lambda *args: callback.flag(args), | |
inputs=[ | |
eval_audio, | |
reference_id, | |
reference_textbox, | |
reference_PPM, | |
predict_mos, | |
hyp, | |
wer, | |
ppm, | |
msg, | |
saved_google_drive_id | |
], | |
outputs=None, | |
preprocess=False, | |
api_name="flagging", | |
) | |
with gr.Row(): | |
b3 = gr.ClearButton( | |
[ | |
ref_audio, | |
eval_audio, | |
reference_id, | |
reference_textbox, | |
reference_PPM, | |
predict_mos, | |
hyp, | |
wer, | |
ppm, | |
msg, | |
saved_google_drive_id | |
], | |
value="3.Clear All", | |
elem_id="clear", | |
) | |
demo.launch(share=False) |