|
--- |
|
language: |
|
- ko |
|
tags: |
|
- generated_from_keras_callback |
|
model-index: |
|
- name: t5-base-korean-chit-chat |
|
results: [] |
|
--- |
|
|
|
<!-- This model card has been generated automatically according to the information Keras had access to. You should |
|
probably proofread and complete it, then remove this comment. --> |
|
|
|
# t5-base-korean-chit-chat |
|
|
|
This model is a fine-tuning of paust/pko-t5-base model using AIHUB "ํ๊ตญ์ด SNS". This model infers the next conversation by using the conversation used on social media.. |
|
|
|
์ด ๋ชจ๋ธ์ paust/pko-t5-large model์ AIHUB "ํ๊ตญ์ด SNS"๋ฅผ ์ด์ฉํ์ฌ fine tunning ํ ๊ฒ์
๋๋ค. ์ด ๋ชจ๋ธ์ SNS์์์ ์ฌ์ฉ๋๋ ๋ํ๋ฅผ ์ด์ฉํ์ฌ ๋ค์ ๋ํ๋ฅผ ์ถ๋ก ํฉ๋๋ค. |
|
|
|
|
|
## Usage |
|
```python |
|
|
|
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MT5ForConditionalGeneration |
|
from transformers import AutoTokenizer, T5TokenizerFast |
|
import nltk |
|
nltk.download('punkt') |
|
|
|
|
|
model_dir = f"lcw99/t5-base-korean-chit-chat" |
|
|
|
max_input_length = 1024 |
|
|
|
text = """ |
|
A: ์ผํํ๋ฌ ๊ฐ๊น? B: ์ ์ข์. A: ์ธ์ ๊ฐ๊น? B: |
|
""" |
|
|
|
inputs = [text] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
|
|
inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt") |
|
output = model.generate(**inputs, num_beams=3, do_sample=True, min_length=20, max_length=500, num_return_sequences=3) |
|
for i in range(3): |
|
#print(output[i]) |
|
print("---", i) |
|
decoded_output = tokenizer.decode(output[i], skip_special_tokens=True) |
|
predicted_title = nltk.sent_tokenize(decoded_output) |
|
#print(decoded_output) |
|
print(predicted_title) |
|
|
|
import torch |
|
|
|
chat_history = [] |
|
# Let's chat for 5 lines |
|
for step in range(100): |
|
print("") |
|
user_input = input(">> User: ") |
|
chat_history.append("A: " + user_input) |
|
while len(chat_history) > 5: |
|
chat_history.pop(0) |
|
hist = "" |
|
for chat in chat_history: |
|
hist += "\n" + chat |
|
hist += "\nB: " |
|
new_user_input_ids = tokenizer.encode(hist, return_tensors='pt') |
|
|
|
bot_input_ids = new_user_input_ids |
|
|
|
# generated a response while limiting the total chat history to 1000 tokens, |
|
chat_history_ids = model.generate( |
|
bot_input_ids, max_length=200, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True, |
|
#top_k=100, |
|
#top_p=0.7, |
|
#temperature = 0.1 |
|
) |
|
|
|
bot_text = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True).replace("#@์ด๋ฆ#", "OOO") |
|
bot_text = bot_text.replace("\n", " / ") |
|
chat_history.append("B: " + bot_text) |
|
|
|
# pretty print last ouput tokens from bot |
|
print("Bot: {}".format(bot_text)) |
|
``` |
|
|
|
### Framework versions |
|
|
|
- Transformers 4.22.1 |
|
- TensorFlow 2.10.0 |
|
- Datasets 2.5.1 |
|
- Tokenizers 0.12.1 |
|
|