jonathanlehner commited on
Commit
8c7c98a
·
1 Parent(s): 2ff1e50

added dialoggpt

Browse files
Files changed (9) hide show
  1. .gitignore +36 -0
  2. Pipfile +21 -0
  3. README 2.md +38 -0
  4. ai_single_response.py +278 -0
  5. app.py +196 -0
  6. config.json +34 -0
  7. file_test.py +3 -0
  8. requirements.txt +101 -0
  9. utils.py +282 -0
.gitignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python basics
2
+ /__pycache__/
3
+ /.idea/
4
+ /scratch/
5
+
6
+
7
+ # local model folders for testing / running bots / deploy
8
+
9
+ /gpt2_std_gpu_774M_120ksteps/
10
+ /gpt2_std_gpu_774M_60ksteps/
11
+ /gpt2_dailydialogue_355M_75Ksteps/
12
+ /gp2_DDandPeterTexts_14kPeter_774M/
13
+ /gp2_DDandPeterTexts_41kPeter-774M/
14
+ /gp2_DDandPeterTexts_774M_73Ksteps/
15
+ /gp2_DDandPeterTexts_gpu_774M_175Ksteps/
16
+ *checkpoint*
17
+ *GPT2*
18
+ *GPTneo*
19
+ *GPTpeter*
20
+ *1pt3B*
21
+
22
+ # most of ^ can be downloaded through `download_models.py`
23
+
24
+ # gradio things
25
+ *.db
26
+ *.db-journal
27
+ *gradio_queue*
28
+ gradio_data
29
+ deploy-as-bot/flagged
30
+ deploy-as-bot/gradio_data
31
+ deploy-as-bot/gradio_queue.db
32
+
33
+
34
+ # notebooks containing personal data
35
+ .DS_Store
36
+ aitextgen
Pipfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+ natsort = "==7.1.1"
8
+ pandas = "==1.3.0"
9
+ symspellpy = "==6.7.0"
10
+ requests = "==2.24.0"
11
+ transformers = "==4.8.2"
12
+ gradio = "==1.7.7"
13
+ tqdm = "==4.43.0"
14
+ aitextgen = "==0.5.2"
15
+ cleantext = "==1.1.3"
16
+ telegram = "==0.0.1"
17
+
18
+ [dev-packages]
19
+
20
+ [requires]
21
+ python_version = "3.8"
README 2.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ai Msgbot Gpt2 M XL
3
+ emoji: 📉
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
+
28
+ `sdk_version` : _string_
29
+ Only applicable for `streamlit` SDK.
30
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
+
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
+ Path is relative to the root of the repository.
35
+
36
+ `pinned`: _boolean_
37
+ Whether the Space stays on top of your list.
38
+
ai_single_response.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ai_single_response.py
3
+
4
+ An executable way to call the model. example:
5
+ *\gpt2_chatbot> python .\ai_single_response.py --prompt "where is the grocery store?" --time
6
+
7
+ extended-summary:
8
+
9
+ A system and method for interacting with a virtual machine using a series of messages , each message having associated otherwise one or more actions to be taken by the machine. The speaker participates in a chat with a responder , and the response from the responder is returned.
10
+
11
+ """
12
+ import argparse
13
+ import pprint as pp
14
+ import time
15
+ import warnings
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+ from cleantext import clean
19
+
20
+ warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
21
+
22
+ from aitextgen import aitextgen
23
+
24
+
25
+ def query_gpt_model(
26
+ folder_path,
27
+ prompt_msg: str,
28
+ speaker=None,
29
+ responder="person beta",
30
+ kparam=150,
31
+ temp=0.75,
32
+ top_p=0.65,
33
+ verbose=False,
34
+ use_gpu=False,
35
+ ):
36
+ """
37
+ query_gpt_model [pass a prompt in to model, get a response. Does NOT "remember" past conversation]
38
+
39
+ Args:
40
+ folder_path ([type]): [description]
41
+ prompt_msg (str): [description]
42
+ speaker ([type], optional): [description]. Defaults to None.
43
+ responder (str, optional): [description]. Defaults to "person beta".
44
+ kparam (int, optional): [description]. Defaults to 125.
45
+ temp (float, optional): [description]. Defaults to 0.75.
46
+ top_p (float, optional): [description]. Defaults to 0.65.
47
+ verbose (bool, optional): [description]. Defaults to False.
48
+ use_gpu (bool, optional): [description]. Defaults to False.
49
+
50
+ Returns:
51
+ [dict]: [returns a dict with A) just model response as str B) total conversation]
52
+ """
53
+ ai = aitextgen(
54
+ model="microsoft/DialoGPT-medium",
55
+ #model_folder=folder_path,
56
+ to_gpu=False,
57
+ )
58
+ print("loaded model")
59
+ p_list = []
60
+ if "natqa" in str(folder_path).lower():
61
+ speaker = "person alpha" # manual correction
62
+ responder = "person beta"
63
+ if "wow" in str(folder_path).lower():
64
+ speaker = "person alpha" # manual correction
65
+ responder = "person beta"
66
+ if "peter" in str(folder_path).lower():
67
+ speaker = None # manual correction
68
+ responder = "peter szemraj"
69
+ if speaker is not None:
70
+ p_list.append(speaker.lower() + ":" + "\n") # write prompt as the speaker
71
+ p_list.append(prompt_msg.lower() + "\n")
72
+ p_list.append("\n")
73
+ p_list.append(responder.lower() + ":" + "\n")
74
+ this_prompt = "".join(p_list)
75
+ if verbose:
76
+ print("overall prompt:\n")
77
+ pp.pprint(this_prompt, indent=4)
78
+ print("\n... generating... \n")
79
+ this_result = ai.generate(
80
+ n=1,
81
+ top_k=kparam,
82
+ batch_size=512,
83
+ max_length=128,
84
+ min_length=16,
85
+ prompt=this_prompt,
86
+ temperature=temp,
87
+ top_p=top_p,
88
+ do_sample=True,
89
+ return_as_list=True,
90
+ use_cache=True,
91
+ )
92
+ if verbose:
93
+ pp.pprint(this_result) # to see what is going on
94
+ try:
95
+ this_result = str(this_result[0]).split("\n")
96
+ res_out = [clean(ele) for ele in this_result]
97
+ p_out = [clean(ele) for ele in p_list]
98
+ if verbose:
99
+ pp.pprint(res_out) # to see what is going on
100
+ pp.pprint(p_out) # to see what is going on
101
+
102
+ diff_list = []
103
+ name_counter = 0
104
+ break_safe = False
105
+ for resline in res_out:
106
+
107
+ if (responder + ":") in resline:
108
+ name_counter += 1
109
+ break_safe = True # next line a response from bot
110
+ continue
111
+ if ":" in resline and name_counter > 0:
112
+ if break_safe:
113
+ diff_list.append(resline)
114
+ break_safe = False
115
+ else:
116
+ break
117
+ if resline in p_out:
118
+ break_safe = False
119
+ continue
120
+
121
+ else:
122
+ diff_list.append(resline)
123
+ break_safe = False
124
+
125
+ if verbose:
126
+ print("------------------------diff list: ")
127
+ pp.pprint(diff_list) # to see what is going on
128
+ print("---------------------------------")
129
+
130
+ output = ", ".join(diff_list)
131
+
132
+ except:
133
+ output = "oops, there was an error. try again"
134
+
135
+ p_list.append(output + "\n")
136
+ p_list.append("\n")
137
+
138
+ model_responses = {"out_text": output, "full_conv": p_list}
139
+ print("finished!\n")
140
+
141
+ return model_responses
142
+
143
+
144
+ # Set up the parsing of command-line arguments
145
+ def get_parser():
146
+ """
147
+ get_parser [a helper function for the argparse module]
148
+
149
+ Returns:
150
+ [argparse.ArgumentParser]: [the argparser relevant for this script]
151
+ """
152
+
153
+ parser = argparse.ArgumentParser(
154
+ description="submit a message and have a 774M parameter GPT model respond"
155
+ )
156
+ parser.add_argument(
157
+ "--prompt",
158
+ required=True, # MUST HAVE A PROMPT
159
+ type=str,
160
+ help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
161
+ )
162
+ parser.add_argument(
163
+ "--model",
164
+ required=False,
165
+ type=str,
166
+ # "gp2_DDandPeterTexts_774M_73Ksteps", - from GPT-Peter
167
+ default="GPT2_trivNatQAdailydia_774M_175Ksteps",
168
+ help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
169
+ "config.json). No models? Run the script download_models.py",
170
+ )
171
+
172
+ parser.add_argument(
173
+ "--speaker",
174
+ required=False,
175
+ default=None,
176
+ help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
177
+ )
178
+ parser.add_argument(
179
+ "--responder",
180
+ required=False,
181
+ default="person beta",
182
+ help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
183
+ )
184
+
185
+ parser.add_argument(
186
+ "--topk",
187
+ required=False,
188
+ type=int,
189
+ default=150,
190
+ help="how many responses to sample (positive integer). lower = more random responses",
191
+ )
192
+
193
+ parser.add_argument(
194
+ "--temp",
195
+ required=False,
196
+ type=float,
197
+ default=0.75,
198
+ help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
199
+ )
200
+
201
+ parser.add_argument(
202
+ "--topp",
203
+ required=False,
204
+ type=float,
205
+ default=0.65,
206
+ help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
207
+ )
208
+
209
+ parser.add_argument(
210
+ "--verbose",
211
+ default=False,
212
+ action="store_true",
213
+ help="pass this argument if you want all the printouts",
214
+ )
215
+ parser.add_argument(
216
+ "--time",
217
+ default=False,
218
+ action="store_true",
219
+ help="pass this argument if you want to know runtime",
220
+ )
221
+ return parser
222
+
223
+
224
+ if __name__ == "__main__":
225
+ args = get_parser().parse_args()
226
+ query = args.prompt
227
+ model_dir = str(args.model)
228
+ model_loc = Path.cwd() / model_dir
229
+ spkr = args.speaker
230
+ rspndr = args.responder
231
+ k_results = args.topk
232
+ my_temp = args.temp
233
+ my_top_p = args.topp
234
+ want_verbose = args.verbose
235
+ want_rt = args.time
236
+
237
+ # force-update the speaker+responder params for the generic model case
238
+ if "dailydialogue" in model_dir.lower():
239
+ spkr = "john smith"
240
+ rspndr = "nancy sellers"
241
+ # ^ arbitrary people created when parsing Daily Dialogue dataset
242
+ # # force-update the speaker+responder params
243
+ # for the generic model case
244
+ if "natqa" in model_dir.lower():
245
+ spkr = "person alpha"
246
+ rspndr = "person beta"
247
+ # ^ arbitrary people created when parsing NatQA + TriviaQA + Daily Dialogue datasets
248
+
249
+ st = time.time()
250
+
251
+ resp = query_gpt_model(
252
+ folder_path=model_loc,
253
+ prompt_msg=query,
254
+ speaker=spkr,
255
+ responder=rspndr,
256
+ kparam=k_results,
257
+ temp=my_temp,
258
+ top_p=my_top_p,
259
+ verbose=want_verbose,
260
+ use_gpu=False,
261
+ )
262
+
263
+ output = resp["out_text"]
264
+ pp.pprint(output, indent=4)
265
+
266
+ # pp.pprint(this_result[3].strip(), indent=4)
267
+ rt = round(time.time() - st, 1)
268
+
269
+ if want_rt:
270
+ print("took {runtime} seconds to generate. \n".format(runtime=rt))
271
+
272
+ if want_verbose:
273
+ print("finished - ", datetime.now())
274
+ if want_verbose:
275
+ p_list = resp["full_conv"]
276
+ print("A transcript of your chat is as follows: \n")
277
+ p_list = [item.strip() for item in p_list]
278
+ pp.pprint(p_list)
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ deploy-as-bot\gradio_chatbot.py
4
+
5
+ A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
6
+ Note that the URL is displayed once the script is run.
7
+
8
+ Set the working directory to */deploy-as-bot in terminal before running.
9
+
10
+ """
11
+ import os
12
+ import sys
13
+ from os.path import dirname
14
+
15
+ sys.path.append(dirname(dirname(os.path.abspath(__file__))))
16
+
17
+ import gradio as gr
18
+ import logging
19
+ import argparse
20
+ import time
21
+ import warnings
22
+ from pathlib import Path
23
+ from cleantext import clean
24
+ from transformers import pipeline
25
+ from datetime import datetime
26
+ from ai_single_response import query_gpt_model
27
+ #from gradio.networking import get_state, set_state
28
+ from flask import Flask, request, session, jsonify, abort, send_file, render_template, redirect
29
+
30
+ import nltk
31
+ nltk.download('stopwords')
32
+
33
+ warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
34
+
35
+ logging.basicConfig()
36
+ cwd = Path.cwd()
37
+ my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
38
+
39
+
40
+ def gramformer_correct(corrector, qphrase: str):
41
+ """
42
+ gramformer_correct - correct a string using a text2textgen pipeline model from transformers
43
+
44
+ Args:
45
+ corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
46
+ qphrase (str): [text to be corrected]
47
+
48
+ Returns:
49
+ [str]: [corrected text]
50
+ """
51
+
52
+ try:
53
+ corrected = corrector(
54
+ clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
55
+ )
56
+ return corrected[0]["generated_text"]
57
+ except:
58
+ print("NOTE - failed to correct with gramformer")
59
+ return clean(qphrase)
60
+
61
+
62
+ def ask_gpt(message: str, sender: str = ""):
63
+ """
64
+ ask_gpt - queries the relevant model with a prompt message and (optional) speaker name
65
+
66
+ Args:
67
+ message (str): prompt message to respond to
68
+ sender (str, optional): speaker aka who said the message. Defaults to "".
69
+
70
+ Returns:
71
+ [str]: [model response as a string]
72
+ """
73
+ st = time.time()
74
+ prompt = clean(message) # clean user input
75
+ prompt = prompt.strip() # get rid of any extra whitespace
76
+ if len(prompt) > 200:
77
+ prompt = prompt[-200:] # truncate
78
+ sender = clean(sender.strip())
79
+ if len(sender) > 2:
80
+ try:
81
+ prompt_speaker = clean(sender)
82
+ except:
83
+ # there was some issue getting that info, whatever
84
+ prompt_speaker = None
85
+ else:
86
+ prompt_speaker = None
87
+
88
+ resp = query_gpt_model(
89
+ folder_path=model_loc,
90
+ prompt_msg=prompt,
91
+ speaker=prompt_speaker,
92
+ kparam=150,
93
+ temp=0.75,
94
+ top_p=0.65, # optimize this with hyperparam search
95
+ )
96
+ bot_resp = gramformer_correct(corrector, qphrase=resp["out_text"])
97
+ rt = round(time.time() - st, 2)
98
+ print(f"took {rt} sec to respond")
99
+
100
+ return bot_resp
101
+
102
+
103
+ def chat(first_and_last_name, message):
104
+ """
105
+ chat - helper function that makes the whole gradio thing work.
106
+
107
+ Args:
108
+ first_and_last_name (str or None): [speaker of the prompt, if provided]
109
+ message (str): [description]
110
+
111
+ Returns:
112
+ [str]: [returns an html string to display]
113
+ """
114
+ history = session.get("my_state") or []
115
+ response = ask_gpt(message, sender=first_and_last_name)
116
+ history.append((f"{first_and_last_name}: " + message, " GPT-Model: " + response)) #+ " [end] "))
117
+ session["my_state"] = history
118
+ session.modified = True
119
+ #html = "<div class='chatbot'>"
120
+ #for user_msg, resp_msg in history:
121
+ # html += f"<div class='user_msg'>{user_msg}</div>"
122
+ # html += f"<div class='resp_msg' style='color: black'>{resp_msg}</div>"
123
+ #html += "</div>"
124
+ return response
125
+
126
+
127
+ def get_parser():
128
+ """
129
+ get_parser - a helper function for the argparse module
130
+
131
+ Returns:
132
+ [argparse.ArgumentParser]: [the argparser relevant for this script]
133
+ """
134
+
135
+ parser = argparse.ArgumentParser(
136
+ description="submit a message and have a 774M parameter GPT model respond"
137
+ )
138
+ parser.add_argument(
139
+ "--model",
140
+ required=False,
141
+ type=str,
142
+ # "gp2_DDandPeterTexts_774M_73Ksteps", - from GPT-Peter
143
+ default="GPT2_trivNatQAdailydia_774M_175Ksteps",
144
+ help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
145
+ "config.json). No models? Run the script download_models.py",
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--gram-model",
150
+ required=False,
151
+ type=str,
152
+ default="prithivida/grammar_error_correcter_v1",
153
+ help="text2text generation model ID from huggingface for the model to correct grammar",
154
+ )
155
+
156
+ return parser
157
+
158
+
159
+ if __name__ == "__main__":
160
+ args = get_parser().parse_args()
161
+ default_model = str(args.model)
162
+ model_loc = cwd.parent / default_model
163
+ model_loc = str(model_loc.resolve())
164
+ gram_model = args.gram_model
165
+ print(f"using model stored here: \n {model_loc} \n")
166
+ corrector = pipeline("text2text-generation", model=gram_model, device=-1)
167
+ print("Finished loading the gramformer model - ", datetime.now())
168
+ iface = gr.Interface(
169
+ chat,
170
+ inputs=["text", "text"],
171
+ outputs="html",
172
+ title="Real-Impact English Chat Demo 英语聊天演示",
173
+ description="A basic interface with a neural network model trained on general Q&A and conversation. Treat it like a friend! 带有模型的基本界面,进行了一般问答和对话训练。 请像朋友一样与他对话! \n first and last name 姓名 \n message 信息 \n Clear 清除 \nSubmit 确认 \n Screenshot 截屏",
174
+ article="**Important Notes & About: 重要说明 & 关于我们**\n"
175
+ "1. the model can take up to 200 seconds to respond sometimes, patience is a virtue. 该模型有时可能需要长达 60 秒的响应时间,请耐心等待。\n"
176
+ "2. entering a username is completely optional. 姓名输入是可选的。\n "
177
+ "3. the model was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement. 该模型在几个不同的数据集上训练而成,它所说的任何内容都应该经过事实核查,然后才能被视为真实陈述。\n ",
178
+ css="""
179
+ .chatbox {display:flex;flex-direction:column}
180
+ .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
181
+ .user_msg {background-color:cornflowerblue;color:white;align-self:start}
182
+ .resp_msg {background-color:lightgray;align-self:self-end}
183
+ """,
184
+ allow_screenshot=True,
185
+ allow_flagging=False,
186
+ flagging_dir="gradio_data",
187
+ flagging_options=[
188
+ "great response",
189
+ "doesn't make sense",
190
+ "bad/offensive response",
191
+ ],
192
+ enable_queue=True, # allows for dealing with multiple users simultaneously
193
+ #theme="darkhuggingface",
194
+ #server_name="0.0.0.0",
195
+ )
196
+ iface.launch(share=True)
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/content/drive/MyDrive/Programming/AI_peter/gpt2_dailydialogue_gpu_355M",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "bos_token_id": 50256,
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 50256,
11
+ "gradient_checkpointing": true,
12
+ "initializer_range": 0.02,
13
+ "layer_norm_epsilon": 1e-05,
14
+ "line_by_line": false,
15
+ "model_type": "gpt2",
16
+ "n_ctx": 1024,
17
+ "n_embd": 1024,
18
+ "n_head": 16,
19
+ "n_inner": null,
20
+ "n_layer": 24,
21
+ "n_positions": 1024,
22
+ "n_vocab": 50257,
23
+ "resid_pdrop": 0.1,
24
+ "scale_attn_weights": true,
25
+ "summary_activation": null,
26
+ "summary_first_dropout": 0.1,
27
+ "summary_proj_to_labels": true,
28
+ "summary_type": "cls_index",
29
+ "summary_use_proj": true,
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.11.3",
32
+ "use_cache": false,
33
+ "vocab_size": 50257
34
+ }
file_test.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import os
2
+
3
+ print(os.path.exists("/Users/jonathan/ai-msgbot/gpt2_dailydialogue_355M_150Ksteps/pytorch_model.bin"))
requirements.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ aitextgen==0.5.2
5
+ analytics-python==1.4.0
6
+ APScheduler==3.6.3
7
+ async-timeout==4.0.2
8
+ attrs==21.2.0
9
+ backoff==1.10.0
10
+ backports.zoneinfo==0.2.1
11
+ bcrypt==3.2.0
12
+ cachetools==4.2.2
13
+ certifi==2021.10.8
14
+ cffi==1.15.0
15
+ chardet==3.0.4
16
+ charset-normalizer==2.0.9
17
+ cleantext==1.1.3
18
+ click==8.0.3
19
+ cryptography==36.0.1
20
+ cycler==0.11.0
21
+ editdistpy==0.1.3
22
+ ffmpy==0.3.0
23
+ filelock==3.4.2
24
+ fire==0.4.0
25
+ Flask==2.0.2
26
+ Flask-CacheBuster==1.0.0
27
+ Flask-Cors==3.0.10
28
+ Flask-Login==0.5.0
29
+ fonttools==4.28.5
30
+ frozenlist==1.2.0
31
+ fsspec==2021.11.1
32
+ future==0.18.2
33
+ google-auth==2.3.3
34
+ google-auth-oauthlib==0.4.6
35
+ gradio==2.4.6
36
+ grpcio==1.43.0
37
+ huggingface-hub==0.2.1
38
+ idna==2.10
39
+ importlib-metadata==4.10.0
40
+ itsdangerous==2.0.1
41
+ Jinja2==3.0.3
42
+ joblib==1.1.0
43
+ kiwisolver==1.3.2
44
+ Markdown==3.3.6
45
+ markdown2==2.4.2
46
+ MarkupSafe==2.0.1
47
+ matplotlib==3.5.1
48
+ monotonic==1.6
49
+ multidict==5.2.0
50
+ natsort==7.1.1
51
+ nltk==3.6.6
52
+ numpy==1.21.5
53
+ oauthlib==3.1.1
54
+ openwa==1.3.16
55
+ packaging==21.3
56
+ pandas==1.3.5
57
+ paramiko==2.9.1
58
+ Pillow==8.4.0
59
+ protobuf==3.19.1
60
+ pyasn1==0.4.8
61
+ pyasn1-modules==0.2.8
62
+ pycparser==2.21
63
+ pycryptodome==3.12.0
64
+ pyDeprecate==0.3.1
65
+ pydub==0.25.1
66
+ PyNaCl==1.4.0
67
+ pyparsing==3.0.6
68
+ python-axolotl==0.2.3
69
+ python-axolotl-curve25519==0.4.1.post2
70
+ python-dateutil==2.8.2
71
+ python-telegram-bot==13.8.1
72
+ pytorch-lightning==1.5.7
73
+ pytz==2021.3
74
+ pytz-deprecation-shim==0.1.0.post0
75
+ PyYAML==6.0
76
+ regex==2021.11.10
77
+ requests==2.24.0
78
+ requests-oauthlib==1.3.0
79
+ rsa==4.8
80
+ sacremoses==0.0.46
81
+ selenium==3.141.0
82
+ six==1.16.0
83
+ symspellpy==6.7.6
84
+ tensorboard==2.7.0
85
+ tensorboard-data-server==0.6.1
86
+ tensorboard-plugin-wit==1.8.0
87
+ termcolor==1.1.0
88
+ tokenizers==0.10.3
89
+ torch==1.10.1
90
+ torchmetrics==0.6.2
91
+ tornado==6.1
92
+ tqdm==4.43.0
93
+ transformers==4.12.5
94
+ typing_extensions==4.0.1
95
+ tzdata==2021.5
96
+ tzlocal==4.1
97
+ urllib3==1.25.11
98
+ webwhatsapi==2.0.5
99
+ Werkzeug==2.0.2
100
+ yarl==1.7.2
101
+ zipp==3.6.0
utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ general utility functions for loading, saving, etc
3
+ """
4
+ import os
5
+ from pathlib import Path
6
+ import pprint as pp
7
+ import re
8
+ import shutil # zipfile formats
9
+ from datetime import datetime
10
+ from os.path import basename
11
+ from os.path import getsize, join
12
+
13
+ import requests
14
+ from cleantext import clean
15
+ from natsort import natsorted
16
+ from symspellpy import SymSpell
17
+ import pandas as pd
18
+ from tqdm.auto import tqdm
19
+
20
+
21
+ def get_timestamp():
22
+ return datetime.now().strftime("%b-%d-%Y_t-%H")
23
+
24
+
25
+ def correct_phrase_load(my_string: str):
26
+ """
27
+ correct_phrase_load [basic / unoptimized implementation of SymSpell to correct a string]
28
+
29
+ Args:
30
+ my_string (str): [text to be corrected]
31
+
32
+ Returns:
33
+ [type]: [description]
34
+ """
35
+ sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
36
+
37
+ dictionary_path = (
38
+ r"symspell_rsc/frequency_dictionary_en_82_765.txt" # from repo root
39
+ )
40
+ bigram_path = (
41
+ r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" # from repo root
42
+ )
43
+ # term_index is the column of the term and count_index is the
44
+ # column of the term frequency
45
+ sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
46
+ sym_spell.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
47
+
48
+ # max edit distance per lookup (per single word, not per whole input string)
49
+ suggestions = sym_spell.lookup_compound(
50
+ clean(my_string), max_edit_distance=2, ignore_non_words=True
51
+ )
52
+ if len(suggestions) < 1:
53
+ return my_string
54
+ else:
55
+ first_result = suggestions[0]
56
+ return first_result._term
57
+
58
+
59
+ def fast_scandir(dirname: str):
60
+ """
61
+ fast_scandir [an os.path-based means to return all subfolders in a given filepath]
62
+
63
+ Args:
64
+ dirname (str): [description]
65
+
66
+ Returns:
67
+ [list]: [description]
68
+ """
69
+
70
+ subfolders = [f.path for f in os.scandir(dirname) if f.is_dir()]
71
+ for dirname in list(subfolders):
72
+ subfolders.extend(fast_scandir(dirname))
73
+ return subfolders # list
74
+
75
+
76
+ def create_folder(directory: str):
77
+
78
+ os.makedirs(directory, exist_ok=True)
79
+
80
+
81
+ def chunks(lst: list, n: int):
82
+ """
83
+ chunks - Yield successive n-sized chunks from lst
84
+ Args:
85
+ lst (list): [description]
86
+ n (int): [description]
87
+
88
+ Yields:
89
+ [type]: [description]
90
+ """
91
+
92
+ for i in range(0, len(lst), n):
93
+ yield lst[i : i + n]
94
+
95
+
96
+ def chunky_pandas(my_df, num_chunks: int = 4):
97
+ """
98
+ chunky_pandas [split dataframe into `num_chunks` equal chunks, return each inside a list]
99
+
100
+ Args:
101
+ my_df (pd.DataFrame): [description]
102
+ num_chunks (int, optional): [description]. Defaults to 4.
103
+
104
+ Returns:
105
+ [type]: [description]
106
+ """
107
+ n = int(len(my_df) // num_chunks)
108
+ list_df = [my_df[i : i + n] for i in range(0, my_df.shape[0], n)]
109
+
110
+ return list_df
111
+
112
+
113
+ def load_dir_files(
114
+ directory: str, req_extension=".txt", return_type="list", verbose=False
115
+ ):
116
+ """
117
+ load_dir_files - an os.path based method of returning all files with extension `req_extension` in a given directory and subdirectories
118
+
119
+ Args:
120
+ directory (str): [description]
121
+ req_extension (str, optional): [description]. Defaults to ".txt".
122
+ return_type (str, optional): [description]. Defaults to "list".
123
+ verbose (bool, optional): [description]. Defaults to False.
124
+
125
+ Returns:
126
+ [type]: [description]
127
+ """
128
+ appr_files = []
129
+ # r=root, d=directories, f = files
130
+ for r, d, f in os.walk(directory):
131
+ for prefile in f:
132
+ if prefile.endswith(req_extension):
133
+ fullpath = os.path.join(r, prefile)
134
+ appr_files.append(fullpath)
135
+
136
+ appr_files = natsorted(appr_files)
137
+
138
+ if verbose:
139
+ print("A list of files in the {} directory are: \n".format(directory))
140
+ if len(appr_files) < 10:
141
+ pp.pprint(appr_files)
142
+ else:
143
+ pp.pprint(appr_files[:10])
144
+ print("\n and more. There are a total of {} files".format(len(appr_files)))
145
+
146
+ if return_type.lower() == "list":
147
+ return appr_files
148
+ else:
149
+ if verbose:
150
+ print("returning dictionary")
151
+
152
+ appr_file_dict = {}
153
+ for this_file in appr_files:
154
+ appr_file_dict[basename(this_file)] = this_file
155
+
156
+ return appr_file_dict
157
+
158
+
159
+ def URL_string_filter(text):
160
+ """
161
+ URL_string_filter - filter out nonstandard "text" characters
162
+
163
+ Args:
164
+ text ([type]): [description]
165
+
166
+ Returns:
167
+ [str]: [description]
168
+ """
169
+ custom_printable = (
170
+ "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ._"
171
+ )
172
+
173
+ filtered = "".join((filter(lambda i: i in custom_printable, text)))
174
+
175
+ return filtered
176
+
177
+
178
+ def getFilename_fromCd(cd):
179
+ if not cd:
180
+ return None
181
+ fname = re.findall("filename=(.+)", cd)
182
+ if len(fname) > 0:
183
+ output = fname[0]
184
+ elif cd.find("/"):
185
+ possible_fname = cd.rsplit("/", 1)[1]
186
+ output = URL_string_filter(possible_fname)
187
+ else:
188
+ output = None
189
+ return output
190
+
191
+
192
+ def get_zip_URL(
193
+ URLtoget: str,
194
+ extract_loc: str = None,
195
+ file_header: str = "dropboxexport_",
196
+ verbose: bool = False,
197
+ ):
198
+ """
199
+ get_zip_URL [summary]
200
+
201
+ Args:
202
+ URLtoget (str): [description]
203
+ extract_loc (str, optional): [description]. Defaults to None.
204
+ file_header (str, optional): [description]. Defaults to "dropboxexport_".
205
+ verbose (bool, optional): [description]. Defaults to False.
206
+
207
+ Returns:
208
+ [type]: [description]
209
+ """
210
+ r = requests.get(URLtoget, allow_redirects=True)
211
+ names = getFilename_fromCd(r.headers.get("content-disposition"))
212
+ fixed_fnames = names.split(";") # split the multiple results
213
+ this_filename = file_header + URL_string_filter(fixed_fnames[0])
214
+
215
+ # define paths and save the zip file
216
+ if extract_loc is None:
217
+ extract_loc = "dropbox_dl"
218
+ dl_place = join(os.getcwd(), extract_loc)
219
+ create_folder(dl_place)
220
+ save_loc = join(os.getcwd(), this_filename)
221
+ open(save_loc, "wb").write(r.content)
222
+ if verbose:
223
+ print("downloaded file size was {} MB".format(getsize(save_loc) / 1000000))
224
+
225
+ # unpack the archive
226
+ shutil.unpack_archive(save_loc, extract_dir=dl_place)
227
+ if verbose:
228
+ print("extracted zip file - ", datetime.now())
229
+ x = load_dir_files(dl_place, req_extension="", verbose=verbose)
230
+
231
+ # remove original
232
+ try:
233
+ os.remove(save_loc)
234
+ del save_loc
235
+ except:
236
+ print("unable to delete original zipfile - check if exists", datetime.now())
237
+
238
+ print("finished extracting zip - ", datetime.now())
239
+
240
+ return dl_place
241
+
242
+
243
+ def merge_dataframes(data_dir: str, ext=".xlsx", verbose=False):
244
+ """
245
+ merge_dataframes - given a filepath, loads and attempts to merge all files as dataframes
246
+
247
+ Args:
248
+ data_dir (str): [root directory to search in]
249
+ ext (str, optional): [anticipate file extension for the dataframes ]. Defaults to '.xlsx'.
250
+
251
+ Returns:
252
+ pd.DataFrame(): merged dataframe
253
+ """
254
+
255
+ src = Path(data_dir)
256
+ src_str = str(src.resolve())
257
+ mrg_df = pd.DataFrame()
258
+
259
+ all_reports = load_dir_files(directory=src_str, req_extension=ext, verbose=verbose)
260
+
261
+ failed = []
262
+
263
+ for df_path in tqdm(all_reports, total=len(all_reports), desc="joining data..."):
264
+
265
+ try:
266
+ this_df = pd.read_excel(df_path).convert_dtypes()
267
+
268
+ mrg_df = pd.concat([mrg_df, this_df], axis=0)
269
+ except:
270
+ short_p = os.path.basename(df_path)
271
+ print(
272
+ f"WARNING - file with extension {ext} and name {short_p} could not be read."
273
+ )
274
+ failed.append(short_p)
275
+
276
+ if len(failed) > 0:
277
+ print("failed to merge {} files, investigate as needed")
278
+
279
+ if verbose:
280
+ pp.pprint(mrg_df.info(True))
281
+
282
+ return mrg_df