File size: 5,597 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
from fengshen.pipelines.tcbert import TCBertPipelines
from pytorch_lightning import seed_everything

def main():
    seed_everything(123)
    total_parser = argparse.ArgumentParser("Topic Classification")
    total_parser = TCBertPipelines.piplines_args(total_parser)
    args = total_parser.parse_args()

    pretrained_model_path = 'IDEA-CCNL/Erlangshen-TCBert-110M-Classification-Chinese'
    args.learning_rate = 2e-5
    args.max_length = 512
    args.max_epochs = 5
    args.batchsize = 4
    args.train = 'train'
    args.default_root_dir = './'
    # args.gpus = 1   #注意:目前使用CPU进行训练,取消注释会使用GPU,但需要配置相应GPU环境版本
    args.fixed_lablen = 2 #注意:可以设置固定标签长度,由于样本对应的标签长度可能不一致,建议选择适中的数值表示标签长度

    train_data = [    # 训练数据
        {"content": "真正的放养教育,放的是孩子的思维,养的是孩子的习惯", "label": "故事"},
        {"content": "《唐人街探案》捧红了王宝强跟刘昊然,唯独戏份不少的他发展最差", "label": "娱乐"},
        {"content": "油价攀升 阿曼经济加速增长", "label": "财经"},
        {"content": "日本男篮近期动作频频,中国队的未来劲敌会是他们吗?", "label": "体育"},
        {"content": "教育部:坚决防止因撤并乡村小规模学校导致学生上学困难", "label": "教育"},
        {"content": "LOL设计最完美的三个英雄,玩家们都很认可!", "label": "电竞"},
        {"content": "上联:浅看红楼终是梦,怎么对下联?", "label": "文化"},
        {"content": "楼市再出新政!北京部分限房价项目或转为共有产权房", "label": "房产"},
        {"content": "企业怎样选云服务器?云服务器哪家比较好?", "label": "科技"},
        {"content": "贝纳利的三缸车TRE899K、TRE1130K华丽转身", "label": "汽车"},
        {"content": "如何评价:刘姝威的《严惩做空中国股市者》?", "label": "股票"},
        {"content": "宁夏邀深圳市民共赴“寻找穿越”之旅", "label": "旅游"},
        {"content": "日本自民党又一派系力挺安倍 称会竭尽全力", "label": "国际"},
        {"content": "农村养老保险每年交5000,交满15年退休后能每月领多少钱?", "label": "农业"},
        {"content": "国产舰载机首次现身,进度超过预期,将率先在滑跃航母测试", "label": "军事"}
    ]

    dev_data = [     # 验证数据
        {"content": "西游记后传中,灵儿最爱的女人是谁?不是碧游!", "label": "故事"},
        {"content": "小李子莱奥纳多有特别的提袋子技能,这些年他还有过哪些神奇的造型?", "label": "娱乐"},
        {"content": "现在手上有钱是投资买房还是存钱,为什么?", "label": "财经"},
        {"content": "迪卡侬的衣服值得购买吗?", "label": "体育"},
        {"content": "黑龙江省旅游委在齐齐哈尔组织举办导游培训班", "label": "教育"},
        {"content": "《王者荣耀》中,哪些英雄的大招最“废柴”?", "label": "电竞"},
        {"content": "上交演绎马勒《复活》,用音乐带来抚慰和希望", "label": "文化"},
        {"content": "All in服务业,58集团在租房、住房市场的全力以赋", "label": "房产"},
        {"content": "为什么有的人宁愿选择骁龙660的X21,也不买骁龙845的小米MIX2S?", "label": "科技"},
        {"content": "众泰大型SUV来袭,售13.98万,2.0T榨出231马力,汉兰达要危险了", "label": "汽车"},
        {"content": "股票放量下趺,大资金出逃谁在接盘?", "label": "股票"},
        {"content": "广西博白最大的特色是什么?", "label": "旅游"},
        {"content": "特朗普退出《伊朗核协议》,对此你怎么看?", "label": "国际"},
        {"content": "卖水果利润怎么样?", "label": "农业"},
        {"content": "特种兵都是身材高大的猛男么?别再被电视骗了,超过1米8都不合格", "label": "军事"}
    ]

    test_data = [    # 测试数据
        {"content": "廖凡重出“江湖”再争影帝 亮相戛纳红毯霸气有型"},
        {"content": "《绝地求生: 刺激战场》越玩越卡?竟是手机厂商没交“保护费”!"},
        {"content": "买涡轮增压还是自然吸气车?今天终于有答案了!"},
    ]

    #标签映射  将真实标签可以映射为更合适prompt的标签 
    prompt_label = {  
                    "体育":"体育", "军事":"军事", "农业":"农业",  "国际":"国际", 
                    "娱乐":"娱乐", "房产":"房产", "故事":"故事",  "教育":"教育",
                    "文化":"文化", "旅游":"旅游", "汽车":"汽车",  "电竞":"电竞", 
                    "科技":"科技", "股票":"股票", "财经":"财经"
                    }
    
    #不同的prompt会影响模型效果
    #prompt = "这一句描述{}的内容如下:"
    prompt = "下面是一则关于{}的新闻:"
                    
    model = TCBertPipelines(args, model_path=pretrained_model_path, nlabels=len(prompt_label))

    if args.train:
        model.train(train_data, dev_data, prompt, prompt_label)
    result = model.predict(test_data, prompt, prompt_label)

    for i, line in enumerate(result):
        print({"content":test_data[i]["content"], "label":list(prompt_label.keys())[line]})


if __name__ == "__main__":
    main()