|
import logging |
|
import os |
|
|
|
import hydra |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops.layers.torch import Rearrange |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
from .utils import Accuracy |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def save_ckpt(model, path, model_class): |
|
ckpt = { |
|
"state_dict": model.state_dict(), |
|
"padding_token": model.padding_token, |
|
"model_class": model_class, |
|
} |
|
torch.save(ckpt, path) |
|
|
|
|
|
def load_ckpt(path): |
|
ckpt = torch.load(path) |
|
ckpt["model_class"]["_target_"] = "emotion_models.duration_predictor.CnnPredictor" |
|
model = hydra.utils.instantiate(ckpt["model_class"]) |
|
model.load_state_dict(ckpt["state_dict"]) |
|
model.padding_token = ckpt["padding_token"] |
|
model = model.cpu() |
|
model.eval() |
|
return model |
|
|
|
|
|
class Collator: |
|
def __init__(self, padding_idx): |
|
self.padding_idx = padding_idx |
|
|
|
def __call__(self, batch): |
|
x = [item[0] for item in batch] |
|
lengths = [len(item) for item in x] |
|
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.padding_idx) |
|
y = [item[1] for item in batch] |
|
y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=self.padding_idx) |
|
mask = (x != self.padding_idx) |
|
return x, y, mask, lengths |
|
|
|
|
|
class Predictor(nn.Module): |
|
def __init__(self, n_tokens, emb_dim): |
|
super(Predictor, self).__init__() |
|
self.n_tokens = n_tokens |
|
self.emb_dim = emb_dim |
|
self.padding_token = n_tokens |
|
|
|
|
|
self.emb = nn.Embedding(n_tokens + 1, emb_dim, padding_idx=self.padding_token) |
|
|
|
def inflate_input(self, batch): |
|
""" get a sequence of tokens, predict their durations |
|
and inflate them accordingly """ |
|
batch_durs = self.forward(batch) |
|
batch_durs = torch.exp(batch_durs) - 1 |
|
batch_durs = batch_durs.round() |
|
output = [] |
|
for seq, durs in zip(batch, batch_durs): |
|
inflated_seq = [] |
|
for token, n in zip(seq, durs): |
|
if token == self.padding_token: |
|
break |
|
n = int(n.item()) |
|
token = int(token.item()) |
|
inflated_seq.extend([token for _ in range(n)]) |
|
output.append(inflated_seq) |
|
output = torch.LongTensor(output) |
|
return output |
|
|
|
|
|
class CnnPredictor(Predictor): |
|
def __init__(self, n_tokens, emb_dim, channels, kernel, output_dim, dropout, n_layers): |
|
super(CnnPredictor, self).__init__(n_tokens=n_tokens, emb_dim=emb_dim) |
|
layers = [ |
|
Rearrange("b t c -> b c t"), |
|
nn.Conv1d(emb_dim, channels, kernel_size=kernel, padding=(kernel - 1) // 2), |
|
Rearrange("b c t -> b t c"), |
|
nn.ReLU(), |
|
nn.LayerNorm(channels), |
|
nn.Dropout(dropout), |
|
] |
|
for _ in range(n_layers-1): |
|
layers += [ |
|
Rearrange("b t c -> b c t"), |
|
nn.Conv1d(channels, channels, kernel_size=kernel, padding=(kernel - 1) // 2), |
|
Rearrange("b c t -> b t c"), |
|
nn.ReLU(), |
|
nn.LayerNorm(channels), |
|
nn.Dropout(dropout), |
|
] |
|
self.conv_layer = nn.Sequential(*layers) |
|
self.proj = nn.Linear(channels, output_dim) |
|
|
|
def forward(self, x): |
|
x = self.emb(x) |
|
x = self.conv_layer(x) |
|
x = self.proj(x) |
|
x = x.squeeze(-1) |
|
return x |
|
|
|
|
|
def l2_log_loss(input, target): |
|
return F.mse_loss( |
|
input=input.float(), |
|
target=torch.log(target.float() + 1), |
|
reduce=False |
|
) |
|
|
|
|
|
class DurationDataset(Dataset): |
|
def __init__(self, tsv_path, km_path, substring=""): |
|
lines = open(tsv_path, "r").readlines() |
|
self.root, self.tsv = lines[0], lines[1:] |
|
self.km = open(km_path, "r").readlines() |
|
logger.info(f"loaded {len(self.km)} files") |
|
|
|
if substring != "": |
|
tsv, km = [], [] |
|
for tsv_line, km_line in zip(self.tsv, self.km): |
|
if substring.lower() in tsv_line.lower(): |
|
tsv.append(tsv_line) |
|
km.append(km_line) |
|
self.tsv, self.km = tsv, km |
|
logger.info(f"after filtering: {len(self.km)} files") |
|
|
|
def __len__(self): |
|
return len(self.km) |
|
|
|
def __getitem__(self, i): |
|
x = self.km[i] |
|
x = x.split(" ") |
|
x = list(map(int, x)) |
|
|
|
y = [] |
|
xd = [] |
|
count = 1 |
|
for x1, x2 in zip(x[:-1], x[1:]): |
|
if x1 == x2: |
|
count += 1 |
|
continue |
|
else: |
|
y.append(count) |
|
xd.append(x1) |
|
count = 1 |
|
|
|
xd = torch.LongTensor(xd) |
|
y = torch.LongTensor(y) |
|
return xd, y |
|
|
|
|
|
def train(cfg): |
|
device = "cuda:0" |
|
model = hydra.utils.instantiate(cfg[cfg.model]).to(device) |
|
optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters()) |
|
|
|
|
|
collate_fn = Collator(padding_idx=model.padding_token) |
|
logger.info(f"data: {cfg.train_tsv}") |
|
train_ds = DurationDataset(cfg.train_tsv, cfg.train_km, substring=cfg.substring) |
|
valid_ds = DurationDataset(cfg.valid_tsv, cfg.valid_km, substring=cfg.substring) |
|
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn) |
|
valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=False, collate_fn=collate_fn) |
|
|
|
best_loss = float("inf") |
|
for epoch in range(cfg.epochs): |
|
train_loss, train_loss_scaled = train_epoch(model, train_dl, l2_log_loss, optimizer, device) |
|
valid_loss, valid_loss_scaled, *acc = valid_epoch(model, valid_dl, l2_log_loss, device) |
|
acc0, acc1, acc2, acc3 = acc |
|
if valid_loss_scaled < best_loss: |
|
path = f"{os.getcwd()}/{cfg.substring}.ckpt" |
|
save_ckpt(model, path, cfg[cfg.model]) |
|
best_loss = valid_loss_scaled |
|
logger.info(f"saved checkpoint: {path}") |
|
logger.info(f"[epoch {epoch}] train loss: {train_loss:.3f}, train scaled: {train_loss_scaled:.3f}") |
|
logger.info(f"[epoch {epoch}] valid loss: {valid_loss:.3f}, valid scaled: {valid_loss_scaled:.3f}") |
|
logger.info(f"acc: {acc0,acc1,acc2,acc3}") |
|
|
|
|
|
def train_epoch(model, loader, criterion, optimizer, device): |
|
model.train() |
|
epoch_loss = 0 |
|
epoch_loss_scaled = 0 |
|
for x, y, mask, _ in loader: |
|
x, y, mask = x.to(device), y.to(device), mask.to(device) |
|
yhat = model(x) |
|
loss = criterion(yhat, y) * mask |
|
loss = torch.mean(loss) |
|
loss.backward() |
|
nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
epoch_loss += loss.item() |
|
|
|
yhat_scaled = torch.exp(yhat) - 1 |
|
yhat_scaled = torch.round(yhat_scaled) |
|
scaled_loss = torch.mean(torch.abs(yhat_scaled - y) * mask) |
|
epoch_loss_scaled += scaled_loss.item() |
|
return epoch_loss / len(loader), epoch_loss_scaled / len(loader) |
|
|
|
|
|
def valid_epoch(model, loader, criterion, device): |
|
model.eval() |
|
epoch_loss = 0 |
|
epoch_loss_scaled = 0 |
|
acc = Accuracy() |
|
for x, y, mask, _ in loader: |
|
x, y, mask = x.to(device), y.to(device), mask.to(device) |
|
yhat = model(x) |
|
loss = criterion(yhat, y) * mask |
|
loss = torch.mean(loss) |
|
epoch_loss += loss.item() |
|
|
|
yhat_scaled = torch.exp(yhat) - 1 |
|
yhat_scaled = torch.round(yhat_scaled) |
|
scaled_loss = torch.sum(torch.abs(yhat_scaled - y) * mask) / mask.sum() |
|
acc.update(yhat_scaled[mask].view(-1).float(), y[mask].view(-1).float()) |
|
epoch_loss_scaled += scaled_loss.item() |
|
logger.info(f"example y: {y[0, :10].tolist()}") |
|
logger.info(f"example yhat: {yhat_scaled[0, :10].tolist()}") |
|
acc0 = acc.acc(tol=0) |
|
acc1 = acc.acc(tol=1) |
|
acc2 = acc.acc(tol=2) |
|
acc3 = acc.acc(tol=3) |
|
logger.info(f"accs: {acc0,acc1,acc2,acc3}") |
|
return epoch_loss / len(loader), epoch_loss_scaled / len(loader), acc0, acc1, acc2, acc3 |
|
|
|
|
|
@hydra.main(config_path=".", config_name="duration_predictor.yaml") |
|
def main(cfg): |
|
logger.info(f"{cfg}") |
|
train(cfg) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|