import torch | |
class Stat: | |
def __init__(self, keep_raw=False): | |
self.x = 0.0 | |
self.x2 = 0.0 | |
self.z = 0.0 # z = logx | |
self.z2 = 0.0 | |
self.n = 0.0 | |
self.u = 0.0 | |
self.keep_raw = keep_raw | |
self.raw = [] | |
def update(self, new_x): | |
new_z = new_x.log() | |
self.x += new_x.sum() | |
self.x2 += (new_x**2).sum() | |
self.z += new_z.sum() | |
self.z2 += (new_z**2).sum() | |
self.n += len(new_x) | |
self.u += 1 | |
if self.keep_raw: | |
self.raw.append(new_x) | |
def mean(self): | |
return self.x / self.n | |
def std(self): | |
return (self.x2 / self.n - self.mean**2) ** 0.5 | |
def mean_log(self): | |
return self.z / self.n | |
def std_log(self): | |
return (self.z2 / self.n - self.mean_log**2) ** 0.5 | |
def n_frms(self): | |
return self.n | |
def n_utts(self): | |
return self.u | |
def raw_data(self): | |
assert self.keep_raw, "does not support storing raw data!" | |
return torch.cat(self.raw) | |
class F0Stat(Stat): | |
def update(self, new_x): | |
# assume unvoiced frames are 0 and consider only voiced frames | |
if new_x is not None: | |
super().update(new_x[new_x != 0]) | |
class Accuracy: | |
def __init__(self): | |
self.y, self.yhat = [], [] | |
def update(self, yhat, y): | |
self.yhat.append(yhat) | |
self.y.append(y) | |
def acc(self, tol): | |
yhat = torch.cat(self.yhat) | |
y = torch.cat(self.y) | |
acc = torch.abs(yhat - y) <= tol | |
acc = acc.float().mean().item() | |
return acc | |