In [1]:
%load_ext autoreload
%autoreload 2

from zeno import zeno
import math
import os
import pandas as pd

In [9]:
df = pd.read_csv("metadata.csv")

In [10]:
df.set_index('id', inplace=True, drop=False)

In [12]:
df.groupby('id')

ValueError: 'id' is both an index level and a column label, which is ambiguous.

In [None]:
zeno({
    "metadata": df[0:10],
    "view": "audio-transcription",
    "data_path": "/Users/acabrera/dev/data/speech-accent-archive/recordings/recordings/",
    "label_column": "label",
    "data_column": "id"
})

In [None]:
import torch
import whisper
from jiwer import wer
from zeno import ZenoOptions, distill, metric, model
import numpy as np
from zeno import ZenoOptions, distill

In [None]:
@model
def load_model(model_path):
    if "sst" in model_path:
        device = torch.device("cpu")
        model, decoder, utils = torch.hub.load(
            repo_or_dir="snakers4/silero-models",
            model="silero_stt",
            language="en",
            device=device,
        )
        (read_batch, _, _, prepare_model_input) = utils

        def pred(df, ops: ZenoOptions):
            files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
            input = prepare_model_input(read_batch(files), device=device)
            return [decoder(x.cpu()) for x in model(input)]

        return pred

    elif "whisper" in model_path:
        model = whisper.load_model("tiny")

        def pred(df, ops: ZenoOptions):
            files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
            outs = []
            for f in files:
                outs.append(model.transcribe(f)["text"])
            return outs

        return pred


@distill
def country(df, ops: ZenoOptions):
    if df["0birthplace"][0] == df["0birthplace"][0]:
        return df["0birthplace"].str.split(", ")[-1][-1]
    return ""


@distill
def wer_m(df, ops: ZenoOptions):
    return df.apply(lambda x: wer(x[ops.label_column], x[ops.output_column]), axis=1)


@metric
def avg_wer(df, ops: ZenoOptions):
    avg = df[ops.distill_columns["wer_m"]].mean()
    if math.isnan(avg):
        return 0
    return avg

# @distill
# def amplitude(df, ops: ZenoOptions):
#     files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
#     amps = []
#     for audio in files:
#         y, _ = librosa.load(audio)
#         amps.append(float(np.abs(y).mean()))
#     return amps


# @distill
# def length(df, ops: ZenoOptions):
#     files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
#     amps = []
#     for audio in files:
#         y, _ = librosa.load(audio)
#         amps.append(len(y))
#     return amps

In [None]:
zeno({
    "metadata": df,
    "functions": [load_model, country, wer_m, avg_wer],
    "view": "audio-transcription",
    "models": ["silero_sst", "whisper"],
    "data_path": "/Users/acabrera/dev/data/speech-accent-archive/recordings/recordings/",
    "data_column": "id",
    "label_column": "label",
    "samples": 10,
})
# metadata = "metadata.csv"
# # data_path = "https://zenoml.s3.amazonaws.com/accents/"