File size: 5,894 Bytes
5896126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ebd785
5896126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f156762
5896126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#! /bin/sh
test -f ja_gsd_modern.conllu || curl -LO https://github.com/KoichiYasuoka/SuPar-UniDic/raw/main/suparunidic/suparmodels/ja_gsd_modern.conllu
test -f JapaneseCoreKanji.txt || curl -LO https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt

if [ ! -d exSwallow-7b-plus-hf ]
then TMPA=./maker$$a.py
     cat << 'EOF' > $TMPA
#! /usr/bin/python3
src="tokyotech-llm/Swallow-7b-plus-hf"
tgt="exSwallow-7b-plus-hf"
import json,torch,unicodedata
from transformers import LlamaTokenizerFast,LlamaForCausalLM
with open("JapaneseCoreKanji.txt","r",encoding="utf-8") as r:
  cjk=[chr(int(t,16)) for t in r.read().strip().split("\n") if not t.startswith("#")]
with open("ja_gsd_modern.conllu","r",encoding="utf-8") as r:
  for s in r:
    t=s.split("\t")
    if len(t)==10:
      for c in t[1]:
        if unicodedata.name(c).startswith("CJK "):
          cjk.append(c)
cjk=list(set(cjk))
tkz=LlamaTokenizerFast.from_pretrained(src,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>",add_prefix_space=False)
c={i:j[2:] for i,j in zip(cjk,tkz(cjk)["input_ids"]) if len(j)>3}
d=json.loads(tkz.backend_tokenizer.to_str())
for i,j in enumerate(c,len(tkz)):
  d["model"]["vocab"][j]=i
tkz.backend_tokenizer.from_str(json.dumps(d)).save("tokenizer.json")
mdl=LlamaForCausalLM.from_pretrained(src)
tkz=LlamaTokenizerFast(tokenizer_file="tokenizer.json",model_max_length=mdl.config.max_position_embeddings,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
e=mdl.resize_token_embeddings(len(tkz))
f=mdl.get_output_embeddings()
with torch.no_grad():
  for k,v in c.items():
    e.weight[d["model"]["vocab"][k],:]=e.weight[v,:].sum(0)
    f.weight[d["model"]["vocab"][k],:]=f.weight[v,:].sum(0)
mdl.set_input_embeddings(e)
mdl.set_output_embeddings(f)
mdl.save_pretrained(tgt)
tkz.save_pretrained(tgt)
EOF
     chmod 755 $TMPA
     $TMPA
fi

TMPB=./maker$$b.py
cat << 'EOF' > $TMPB
#! /usr/bin/env deepspeed
src="exSwallow-7b-plus-hf"
tgt="KoichiYasuoka/Swallow-7b-plus-upos"
from transformers import LlamaTokenizerFast,LlamaForTokenClassification,AutoConfig,DataCollatorForTokenClassification,TrainingArguments,Trainer

class UPOSFileDataset(object):
  def __init__(self,conllu,tokenizer):
    self.conllu=open(conllu,"r",encoding="utf-8")
    self.tokenizer=tokenizer
    self.seeks=[0]
    self.multiword={}
    label=set(["SYM"])
    s=self.conllu.readline()
    while s!="":
      if s=="\n":
        self.seeks.append(self.conllu.tell())
      else:
        w=s.split("\t")
        if len(w)==10:
          if w[0].isdecimal():
            label.add(w[3] if w[5]=="_" else w[3]+"|"+w[5])
          elif w[0].find("-")>0:
            t=w[0].split("-")
            f,j,k=w[1],[],[]
            for i in range(int(t[0]),int(t[1])+1):
              w=self.conllu.readline().split("\t")
              j.append(w[3] if w[5]=="_" else w[3]+"|"+w[5])
              k.append(w[1])
            p="+".join(j)
            label.add(p)
            if p in self.multiword:
              self.multiword[p][f]=list(k)
            else:
              self.multiword[p]={f:list(k)}
      s=self.conllu.readline()
    lid={}
    for i,l in enumerate(sorted(label)):
      lid[l],lid["B-"+l],lid["I-"+l]=i*3,i*3+1,i*3+2
    self.label2id=lid
  def __call__(*args):
    lid={l:i for i,l in enumerate(sorted(set(sum([list(t.label2id) for t in args],[]))))}
    for t in args:
      t.label2id=lid
    return lid
  def __del__(self):
    self.conllu.close()
  __len__=lambda self:len(self.seeks)-1
  def __getitem__(self,i):
    self.conllu.seek(self.seeks[i])
    form,upos=[],[]
    while self.conllu.tell()<self.seeks[i+1]:
      w=self.conllu.readline().split("\t")
      if len(w)==10:
        form.append(w[1])
        if w[0].isdecimal():
          upos.append(w[3] if w[5]=="_" else w[3]+"|"+w[5])
        elif w[0].find("-")>0:
          t=w[0].split("-")
          u=[]
          for j in range(int(t[0]),int(t[1])+1):
            k=self.conllu.readline().split("\t")
            u.append(k[3] if k[5]=="_" else k[3]+"|"+k[5])
          upos.append("+".join(u))
    v=self.tokenizer(form,add_special_tokens=False)
    i,u=[],[]
    for j,(x,y) in enumerate(zip(v["input_ids"],upos)):
      if x!=[]:
        i+=x
        u+=[y] if len(x)==1 else ["B-"+y]+["I-"+y]*(len(x)-1)
    if len(i)<self.tokenizer.model_max_length-3:
      ids=[self.tokenizer.cls_token_id]+i+[self.tokenizer.sep_token_id]
      upos=["SYM"]+u+["SYM"]
    else:
      ids=i[0:self.tokenizer.model_max_length-2]
      upos=u[0:self.tokenizer.model_max_length-2]
    return {"input_ids":ids,"labels":[self.label2id[t] for t in upos]}

tkz=LlamaTokenizerFast.from_pretrained(src)
trainDS=UPOSFileDataset("ja_gsd_modern.conllu",tkz)
lid=trainDS.label2id
cfg=AutoConfig.from_pretrained(src,num_labels=len(lid),label2id=lid,id2label={i:l for l,i in lid.items()},ignore_mismatched_sizes=True)
dsp={"fp16":{"enabled":"auto"},"optimizer":{"type":"AdamW"},"scheduler":{"type":"WarmupLR","params":{}},"train_batch_size":"auto","train_micro_batch_size_per_gpu":"auto","zero_optimization":{"stage":3,"offload_optimizer":{"device":"cpu","pin_memory":True},"offload_param":{"device":"cpu","pin_memory":True},"overlap_comm":True,"contiguous_gradients":True,"reduce_bucket_size":"auto","stage3_prefetch_bucket_size":"auto","stage3_param_persistence_threshold":"auto","stage3_gather_16bit_weights_on_model_save":True}}
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8,deepspeed=dsp,output_dir=tgt,overwrite_output_dir=True,save_total_limit=2,learning_rate=5e-05,warmup_ratio=0.1,save_safetensors=False)
trn=Trainer(args=arg,data_collator=DataCollatorForTokenClassification(tkz),model=LlamaForTokenClassification.from_pretrained(src,config=cfg,ignore_mismatched_sizes=True),train_dataset=trainDS)
trn.train()
trn.save_model(tgt)
tkz.save_pretrained(tgt)
EOF
chmod 755 $TMPB
$TMPB
exit