Spaces:
Runtime error
Runtime error
Delete model.py
Browse files
model.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
#! /usr/bin/env python
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
# Copyright 2023 Imperial College London (Pingchuan Ma)
|
5 |
-
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
-
|
7 |
-
import os
|
8 |
-
import json
|
9 |
-
import torch
|
10 |
-
import argparse
|
11 |
-
import numpy as np
|
12 |
-
|
13 |
-
from espnet.asr.asr_utils import torch_load
|
14 |
-
from espnet.asr.asr_utils import get_model_conf
|
15 |
-
from espnet.asr.asr_utils import add_results_to_json
|
16 |
-
from espnet.nets.batch_beam_search import BatchBeamSearch
|
17 |
-
from espnet.nets.lm_interface import dynamic_import_lm
|
18 |
-
from espnet.nets.scorers.length_bonus import LengthBonus
|
19 |
-
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E
|
20 |
-
|
21 |
-
|
22 |
-
class AVSR(torch.nn.Module):
|
23 |
-
def __init__(self, modality, model_path, model_conf, rnnlm=None, rnnlm_conf=None,
|
24 |
-
penalty=0., ctc_weight=0.1, lm_weight=0., beam_size=40, device="cuda:0"):
|
25 |
-
super(AVSR, self).__init__()
|
26 |
-
self.device = device
|
27 |
-
|
28 |
-
if modality == "audiovisual":
|
29 |
-
from espnet.nets.pytorch_backend.e2e_asr_transformer_av import E2E
|
30 |
-
else:
|
31 |
-
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E
|
32 |
-
|
33 |
-
with open(model_conf, "rb") as f:
|
34 |
-
confs = json.load(f)
|
35 |
-
args = confs if isinstance(confs, dict) else confs[2]
|
36 |
-
self.train_args = argparse.Namespace(**args)
|
37 |
-
|
38 |
-
labels_type = getattr(self.train_args, "labels_type", "char")
|
39 |
-
if labels_type == "char":
|
40 |
-
self.token_list = self.train_args.char_list
|
41 |
-
elif labels_type == "unigram5000":
|
42 |
-
file_path = os.path.join(os.path.dirname(__file__), "tokens", "unigram5000_units.txt")
|
43 |
-
self.token_list = ['<blank>'] + [word.split()[0] for word in open(file_path).read().splitlines()] + ['<eos>']
|
44 |
-
self.odim = len(self.token_list)
|
45 |
-
|
46 |
-
self.model = E2E(self.odim, self.train_args)
|
47 |
-
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
48 |
-
self.model.to(device=self.device).eval()
|
49 |
-
|
50 |
-
self.beam_search = get_beam_search_decoder(self.model, self.token_list, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size)
|
51 |
-
self.beam_search.to(device=self.device).eval()
|
52 |
-
|
53 |
-
def infer(self, data):
|
54 |
-
with torch.no_grad():
|
55 |
-
if isinstance(data, tuple):
|
56 |
-
enc_feats = self.model.encode(data[0].to(self.device), data[1].to(self.device))
|
57 |
-
else:
|
58 |
-
enc_feats = self.model.encode(data.to(self.device))
|
59 |
-
nbest_hyps = self.beam_search(enc_feats)
|
60 |
-
nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]]
|
61 |
-
transcription = add_results_to_json(nbest_hyps, self.token_list)
|
62 |
-
transcription = transcription.replace("▁", " ").strip()
|
63 |
-
return transcription.replace("<eos>", "")
|
64 |
-
|
65 |
-
|
66 |
-
def get_beam_search_decoder(model, token_list, rnnlm=None, rnnlm_conf=None, penalty=0, ctc_weight=0.1, lm_weight=0., beam_size=40):
|
67 |
-
sos = model.odim - 1
|
68 |
-
eos = model.odim - 1
|
69 |
-
scorers = model.scorers()
|
70 |
-
|
71 |
-
if not rnnlm:
|
72 |
-
lm = None
|
73 |
-
else:
|
74 |
-
lm_args = get_model_conf(rnnlm, rnnlm_conf)
|
75 |
-
lm_model_module = getattr(lm_args, "model_module", "default")
|
76 |
-
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
|
77 |
-
lm = lm_class(len(token_list), lm_args)
|
78 |
-
torch_load(rnnlm, lm)
|
79 |
-
lm.eval()
|
80 |
-
|
81 |
-
scorers["lm"] = lm
|
82 |
-
scorers["length_bonus"] = LengthBonus(len(token_list))
|
83 |
-
weights = dict(
|
84 |
-
decoder=1.0 - ctc_weight,
|
85 |
-
ctc=ctc_weight,
|
86 |
-
lm=lm_weight,
|
87 |
-
length_bonus=penalty,
|
88 |
-
)
|
89 |
-
|
90 |
-
return BatchBeamSearch(
|
91 |
-
beam_size=beam_size,
|
92 |
-
vocab_size=len(token_list),
|
93 |
-
weights=weights,
|
94 |
-
scorers=scorers,
|
95 |
-
sos=sos,
|
96 |
-
eos=eos,
|
97 |
-
token_list=token_list,
|
98 |
-
pre_beam_score_key=None if ctc_weight == 1.0 else "decoder",
|
99 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|