qgyd2021 commited on
Commit
71e8d3d
·
1 Parent(s): 0038ad1

first commit

Browse files
Files changed (6) hide show
  1. .gitignore +10 -0
  2. README.md +6 -5
  3. examples.json +6 -0
  4. main.py +191 -0
  5. project_settings.py +12 -0
  6. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ **/flagged/
6
+ **/__pycache__/
7
+
8
+ cache/
9
+ flagged/
10
+ trained_models/
README.md CHANGED
@@ -1,11 +1,12 @@
1
  ---
2
  title: Generate Similar Question
3
- emoji: 🐨
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: docker
 
 
7
  pinned: false
8
- license: apache-2.0
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Generate Similar Question
3
+ emoji: 🐠
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.41.2
8
+ app_file: main.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
examples.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ "如何开通微信支付?",
4
+ 128, 0.75, 0.35, 1.2, "qgyd2021/similar_question_generation", true
5
+ ]
6
+ ]
main.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import os
7
+ import platform
8
+ import re
9
+ import string
10
+ from typing import List
11
+
12
+ from project_settings import project_path
13
+
14
+ os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
15
+
16
+ import gradio as gr
17
+ from threading import Thread
18
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
+ from transformers.models.bert.tokenization_bert import BertTokenizer
20
+ from transformers.generation.streamers import TextIteratorStreamer
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser()
26
+
27
+ parser.add_argument("--max_new_tokens", default=512, type=int)
28
+ parser.add_argument("--top_p", default=0.9, type=float)
29
+ parser.add_argument("--temperature", default=0.35, type=float)
30
+ parser.add_argument("--repetition_penalty", default=1.0, type=float)
31
+ parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str)
32
+
33
+ parser.add_argument(
34
+ "--examples_json_file",
35
+ default="examples.json",
36
+ type=str
37
+ )
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def repl1(match):
43
+ result = "{}{}".format(match.group(1), match.group(2))
44
+ return result
45
+
46
+
47
+ def repl2(match):
48
+ result = "{}".format(match.group(1))
49
+ return result
50
+
51
+
52
+ def remove_space_between_cn_en(text):
53
+ splits = re.split(" ", text)
54
+ if len(splits) < 2:
55
+ return text
56
+
57
+ result = ""
58
+ for t in splits:
59
+ if t == "":
60
+ continue
61
+ if re.search(f"[a-zA-Z0-9{string.punctuation}]$", result) and re.search("^[a-zA-Z0-9]", t):
62
+ result += " "
63
+ result += t
64
+ else:
65
+ if not result == "":
66
+ result += t
67
+ else:
68
+ result = t
69
+
70
+ if text.endswith(" "):
71
+ result += " "
72
+ return result
73
+
74
+
75
+ def main():
76
+ args = get_args()
77
+
78
+ description = """
79
+ ## GPT2 Chat
80
+ """
81
+
82
+ # example json
83
+ with open(args.examples_json_file, "r", encoding="utf-8") as f:
84
+ examples = json.load(f)
85
+
86
+ if args.device == 'auto':
87
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
88
+ else:
89
+ device = args.device
90
+
91
+ input_text_box = gr.Text(label="text")
92
+ output_text_box = gr.Text(lines=4, label="generated_content")
93
+
94
+ def fn_stream(text: str,
95
+ max_new_tokens: int = 200,
96
+ top_p: float = 0.85,
97
+ temperature: float = 0.35,
98
+ repetition_penalty: float = 1.2,
99
+ model_name: str = "qgyd2021/lip_service_4chan",
100
+ is_chat: bool = True,
101
+ ):
102
+ tokenizer = BertTokenizer.from_pretrained(model_name)
103
+ model = GPT2LMHeadModel.from_pretrained(model_name)
104
+ model = model.eval()
105
+
106
+ text_encoded = tokenizer.__call__(text, add_special_tokens=False)
107
+ input_ids_ = text_encoded["input_ids"]
108
+
109
+ input_ids = [tokenizer.cls_token_id]
110
+ input_ids.extend(input_ids_)
111
+ if is_chat:
112
+ input_ids.append(tokenizer.sep_token_id)
113
+
114
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
115
+ input_ids = input_ids.to(device)
116
+
117
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
118
+
119
+ generation_kwargs = dict(
120
+ inputs=input_ids,
121
+ max_new_tokens=max_new_tokens,
122
+ do_sample=True,
123
+ top_p=top_p,
124
+ temperature=temperature,
125
+ repetition_penalty=repetition_penalty,
126
+ eos_token_id=tokenizer.sep_token_id if is_chat else None,
127
+ pad_token_id=tokenizer.pad_token_id,
128
+ streamer=streamer,
129
+ )
130
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
131
+ thread.start()
132
+
133
+ output: str = ""
134
+ first_answer = True
135
+ for output_ in streamer:
136
+ if first_answer:
137
+ first_answer = False
138
+ continue
139
+
140
+ output_ = output_.replace("[UNK] ", "")
141
+ output_ = output_.replace("[UNK]", "")
142
+ output_ = output_.replace("[CLS] ", "")
143
+ output_ = output_.replace("[CLS]", "")
144
+
145
+ output += output_
146
+ if output.startswith("[SEP]"):
147
+ output = output[5:]
148
+
149
+ output = output.lstrip(" ,.!?")
150
+ output = remove_space_between_cn_en(output)
151
+ # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ([,。!?\u4e00-\u9fa5])", repl1, output)
152
+ # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ", repl2, output)
153
+
154
+ output = output.replace("[SEP] ", "\n")
155
+ output = output.replace("[SEP]", "\n")
156
+
157
+ yield output
158
+
159
+ model_name_choices = ["trained_models/lip_service_4chan", "trained_models/chinese_porn_novel"] \
160
+ if platform.system() == "Windows" else \
161
+ [
162
+ "qgyd2021/lip_service_4chan", "qgyd2021/chinese_chitchat",
163
+ "qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent",
164
+ "qgyd2021/similar_question_generation"
165
+ ]
166
+
167
+ demo = gr.Interface(
168
+ fn=fn_stream,
169
+ inputs=[
170
+ input_text_box,
171
+ gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"),
172
+ gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
173
+ gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
174
+ gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
175
+ gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"),
176
+ gr.Checkbox(value=True, label="is_chat")
177
+ ],
178
+ outputs=[output_text_box],
179
+ examples=examples,
180
+ cache_examples=False,
181
+ examples_per_page=50,
182
+ title="GPT2 Chat",
183
+ description=description,
184
+ )
185
+ demo.queue().launch()
186
+
187
+ return
188
+
189
+
190
+ if __name__ == '__main__':
191
+ main()
project_settings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+
7
+ project_path = os.path.abspath(os.path.dirname(__file__))
8
+ project_path = Path(project_path)
9
+
10
+
11
+ if __name__ == '__main__':
12
+ pass
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.41.2
2
+ pydantic==1.10.12
3
+ thinc==7.4.6
4
+ spacy==2.3.9
5
+ transformers==4.30.2
6
+ numpy==1.21.4
7
+ tqdm==4.62.3
8
+ torch==1.13.0
9
+ datasets