lcw99 commited on
Commit
57fd2d5
ยท
1 Parent(s): f3af4a9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +78 -27
README.md CHANGED
@@ -11,33 +11,84 @@ probably proofread and complete it, then remove this comment. -->
11
 
12
  # t5-base-korean-chit-chat
13
 
14
- This model was trained from scratch on an unknown dataset.
15
- It achieves the following results on the evaluation set:
16
-
17
-
18
- ## Model description
19
-
20
- More information needed
21
-
22
- ## Intended uses & limitations
23
-
24
- More information needed
25
-
26
- ## Training and evaluation data
27
-
28
- More information needed
29
-
30
- ## Training procedure
31
-
32
- ### Training hyperparameters
33
-
34
- The following hyperparameters were used during training:
35
- - optimizer: None
36
- - training_precision: float32
37
-
38
- ### Training results
39
-
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ### Framework versions
43
 
 
11
 
12
  # t5-base-korean-chit-chat
13
 
14
+ This model is a fine-tuning of paust/pko-t5-large model using AIHUB "ํ•œ๊ตญ์–ด SNS". This model infers the next conversation by using the conversation used on social media..
15
+
16
+ ์ด ๋ชจ๋ธ์€ paust/pko-t5-large model์„ AIHUB "ํ•œ๊ตญ์–ด SNS"๋ฅผ ์ด์šฉํ•˜์—ฌ fine tunning ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ SNS์ƒ์—์„œ ์‚ฌ์šฉ๋˜๋Š” ๋Œ€ํ™”๋ฅผ ์ด์šฉํ•˜์—ฌ ๋‹ค์Œ ๋Œ€ํ™”๋ฅผ ์ถ”๋ก  ํ•ฉ๋‹ˆ๋‹ค.
17
+
18
+
19
+ ## Usage
20
+ ```python
21
+
22
+ from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MT5ForConditionalGeneration
23
+ from transformers import AutoTokenizer, T5TokenizerFast
24
+ import nltk
25
+ nltk.download('punkt')
26
+
27
+
28
+ model_dir = f"lcw99/t5-base-korean-chit-chat"
29
+
30
+ max_input_length = 1024
31
+
32
+ text = """
33
+ A: ์‡ผํ•‘ํ•˜๋Ÿฌ ๊ฐˆ๊นŒ? B: ์‘ ์ข‹์•„. A: ์–ธ์ œ ๊ฐˆ๊นŒ? B:
34
+ """
35
+
36
+ inputs = [text]
37
+
38
+ inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
39
+ output = model.generate(**inputs, num_beams=3, do_sample=True, min_length=20, max_length=500, num_return_sequences=3)
40
+ for i in range(3):
41
+ #print(output[i])
42
+ print("---", i)
43
+ decoded_output = tokenizer.decode(output[i], skip_special_tokens=True)
44
+ predicted_title = nltk.sent_tokenize(decoded_output)
45
+ #print(decoded_output)
46
+ print(predicted_title)
47
+
48
+ import torch
49
+
50
+ chat_history = []
51
+ # Let's chat for 5 lines
52
+ for step in range(100):
53
+ print("")
54
+ user_input = input(">> User: ")
55
+ chat_history.append("A: " + user_input)
56
+ while len(chat_history) > 5:
57
+ chat_history.pop(0)
58
+ # print(chat_history)
59
+ hist = ""
60
+ for chat in chat_history:
61
+ hist += "\n" + chat
62
+ hist += "\nB: "
63
+ # print("====", len(chat_history))
64
+ # print("===>", hist.replace("\n", " / "))
65
+ # print("----")
66
+ # encode the new user input, add the eos_token and return a tensor in Pytorch
67
+ new_user_input_ids = tokenizer.encode(hist, return_tensors='pt')
68
+ # print(new_user_input_ids)
69
+
70
+ # append the new user input tokens to the chat history
71
+ #bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
72
+ bot_input_ids = new_user_input_ids
73
+
74
+ # generated a response while limiting the total chat history to 1000 tokens,
75
+ chat_history_ids = model.generate(
76
+ bot_input_ids, max_length=200,
77
+ pad_token_id=tokenizer.eos_token_id,
78
+ #no_repeat_ngram_size=3,
79
+ do_sample=True,
80
+ #top_k=100,
81
+ #top_p=0.7,
82
+ #temperature = 0.1
83
+ )
84
+
85
+ bot_text = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True).replace("#@์ด๋ฆ„#", "OOO")
86
+ bot_text = bot_text.replace("\n", " / ")
87
+ chat_history.append("B: " + bot_text)
88
+
89
+ # pretty print last ouput tokens from bot
90
+ print("Bot: {}".format(bot_text))
91
+ ```
92
 
93
  ### Framework versions
94