File size: 4,278 Bytes
0ab2524
4c1cdff
 
0ab2524
4c1cdff
 
 
 
 
 
0ab2524
4c1cdff
 
 
 
 
 
e42fd12
4c1cdff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e42fd12
4c1cdff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
---
language: 
  - zh
license: apache-2.0
tags:
  - chinese poem
  - 中文
  - 对联
widget:
- text: "对联:北国风光,千里冰封,万里雪飘"
---

# 一个好玩的中文AI对联模型
- 输入格式
  - `对联:您的上联`,比如 `对联:北国风光,千里冰封,万里雪飘`
- 如果你想尝试
  - 如果自己有GPU环境,可以参考我放在huggingface的[示例代码](https://huggingface.co/hululuzhu/chinese-couplet-t5-mengzi-finetune#%E8%BF%90%E8%A1%8C%E4%BB%A3%E7%A0%81%E7%A4%BA%E4%BE%8B)
  - 或者使用Google colab可以用这个[简单的colab notebook](https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/inference/2022_simple_couplet_inference_huggingface.ipynb)
- 训练代码请参考[我的github链接](https://github.com/hululuzhu/chinese-ai-writing-share)
- 如果想了解一些背景和讨论,可以看我的[slides](https://github.com/hululuzhu/chinese-ai-writing-share/tree/main/slides)

## 架构
- 预训练使用 [澜舟科技的孟子 T5](https://huggingface.co/Langboat/mengzi-t5-base)

## 数据来源
- 对联数据集 https://github.com/wb14123/couplet-dataset
  - 标准输入输出seq2seq,T5使用`对联:`前缀,长度限制32字符

## 语言支持
- 默认简体中文
- 支持繁体中文,参考下面代码标记 `is_input_traditional_chinese=True`

## 训练
- 我是用 Google Colab Pro(推荐,16G的GPU一个月才9.99!)

## 运行代码示例
```python
# 安装以下2个包方便文字处理和模型生成
# !pip install -q simplet5
# !pip install -q chinese-converter

# 具体代码
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration
import chinese_converter

MODEL_PATH = "hululuzhu/chinese-couplet-t5-mengzi-finetune"
class PoemModel(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self):
    self.tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH)
    self.model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)

COUPLET_PROMPOT = '对联:'
MAX_SEQ_LEN = 32
MAX_OUT_TOKENS = MAX_SEQ_LEN

def couplet(in_str, model=couplet_model,
            is_input_traditional_chinese=False,
            num_beams=2):
  model.model = model.model.to('cuda')
  in_request = f"{COUPLET_PROMPOT}{in_str[:MAX_SEQ_LEN]}"
  if is_input_traditional_chinese:
    # model only knows s chinese
    in_request = chinese_converter.to_simplified(in_request)
  # Note default sampling is turned off for consistent result
  out = model.predict(in_request,
                      max_length=MAX_OUT_TOKENS,
                      num_beams=num_beams)[0].replace(",", ",")
  if is_input_traditional_chinese:
    out = chinese_converter.to_traditional(out)
  print(f"上: {in_str}\n下: {out}")
```


## 简体中文示例
```python
for pre in ['欢天喜地度佳节',
            '不待鸣钟已汗颜,重来试手竟何艰',
            '当年欲跃龙门去,今日真披马革还',
            '北国风光,千里冰封,万里雪飘',
            '寂寞寒窗空守寡',
            '烟锁池塘柳',
            '五科五状元,金木水火土',
            '望江楼,望江流,望江楼上望江流,江楼千古,江流千古']:
  couplet(pre)

上: 欢天喜地度佳节
下: 笑语欢歌迎新春
上: 不待鸣钟已汗颜,重来试手竟何艰
下: 何堪击鼓频催泪?一别伤心更枉然
上: 当年欲跃龙门去,今日真披马革还
下: 此日当登虎榜来,他年又见龙图新
上: 北国风光,千里冰封,万里雪飘
下: 南疆气象,五湖浪涌,三江潮来
上: 寂寞寒窗空守寡
下: 逍遥野渡醉吟诗
上: 烟锁池塘柳
下: 云封岭上松
上: 五科五状元,金木水火土
下: 三才三进士,诗书礼乐诗
上: 望江楼,望江流,望江楼上望江流,江楼千古,江流千古
下: 听雨阁,听雨落,听雨阁中听雨落,雨阁万重,雨落万重
```

# 繁体中文
```python
for pre in ['飛龍在天', '臺北風光好']:
  couplet(pre, is_input_traditional_chinese=True, num_beams=10)

上: 飛龍在天
下: 飛鳳於天
上: 臺北風光好
下: 神州氣象新
```