4kasha commited on
Commit
527e550
·
1 Parent(s): b4fe012
Files changed (4) hide show
  1. aligner.py +132 -0
  2. app.py +170 -0
  3. requirements.txt +6 -0
  4. utils.py +106 -0
aligner.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import ot
4
+ from utils import (
5
+ compute_distance_matrix_cosine,
6
+ compute_distance_matrix_l2,
7
+ compute_weights_norm,
8
+ compute_weights_uniform,
9
+ min_max_scaling
10
+ )
11
+
12
+ class Aligner:
13
+ def __init__(self, ot_type, sinkhorn, chimera, dist_type, weight_type, distortion, thresh, tau, **kwargs):
14
+ self.ot_type = ot_type
15
+ self.sinkhorn = sinkhorn
16
+ self.chimera = chimera
17
+ self.dist_type = dist_type
18
+ self.weight_type = weight_type
19
+ self.distotion = distortion
20
+ self.thresh = thresh
21
+ self.tau = tau
22
+ self.epsilon = 0.1
23
+ self.stopThr = 1e-6
24
+ self.numItermax = 1000
25
+ self.div_type = kwargs['div_type']
26
+
27
+ self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2
28
+ if weight_type == 'uniform':
29
+ self.weight_func = compute_weights_uniform
30
+ else:
31
+ self.weight_func = compute_weights_norm
32
+
33
+ def compute_alignment_matrixes(self, s1_vecs, s2_vecs):
34
+ self.align_matrixes = []
35
+ for vecX, vecY in zip(s1_vecs, s2_vecs):
36
+ P = self.compute_optimal_transport(vecX, vecY)
37
+ if torch.is_tensor(P):
38
+ P = P.to('cpu').numpy()
39
+
40
+ self.align_matrixes.append(P)
41
+
42
+ def get_alignments(self, thresh, assign_cost=False):
43
+ assert len(self.align_matrixes) > 0
44
+
45
+ self.thresh = thresh
46
+ all_alignments = []
47
+ for P in self.align_matrixes:
48
+ alignments = self.matrix_to_alignments(P, assign_cost)
49
+ all_alignments.append(alignments)
50
+
51
+ return all_alignments
52
+
53
+ def matrix_to_alignments(self, P, assign_cost):
54
+ alignments = set()
55
+ align_pairs = np.transpose(np.nonzero(P > self.thresh))
56
+ if assign_cost:
57
+ for i_j in align_pairs:
58
+ alignments.add('{0}-{1}-{2:.4f}'.format(i_j[0], i_j[1], P[i_j[0], i_j[1]]))
59
+ else:
60
+ for i_j in align_pairs:
61
+ alignments.add('{0}-{1}'.format(i_j[0], i_j[1]))
62
+
63
+ return alignments
64
+
65
+ def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
66
+ s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
67
+ s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
68
+
69
+ C = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
70
+ s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
71
+
72
+ if self.ot_type == 'ot':
73
+ s1_weights = s1_weights / s1_weights.sum()
74
+ s2_weights = s2_weights / s2_weights.sum()
75
+ s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
76
+
77
+ if self.sinkhorn:
78
+ P = ot.bregman.sinkhorn_log(s1_weights, s2_weights, C, reg=self.epsilon, stopThr=self.stopThr,
79
+ numItermax=self.numItermax)
80
+ else:
81
+ P = ot.emd(s1_weights, s2_weights, C)
82
+ # Min-max normalization
83
+ P = min_max_scaling(P)
84
+
85
+ elif self.ot_type == 'pot':
86
+ if self.chimera:
87
+ m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
88
+ m = min(1.0, m.item())
89
+ else:
90
+ m = self.tau
91
+
92
+ s1_weights, s2_weights, C = self.comvert_to_numpy(s1_weights, s2_weights, C)
93
+ m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * m
94
+
95
+ if self.sinkhorn:
96
+ P = ot.partial.entropic_partial_wasserstein(s1_weights, s2_weights, C,
97
+ reg=self.epsilon,
98
+ m=m, stopThr=self.stopThr, numItermax=self.numItermax)
99
+ else:
100
+ # To cope with round error
101
+ P = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m)
102
+ # Min-max normalization
103
+ P = min_max_scaling(P)
104
+
105
+ elif 'uot' in self.ot_type:
106
+ if self.chimera:
107
+ tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
108
+ else:
109
+ tau = self.tau
110
+
111
+ if self.ot_type == 'uot':
112
+ P = ot.unbalanced.sinkhorn_stabilized_unbalanced(s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau,
113
+ stopThr=self.stopThr, numItermax=self.numItermax)
114
+ elif self.ot_type == 'uot-mm':
115
+ P = ot.unbalanced.mm_unbalanced(s1_weights, s2_weights, C, reg_m=tau, div=self.div_type,
116
+ stopThr=self.stopThr, numItermax=self.numItermax)
117
+ # Min-max normalization
118
+ P = min_max_scaling(P)
119
+
120
+ elif self.ot_type == 'none':
121
+ P = 1 - C
122
+
123
+ return P
124
+
125
+ def comvert_to_numpy(self, s1_weights, s2_weights, C):
126
+ if torch.is_tensor(s1_weights):
127
+ s1_weights = s1_weights.to('cpu').numpy()
128
+ s2_weights = s2_weights.to('cpu').numpy()
129
+ if torch.is_tensor(C):
130
+ C = C.to('cpu').numpy()
131
+
132
+ return s1_weights, s2_weights, C
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from aligner import Aligner
7
+ from utils import align_matrix_heatmap, plot_align_matrix_heatmap
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ torch.manual_seed(42)
11
+ np.random.seed(42)
12
+ random.seed(42)
13
+
14
+
15
+ @st.cache_resource
16
+ def init_model(model: str):
17
+ tokenizer = AutoTokenizer.from_pretrained(model)
18
+ model = AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval()
19
+ return tokenizer, model
20
+
21
+
22
+ @st.cache_resource(max_entries=100)
23
+ def init_aligner(ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float):
24
+ return Aligner(
25
+ ot_type=ot_type,
26
+ sinkhorn=sinkhorn,
27
+ chimera=False,
28
+ dist_type="cos",
29
+ weight_type="uniform",
30
+ distortion=distortion,
31
+ thresh=threshhold,
32
+ tau=tau,
33
+ div_type="--"
34
+ )
35
+
36
+
37
+ def encode_sentence(sent, pair, tokenizer, model, layer: int):
38
+ if pair == None:
39
+ inputs = tokenizer(sent, padding=False, truncation=False, is_split_into_words=True, return_offsets_mapping=True,
40
+ return_tensors="pt")
41
+ with torch.no_grad():
42
+ outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
43
+ inputs['token_type_ids'].to(device))
44
+ else:
45
+ inputs = tokenizer(text=sent, text_pair=pair, padding=False, truncation=True,
46
+ is_split_into_words=True,
47
+ return_offsets_mapping=True, return_tensors="pt")
48
+ with torch.no_grad():
49
+ outputs = model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device),
50
+ inputs['token_type_ids'].to(device))
51
+
52
+ return outputs.hidden_states[layer][0], inputs['input_ids'][0], inputs['offset_mapping'][0]
53
+
54
+
55
+ def centering(hidden_outputs):
56
+ """
57
+ hidden_outputs : [tokens, hidden_size]
58
+ """
59
+ # 全てのトークンの埋め込みについて足し上げ、その平均ベクトルを求める
60
+ mean_vec = torch.sum(hidden_outputs, dim=0) / hidden_outputs.shape[0]
61
+ hidden_outputs = hidden_outputs - mean_vec
62
+ print(hidden_outputs.shape)
63
+ return hidden_outputs
64
+
65
+
66
+ def convert_to_word_embeddings(offset_mapping, token_ids, hidden_tensors, tokenizer, pair):
67
+ word_idx = -1
68
+ subword_to_word_conv = np.full((hidden_tensors.shape[0]), -1)
69
+ # Bug in hugging face tokenizer? Sometimes Metaspace is inserted
70
+ metaspace = getattr(tokenizer.decoder, "replacement", None)
71
+ metaspace = tokenizer.decoder.prefix if metaspace is None else metaspace
72
+ tokenizer_bug_idxes = [i for i, x in enumerate(tokenizer.convert_ids_to_tokens(token_ids)) if
73
+ x == metaspace]
74
+
75
+ for subw_idx, offset in enumerate(offset_mapping):
76
+ if subw_idx in tokenizer_bug_idxes:
77
+ continue
78
+ elif offset[0] == offset[1]: # Special token
79
+ continue
80
+ elif offset[0] == 0:
81
+ word_idx += 1
82
+ subword_to_word_conv[subw_idx] = word_idx
83
+ else:
84
+ subword_to_word_conv[subw_idx] = word_idx
85
+
86
+ word_embeddings = torch.vstack(
87
+ ([torch.mean(hidden_tensors[subword_to_word_conv == word_idx], dim=0) for word_idx in range(word_idx + 1)]))
88
+ print(word_embeddings.shape)
89
+
90
+ if pair:
91
+ sep_tok_indices = [i for i, x in enumerate(token_ids) if x == tokenizer.sep_token_id]
92
+ s2_start_idx = subword_to_word_conv[
93
+ sep_tok_indices[0] + np.argmax(subword_to_word_conv[sep_tok_indices[0]:] > -1)]
94
+
95
+ s1_word_embeddigs = word_embeddings[0:s2_start_idx, :]
96
+ s2_word_embeddigs = word_embeddings[s2_start_idx:, :]
97
+
98
+ return s1_word_embeddigs, s2_word_embeddigs
99
+ else:
100
+ return word_embeddings
101
+
102
+
103
+ def main():
104
+ st.set_page_config(layout="wide")
105
+
106
+ # Sidebar
107
+ st.sidebar.markdown("## Settings & Parameters")
108
+ model = st.sidebar.selectbox('model', ['microsoft/deberta-v3-base', 'bert-base-uncased'])
109
+ layer = st.sidebar.slider(
110
+ 'layer number for embeddings', 0, 11, value=9
111
+ )
112
+ is_centering = st.sidebar.checkbox('centering embeddings', value=True)
113
+ ot_type = st.sidebar.selectbox('ot_type', ['OT', 'POT', 'UOT'])
114
+ ot_type = ot_type.lower()
115
+ sinkhorn = st.sidebar.checkbox('sinkhorn', value=True)
116
+ distortion = st.sidebar.slider(
117
+ 'distortion: $\kappa$', 0.0, 1.0, value=0.20
118
+ )
119
+ tau = st.sidebar.slider(
120
+ 'tau: $\\tau$', 0.0, 1.0, value=0.98
121
+ ) # with 0.02 interva
122
+ threshhold = st.sidebar.slider(
123
+ 'threshhold: $\lambda$', 0.0, 1.0
124
+ ) # with 0.01 interval
125
+
126
+ # Content
127
+ st.markdown('## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment')
128
+
129
+ col1, col2 = st.columns(2)
130
+
131
+ with col1:
132
+ sent1 = st.text_area(
133
+ 'sentence 1',
134
+ 'By one estimate , fewer than 20,000 lions exist in the wild , a drop of about 40 percent in the past two decades .'
135
+ )
136
+ with col2:
137
+ sent2 = st.text_area(
138
+ 'sentence 2',
139
+ 'Today there are only around 20,000 wild lions left in the world .'
140
+ )
141
+
142
+ tokenizer, model = init_model(model)
143
+ aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
144
+
145
+ with st.container():
146
+ st.write("word alignment matrix")
147
+
148
+ if sent1 != '' and sent2 != '':
149
+ sent1 = sent1.lower().split()
150
+ sent2 = sent2.lower().split()
151
+ hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer)
152
+ if is_centering:
153
+ hidden_output = centering(hidden_output)
154
+ s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True)
155
+ aligner.compute_alignment_matrixes([s1_vec], [s2_vec])
156
+ align_matrix = aligner.align_matrixes[0]
157
+ print(align_matrix.shape)
158
+
159
+ #fig = align_matrix_heatmap(align_matrix.T, sent1, sent2, threshhold)
160
+ #st.plotly_chart(fig, use_container_width=True)
161
+ fig = plot_align_matrix_heatmap(align_matrix.T, sent1, sent2, threshhold)
162
+ st.pyplot(fig, dpi=300)
163
+
164
+ st.divider()
165
+ st.markdown("Note that the centering in this demo is applied only to the input sentences, so the variance may be large.")
166
+ st.subheader('Refs')
167
+ st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
168
+
169
+ if __name__ == '__main__':
170
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ POT==0.9.0
2
+ sentencepiece==0.1.99
3
+ streamlit==1.24.0
4
+ tokenizers==0.13.3
5
+ transformers==4.30.2
6
+ matplotlib==3.7.1
utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from ot.backend import get_backend
5
+ import plotly.graph_objects as go
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ def compute_distance_matrix_cosine(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
10
+ C = (torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) + 1.0) / 2 # Range 0-1
11
+ C = apply_distortion(C, distortion_ratio)
12
+ C = min_max_scaling(C) # Range 0-1
13
+ C = 1.0 - C # Convert to distance
14
+
15
+ return C
16
+
17
+
18
+ def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
19
+ C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2)
20
+ C = min_max_scaling(C) # Range 0-1
21
+ C = 1.0 - C # Convert to similarity
22
+ C = apply_distortion(C, distortion_ratio)
23
+ C = min_max_scaling(C) # Range 0-1
24
+ C = 1.0 - C # Convert to distance
25
+
26
+ return C
27
+
28
+
29
+ def apply_distortion(sim_matrix, ratio):
30
+ shape = sim_matrix.shape
31
+ if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
32
+ return sim_matrix
33
+
34
+ pos_x = torch.tensor([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
35
+ device=device)
36
+ pos_y = torch.tensor([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
37
+ device=device)
38
+ distortion_mask = 1.0 - ((pos_x - pos_y.T) ** 2) * ratio
39
+
40
+ sim_matrix = torch.mul(sim_matrix, distortion_mask)
41
+
42
+ return sim_matrix
43
+
44
+
45
+ def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
46
+ s1_weights = torch.norm(s1_word_embeddigs, dim=1)
47
+ s2_weights = torch.norm(s2_word_embeddigs, dim=1)
48
+ return s1_weights, s2_weights
49
+
50
+
51
+ def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
52
+ s1_weights = torch.ones(s1_word_embeddigs.shape[0], dtype=torch.float64, device=device)
53
+ s2_weights = torch.ones(s2_word_embeddigs.shape[0], dtype=torch.float64, device=device)
54
+
55
+ # # Uniform weights to make L2 norm=1
56
+ # s1_weights /= torch.linalg.norm(s1_weights)
57
+ # s2_weights /= torch.linalg.norm(s2_weights)
58
+
59
+ return s1_weights, s2_weights
60
+
61
+
62
+ def min_max_scaling(C):
63
+ eps = 1e-10
64
+ # Min-max scaling for stabilization
65
+ nx = get_backend(C)
66
+ C_min = nx.min(C)
67
+ C_max = nx.max(C)
68
+ C = (C - C_min + eps) / (C_max - C_min + eps)
69
+ return C
70
+
71
+
72
+ import seaborn as sns
73
+ import matplotlib.pyplot as plt
74
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
75
+
76
+ def plot_align_matrix_heatmap(align_matrix, sent1, sent2, thresh, **kwargs):
77
+
78
+ align_matrix = np.where(align_matrix <= thresh, 0, align_matrix)
79
+
80
+ fig, ax = plt.subplots(figsize=(10, 6))
81
+ sns.set(font='sans-serif', style="ticks")
82
+
83
+ _color = ['#F2F2F2', '#E0F4FA', '#BEE4F0', '#88CCE5', '#33b7df', '#1B88A6', '#105264', '#092E39']
84
+ _ticks = [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
85
+
86
+ divider = make_axes_locatable(ax)
87
+ cbar_ax = divider.append_axes("right", size="2.5%", pad=0.1)
88
+ fig.add_axes(cbar_ax)
89
+ ax = sns.heatmap(
90
+ align_matrix,
91
+ xticklabels=sent1,
92
+ yticklabels=sent2,
93
+ cmap=_color,
94
+ linewidths=1,
95
+ square=True,
96
+ ax=ax,
97
+ cbar_ax=cbar_ax,
98
+ **kwargs
99
+ )
100
+ ax.collections[0].colorbar.ax.yaxis.set_ticks(_ticks, minor=False)
101
+ ax.collections[0].colorbar.set_ticklabels(_ticks)
102
+ cax = ax.collections[0].colorbar.ax
103
+ cax.tick_params(which='major', length=3, labelsize=5)
104
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
105
+ ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
106
+ return fig