KoichiYasuoka
commited on
Commit
·
c4029e3
1
Parent(s):
c62c8b1
algorithm improved
Browse files
upos.py
CHANGED
@@ -6,7 +6,7 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
6 |
super().__init__(**kwargs)
|
7 |
x=self.model.config.label2id
|
8 |
y=[k for k in x if not k.startswith("I-")]
|
9 |
-
self.transition=numpy.full((len(x),len(x))
|
10 |
for k,v in x.items():
|
11 |
for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
|
12 |
self.transition[v,x[j]]=0
|
@@ -20,10 +20,10 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
20 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
21 |
z=e/e.sum(axis=-1,keepdims=True)
|
22 |
for i in range(m.shape[0]-1,0,-1):
|
23 |
-
m[i-1]+=numpy.
|
24 |
-
k=[
|
25 |
for i in range(1,m.shape[0]):
|
26 |
-
k.append(numpy.
|
27 |
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
|
28 |
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
|
29 |
for i,t in reversed(list(enumerate(w))):
|
|
|
6 |
super().__init__(**kwargs)
|
7 |
x=self.model.config.label2id
|
8 |
y=[k for k in x if not k.startswith("I-")]
|
9 |
+
self.transition=numpy.full((len(x),len(x)),-numpy.inf)
|
10 |
for k,v in x.items():
|
11 |
for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
|
12 |
self.transition[v,x[j]]=0
|
|
|
20 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
21 |
z=e/e.sum(axis=-1,keepdims=True)
|
22 |
for i in range(m.shape[0]-1,0,-1):
|
23 |
+
m[i-1]+=numpy.max(m[i]+self.transition,axis=1)
|
24 |
+
k=[self.model.config.label2id["SYM"]]
|
25 |
for i in range(1,m.shape[0]):
|
26 |
+
k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
|
27 |
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
|
28 |
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
|
29 |
for i,t in reversed(list(enumerate(w))):
|