File size: 4,393 Bytes
3a52f83
a7caa81
 
 
 
 
 
 
 
 
 
641e3b0
 
a7caa81
 
3a52f83
a7caa81
3a52f83
a7caa81
 
3a52f83
 
 
 
 
a7caa81
 
 
3a52f83
a7caa81
 
3d0cddb
 
a7caa81
24769be
 
a7caa81
3d0cddb
a7caa81
 
 
 
 
68e4882
a7caa81
 
a91df99
 
a7caa81
 
 
 
 
 
 
 
 
dd79cf8
 
eb1db6e
dd79cf8
 
 
 
 
24769be
a7caa81
 
3a52f83
a7caa81
 
 
 
 
 
 
 
 
 
3a52f83
 
 
a7caa81
3a52f83
a7caa81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def _forward(self,model_inputs):
    import torch
    v=model_inputs["input_ids"][0].tolist()
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
    return {"logits":e.logits[:,1:-2,:],**model_inputs}
  def postprocess(self,model_outputs,**kwargs):
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    e=model_outputs["logits"].numpy()
    r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
    e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,-numpy.inf)
    g=self.model.config.label2id["X|_|goeswith"]
    m,r=numpy.max(e,axis=2),numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=1
        if numpy.argmax(e[i,j-1])==g and numpy.argmax(m[:,j-1])==i:
          r[i,j]=r[i,j-1]
    e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    t=model_outputs["sentence"].replace("\n"," ")
    v=[(s,e,c if c!=self.tokenizer.unk_token else t[s:e]) for (s,e),c in zip(model_outputs["offset_mapping"][0].tolist(),self.tokenizer.convert_ids_to_tokens(model_outputs["input_ids"][0].tolist())) if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
    if g:
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="goeswith" and set([k[-1] for k in q[h[i]+1:i+1]])=={"goeswith"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          s,e,c=v.pop(i)
          v[i-1]=(v[i-1][0],e,v[i-1][2]+c)
          q.pop(i)
    u="\n"
    z={"a":"ア","i":"イ","u":"ウ","e":"エ","o":"オ","k":"ㇰ","s":"ㇱ","t":"ㇳ","n":"ㇴ","h":"ㇷ","m":"ㇺ","r":"ㇽ","p":"ㇷ゚"}
    f=-1
    for i,(s,e,c) in reversed(list(enumerate(v))):
      if t[s]=="\u309a":
        s-=1
      w,x=[j for j in t[s:e]],""
      if i>0 and s<v[i-1][1]:
        w[0]=z[c[0]] if c[0] in z else "ッ"
        f=max(f,i)
      elif f>0:
        x="{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t{}\n".format(i+1,f+1,t[s:v[f][1]],"_" if f+1<len(v) and v[f][1]<v[f+1][0] else "SpaceAfter=No")
        f=-1
      if i+1<len(v) and e>v[i+1][0]:
        w[-1]=z[c[-1]] if c[-1] in z else "ッ"
      if g:
        l="".join(w).replace(" ","") if max(w)<"z" else c
        l=l.replace("sh","s").replace("ch","c").replace("au","aw").replace("iu","iw").replace("eu","ew").replace("uu","uw").replace("ou","ow").replace("ai","ay").replace("ui","uy").replace("ei","ey").replace("oi","oy")
        if q[i][1]=="人称接辞":
          if l.find("=")<0:
            l="="+l if i>h[i] else l+"="
      else:
        l="_"
      u=x+"\t".join([str(i+1),"".join(w),l,q[i][0],"|".join(q[i][1:-1]),"_",str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"+u
    return "# text = "+t+"\n"+u
  def chu_liu_edmonds(self,matrix):
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h