qgyd2021 commited on
Commit
1a4c139
1 Parent(s): a38ecb0

add language

Browse files
.gitignore CHANGED
@@ -3,3 +3,5 @@
3
  .idea/
4
 
5
  **/__pycache__/
 
 
 
3
  .idea/
4
 
5
  **/__pycache__/
6
+
7
+ trained_models/
Dockerfile CHANGED
@@ -17,6 +17,8 @@ RUN useradd -m -u 1000 user
17
  # Switch to the "user" user
18
  USER user
19
 
 
 
20
  # Set home to the user's home directory
21
  ENV HOME=/home/user \
22
  PATH=/home/user/.local/bin:$PATH
 
17
  # Switch to the "user" user
18
  USER user
19
 
20
+ RUN apt-get install -y git
21
+
22
  # Set home to the user's home directory
23
  ENV HOME=/home/user \
24
  PATH=/home/user/.local/bin:$PATH
main.py CHANGED
@@ -2,9 +2,15 @@
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
 
 
 
5
  import gradio as gr
6
  import platform
7
 
 
 
 
 
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
@@ -20,10 +26,35 @@ model_names = {
20
  }
21
 
22
 
 
 
 
 
23
  def click_button_allennlp_text_classification(text: str, model_name: str):
24
- print(text)
25
- print(model_name)
26
- return "label", 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def main():
 
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
 
5
+ from allennlp.models.archival import archive_model, load_archive
6
+ from allennlp.predictors.text_classifier import TextClassifierPredictor
7
  import gradio as gr
8
  import platform
9
 
10
+ from project_settings import project_path
11
+ from toolbox.allennlp.data.dataset_readers.text_classification_json import TextClassificationJsonReader
12
+ from toolbox.os.command import Command
13
+
14
 
15
  def get_args():
16
  parser = argparse.ArgumentParser()
 
26
  }
27
 
28
 
29
+ trained_model_dir = project_path / "trained_models/huggingface"
30
+ trained_model_dir.mkdir(parents=True, exist_ok=True)
31
+
32
+
33
  def click_button_allennlp_text_classification(text: str, model_name: str):
34
+ model_path = trained_model_dir / model_name
35
+ if not model_path.exists():
36
+ model_path.parent.mkdir(exist_ok=True)
37
+ Command.cd(model_path.parent.as_posix())
38
+ Command.popen("git clone https://huggingface.co/{}".format(model_name))
39
+
40
+ archive = load_archive(archive_file=model_path.as_posix())
41
+
42
+ predictor = TextClassifierPredictor(
43
+ model=archive.model,
44
+ dataset_reader=archive.dataset_reader,
45
+ )
46
+
47
+ json_dict = {
48
+ "sentence": text
49
+ }
50
+
51
+ outputs = predictor.predict_json(
52
+ json_dict
53
+ )
54
+ label = outputs["label"]
55
+ probs = outputs["probs"]
56
+
57
+ return label, round(max(probs), 4)
58
 
59
 
60
  def main():
toolbox/allennlp/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/allennlp/data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/allennlp/data/dataset_readers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/allennlp/data/dataset_readers/text_classification_json.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Union
2
+ import logging
3
+ import json
4
+
5
+ from allennlp.common.file_utils import cached_path
6
+ from allennlp.data.dataset_readers.dataset_reader import DatasetReader
7
+ from allennlp.data.fields import LabelField, TextField, Field, ListField
8
+ from allennlp.data.instance import Instance
9
+ from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
10
+ from allennlp.data.tokenizers import Tokenizer, SpacyTokenizer
11
+ from allennlp.data.tokenizers.sentence_splitter import SpacySentenceSplitter
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @DatasetReader.register("text_classification_json_utf8")
17
+ class TextClassificationJsonReader(DatasetReader):
18
+
19
+ def __init__(
20
+ self,
21
+ token_indexers: Dict[str, TokenIndexer] = None,
22
+ tokenizer: Tokenizer = None,
23
+ segment_sentences: bool = False,
24
+ max_sequence_length: int = None,
25
+ skip_label_indexing: bool = False,
26
+ text_key: str = "text",
27
+ label_key: str = "label",
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(
31
+ manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs
32
+ )
33
+ self._tokenizer = tokenizer or SpacyTokenizer()
34
+ self._segment_sentences = segment_sentences
35
+ self._max_sequence_length = max_sequence_length
36
+ self._skip_label_indexing = skip_label_indexing
37
+ self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
38
+ self._text_key = text_key
39
+ self._label_key = label_key
40
+ if self._segment_sentences:
41
+ self._sentence_segmenter = SpacySentenceSplitter()
42
+
43
+ def _read(self, file_path):
44
+ with open(cached_path(file_path), "r", encoding="utf-8") as data_file:
45
+ for line in self.shard_iterable(data_file.readlines()):
46
+ if not line:
47
+ continue
48
+ items = json.loads(line)
49
+ text = items[self._text_key]
50
+ label = items.get(self._label_key)
51
+ if label is not None:
52
+ if self._skip_label_indexing:
53
+ try:
54
+ label = int(label)
55
+ except ValueError:
56
+ raise ValueError(
57
+ "Labels must be integers if skip_label_indexing is True."
58
+ )
59
+ else:
60
+ label = str(label)
61
+ yield self.text_to_instance(text=text, label=label)
62
+
63
+ def _truncate(self, tokens):
64
+ if len(tokens) > self._max_sequence_length:
65
+ tokens = tokens[: self._max_sequence_length]
66
+ return tokens
67
+
68
+ def text_to_instance( # type: ignore
69
+ self, text: str, label: Union[str, int] = None
70
+ ) -> Instance:
71
+ fields: Dict[str, Field] = {}
72
+ if self._segment_sentences:
73
+ sentences: List[Field] = []
74
+ sentence_splits = self._sentence_segmenter.split_sentences(text)
75
+ for sentence in sentence_splits:
76
+ word_tokens = self._tokenizer.tokenize(sentence)
77
+ if self._max_sequence_length is not None:
78
+ word_tokens = self._truncate(word_tokens)
79
+ sentences.append(TextField(word_tokens))
80
+ fields["tokens"] = ListField(sentences)
81
+ else:
82
+ tokens = self._tokenizer.tokenize(text)
83
+ if self._max_sequence_length is not None:
84
+ tokens = self._truncate(tokens)
85
+ fields["tokens"] = TextField(tokens)
86
+ if label is not None:
87
+ fields["label"] = LabelField(label, skip_indexing=self._skip_label_indexing)
88
+ return Instance(fields)
89
+
90
+ def apply_token_indexers(self, instance: Instance) -> None:
91
+ if self._segment_sentences:
92
+ for text_field in instance.fields["tokens"]: # type: ignore
93
+ text_field._token_indexers = self._token_indexers
94
+ else:
95
+ instance.fields["tokens"]._token_indexers = self._token_indexers # type: ignore
toolbox/allennlp/training/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/allennlp/training/optimizers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Any, Dict, List, Tuple
4
+ from allennlp.training.optimizers import Registrable, Optimizer, make_parameter_groups
5
+ from pytorch_pretrained_bert.optimization import BertAdam
6
+ import torch
7
+
8
+
9
+ @Optimizer.register("bert_adam")
10
+ class BertAdamOptimizer(Optimizer, BertAdam):
11
+
12
+ def __init__(
13
+ self,
14
+ model_parameters: List[Tuple[str, torch.nn.Parameter]],
15
+ parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None,
16
+ lr: float = 5e-5,
17
+ warmup: float = 0.1,
18
+ t_total: int = 50000,
19
+ schedule: str = 'warmup_linear',
20
+ ):
21
+ super().__init__(
22
+ params=make_parameter_groups(model_parameters, parameter_groups),
23
+ lr=lr,
24
+ warmup=warmup,
25
+ t_total=t_total,
26
+ schedule=schedule,
27
+ )
28
+
29
+
30
+ if __name__ == '__main__':
31
+ pass
toolbox/os/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/os/command.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Command(object):
5
+ custom_command = [
6
+ 'cd'
7
+ ]
8
+
9
+ @staticmethod
10
+ def _get_cmd(command):
11
+ command = str(command).strip()
12
+ if command == '':
13
+ return None
14
+ cmd_and_args = command.split(sep=' ')
15
+ cmd = cmd_and_args[0]
16
+ args = ' '.join(cmd_and_args[1:])
17
+ return cmd, args
18
+
19
+ @classmethod
20
+ def popen(cls, command):
21
+ cmd, args = cls._get_cmd(command)
22
+ if cmd in cls.custom_command:
23
+ method = getattr(cls, cmd)
24
+ return method(args)
25
+ else:
26
+ resp = os.popen(command)
27
+ result = resp.read()
28
+ resp.close()
29
+ return result
30
+
31
+ @classmethod
32
+ def cd(cls, args):
33
+ if args.startswith('/'):
34
+ os.chdir(args)
35
+ else:
36
+ pwd = os.getcwd()
37
+ path = os.path.join(pwd, args)
38
+ os.chdir(path)
39
+
40
+ @classmethod
41
+ def system(cls, command):
42
+ return os.system(command)
43
+
44
+ def __init__(self):
45
+ pass
46
+
47
+
48
+ def ps_ef_grep(keyword: str):
49
+ cmd = 'ps -ef | grep {}'.format(keyword)
50
+ rows = Command.popen(cmd)
51
+ rows = str(rows).split('\n')
52
+ rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__('grep')]
53
+ return rows
toolbox/os/environment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ from dotenv import load_dotenv
7
+ from dotenv.main import DotEnv
8
+
9
+ from smart.json.misc import traverse
10
+
11
+
12
+ class EnvironmentManager(object):
13
+ def __init__(self, path, env, override=False):
14
+ filename = os.path.join(path, '{}.env'.format(env))
15
+ self.filename = filename
16
+
17
+ load_dotenv(
18
+ dotenv_path=filename,
19
+ override=override
20
+ )
21
+
22
+ self._environ = dict()
23
+
24
+ def open_dotenv(self, filename: str = None):
25
+ filename = filename or self.filename
26
+ dotenv = DotEnv(
27
+ dotenv_path=filename,
28
+ stream=None,
29
+ verbose=False,
30
+ interpolate=False,
31
+ override=False,
32
+ encoding="utf-8",
33
+ )
34
+ result = dotenv.dict()
35
+ return result
36
+
37
+ def get(self, key, default=None, dtype=str):
38
+ result = os.environ.get(key)
39
+ if result is None:
40
+ if default is None:
41
+ result = None
42
+ else:
43
+ result = default
44
+ else:
45
+ result = dtype(result)
46
+ self._environ[key] = result
47
+ return result
48
+
49
+
50
+ _DEFAULT_DTYPE_MAP = {
51
+ 'int': int,
52
+ 'float': float,
53
+ 'str': str,
54
+ 'json.loads': json.loads
55
+ }
56
+
57
+
58
+ class JsonConfig(object):
59
+ """
60
+ 将 json 中, 形如 `$float:threshold` 的值, 处理为:
61
+ 从环境变量中查到 threshold, 再将其转换为 float 类型.
62
+ """
63
+ def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
64
+ self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
65
+ self.environment = environment or os.environ
66
+
67
+ def sanitize_by_filename(self, filename: str):
68
+ with open(filename, 'r', encoding='utf-8') as f:
69
+ js = json.load(f)
70
+
71
+ return self.sanitize_by_json(js)
72
+
73
+ def sanitize_by_json(self, js):
74
+ js = traverse(
75
+ js,
76
+ callback=self.sanitize,
77
+ environment=self.environment
78
+ )
79
+ return js
80
+
81
+ def sanitize(self, string, environment):
82
+ """支持 $ 符开始的, 环境变量配置"""
83
+ if isinstance(string, str) and string.startswith('$'):
84
+ dtype, key = string[1:].split(':')
85
+ dtype = self.dtype_map[dtype]
86
+
87
+ value = environment.get(key)
88
+ if value is None:
89
+ raise AssertionError('environment not exist. key: {}'.format(key))
90
+
91
+ value = dtype(value)
92
+ result = value
93
+ else:
94
+ result = string
95
+ return result
96
+
97
+
98
+ def demo1():
99
+ import json
100
+
101
+ from project_settings import project_path
102
+
103
+ environment = EnvironmentManager(
104
+ path=os.path.join(project_path, 'server/callbot_server/dotenv'),
105
+ env='dev',
106
+ )
107
+ init_scenes = environment.get(key='init_scenes', dtype=json.loads)
108
+ print(init_scenes)
109
+ print(environment._environ)
110
+ return
111
+
112
+
113
+ if __name__ == '__main__':
114
+ demo1()
toolbox/os/other.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+
4
+
5
+ def pwd():
6
+ """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
7
+ frame = inspect.stack()[1]
8
+ module = inspect.getmodule(frame[0])
9
+ return os.path.dirname(os.path.abspath(module.__file__))