willwade commited on
Commit
e9b6593
·
verified ·
1 Parent(s): b6db772

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -99
model.py DELETED
@@ -1,99 +0,0 @@
1
- #! /usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
- # Copyright 2023 Imperial College London (Pingchuan Ma)
5
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
-
7
- import os
8
- import json
9
- import torch
10
- import argparse
11
- import numpy as np
12
-
13
- from espnet.asr.asr_utils import torch_load
14
- from espnet.asr.asr_utils import get_model_conf
15
- from espnet.asr.asr_utils import add_results_to_json
16
- from espnet.nets.batch_beam_search import BatchBeamSearch
17
- from espnet.nets.lm_interface import dynamic_import_lm
18
- from espnet.nets.scorers.length_bonus import LengthBonus
19
- from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E
20
-
21
-
22
- class AVSR(torch.nn.Module):
23
- def __init__(self, modality, model_path, model_conf, rnnlm=None, rnnlm_conf=None,
24
- penalty=0., ctc_weight=0.1, lm_weight=0., beam_size=40, device="cuda:0"):
25
- super(AVSR, self).__init__()
26
- self.device = device
27
-
28
- if modality == "audiovisual":
29
- from espnet.nets.pytorch_backend.e2e_asr_transformer_av import E2E
30
- else:
31
- from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E
32
-
33
- with open(model_conf, "rb") as f:
34
- confs = json.load(f)
35
- args = confs if isinstance(confs, dict) else confs[2]
36
- self.train_args = argparse.Namespace(**args)
37
-
38
- labels_type = getattr(self.train_args, "labels_type", "char")
39
- if labels_type == "char":
40
- self.token_list = self.train_args.char_list
41
- elif labels_type == "unigram5000":
42
- file_path = os.path.join(os.path.dirname(__file__), "tokens", "unigram5000_units.txt")
43
- self.token_list = ['<blank>'] + [word.split()[0] for word in open(file_path).read().splitlines()] + ['<eos>']
44
- self.odim = len(self.token_list)
45
-
46
- self.model = E2E(self.odim, self.train_args)
47
- self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
48
- self.model.to(device=self.device).eval()
49
-
50
- self.beam_search = get_beam_search_decoder(self.model, self.token_list, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size)
51
- self.beam_search.to(device=self.device).eval()
52
-
53
- def infer(self, data):
54
- with torch.no_grad():
55
- if isinstance(data, tuple):
56
- enc_feats = self.model.encode(data[0].to(self.device), data[1].to(self.device))
57
- else:
58
- enc_feats = self.model.encode(data.to(self.device))
59
- nbest_hyps = self.beam_search(enc_feats)
60
- nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]]
61
- transcription = add_results_to_json(nbest_hyps, self.token_list)
62
- transcription = transcription.replace("▁", " ").strip()
63
- return transcription.replace("<eos>", "")
64
-
65
-
66
- def get_beam_search_decoder(model, token_list, rnnlm=None, rnnlm_conf=None, penalty=0, ctc_weight=0.1, lm_weight=0., beam_size=40):
67
- sos = model.odim - 1
68
- eos = model.odim - 1
69
- scorers = model.scorers()
70
-
71
- if not rnnlm:
72
- lm = None
73
- else:
74
- lm_args = get_model_conf(rnnlm, rnnlm_conf)
75
- lm_model_module = getattr(lm_args, "model_module", "default")
76
- lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
77
- lm = lm_class(len(token_list), lm_args)
78
- torch_load(rnnlm, lm)
79
- lm.eval()
80
-
81
- scorers["lm"] = lm
82
- scorers["length_bonus"] = LengthBonus(len(token_list))
83
- weights = dict(
84
- decoder=1.0 - ctc_weight,
85
- ctc=ctc_weight,
86
- lm=lm_weight,
87
- length_bonus=penalty,
88
- )
89
-
90
- return BatchBeamSearch(
91
- beam_size=beam_size,
92
- vocab_size=len(token_list),
93
- weights=weights,
94
- scorers=scorers,
95
- sos=sos,
96
- eos=eos,
97
- token_list=token_list,
98
- pre_beam_score_key=None if ctc_weight == 1.0 else "decoder",
99
- )