File size: 2,113 Bytes
223aff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

import json
import pathlib
# import tqdm

from typing import Optional
import os
import threading

from loguru import logger
# from app.common import HParams
# from __ini import HParams
from pathlib import Path
import requests

from app import HParams


def find_path_by_suffix(dir_path: Path, suffix: Path):
    assert dir_path.is_dir()

    for path in dir_path.glob(f"*.{suffix}"):
        return path

    return None


def get_hparams_from_file(config_path):
    with open(config_path, "r") as f:
        data = f.read()
    config = json.loads(data)

    hparams = HParams(**config)
    return hparams


def intersperse(lst, item):
    result = [item] * (len(lst) * 2 + 1)
    result[1::2] = lst
    return result


def time_it(func: callable):
    import time

    def wrapper(*args, **kwargs):
        # start = time.time()
        start = time.perf_counter()
        res = func(*args, **kwargs)
        # end = time.time()
        end = time.perf_counter()
        # print(f"func {func.__name__} cost {end-start} seconds")
        logger.info(f"func {func.__name__} cost {end-start} seconds")
        return res
    return wrapper





# def download_defaults(model_path: pathlib.Path, config_path: pathlib.Path):

#     config = requests.get(config_url,  timeout=10).content
#     with open(str(config_path), 'wb') as f:
#         f.write(config)

#     t = threading.Thread(target=pdownload, args=(model_url, str(model_path)))
#     t.start()


def get_paths(dir_path: Path):

    model_path: Path = find_path_by_suffix(dir_path, "onnx")
    config_path: Path = find_path_by_suffix(dir_path, "json")
    # if not model_path or not config_path:
    #     model_path = dir_path / "model.onnx"
    #     config_path = dir_path / "config.json"
    #     logger.warning(
    #         "unable to find model or config, try to download default model and config"
    #     )
    #     download_defaults(model_path, config_path)

    # model_path = str(model_path)
    # config_path = str(config_path)
    # logger.info(f"model path: {model_path} config path: {config_path}")
    return model_path, config_path