Spaces:
Running
Running
File size: 1,905 Bytes
c24ff9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import math
import os
import pandas as pd
import torch
import whisper
from jiwer import wer
from zeno import (
ZenoOptions,
distill,
metric,
model,
DistillReturn,
ModelReturn,
MetricReturn,
)
@model
def load_model(model_path):
if "sst" in model_path:
device = torch.device("cpu")
model, decoder, utils = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_stt",
language="en",
device=device,
)
(read_batch, _, _, prepare_model_input) = utils
def pred(df, ops: ZenoOptions):
files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
input = prepare_model_input(read_batch(files), device=device)
return ModelReturn(model_output=[decoder(x.cpu()) for x in model(input)])
return pred
elif "whisper" in model_path:
model = whisper.load_model("tiny")
def pred(df, ops: ZenoOptions):
files = [os.path.join(ops.data_path, f) for f in df[ops.data_column]]
outs = []
for f in files:
outs.append(model.transcribe(f)["text"])
return ModelReturn(model_output=outs)
return pred
@distill
def country(df, ops: ZenoOptions):
if df["birthplace"][0] == df["birthplace"][0]:
return DistillReturn(distill_output=[df["birthplace"].str.split(", ")[-1][-1]])
return DistillReturn(distill_output=[""] * len(df))
@distill
def wer_m(df, ops: ZenoOptions):
return DistillReturn(
distill_output=df.apply(
lambda x: wer(x[ops.label_column], x[ops.output_column]), axis=1
)
)
@metric
def avg_wer(df, ops: ZenoOptions):
avg = df[ops.distill_columns["wer_m"]].mean()
if pd.isnull(avg) or math.isnan(avg):
return MetricReturn(metric=0)
return MetricReturn(metric=avg)
|