Spaces:
Sleeping
Sleeping
Gabriela Nicole Gonzalez Saez
commited on
Commit
·
e4bccbf
1
Parent(s):
9e85aff
topk
Browse files- app.py +35 -7
- plotsjs.js +140 -4
app.py
CHANGED
@@ -16,8 +16,6 @@ from functools import partial
|
|
16 |
|
17 |
from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
|
18 |
|
19 |
-
|
20 |
-
|
21 |
model_es = "Helsinki-NLP/opus-mt-en-es"
|
22 |
model_fr = "Helsinki-NLP/opus-mt-en-fr"
|
23 |
model_zh = "Helsinki-NLP/opus-mt-en-zh"
|
@@ -75,6 +73,28 @@ contrastive_examples = [
|
|
75 |
]
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def split_token_from_sequences(sequences, model) -> dict :
|
79 |
n_sentences = len(sequences)
|
80 |
|
@@ -138,7 +158,8 @@ def split_token_from_sequences(sequences, model) -> dict :
|
|
138 |
return dict_parent
|
139 |
|
140 |
|
141 |
-
|
|
|
142 |
|
143 |
html = """
|
144 |
<html>
|
@@ -149,9 +170,13 @@ html = """
|
|
149 |
<p id="viz"></p>
|
150 |
|
151 |
<p id="demo2"></p>
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
|
154 |
-
<div id="d3_beam_search"></div>
|
155 |
|
156 |
</body>
|
157 |
</html>
|
@@ -175,16 +200,19 @@ def sentence_maker(w1, model, var2={}):
|
|
175 |
beam_dict = split_token_from_sequences(translated.sequences,model )
|
176 |
|
177 |
tgt_text = dict_tokenizer_tr[model].decode(translated.sequences[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
return [tgt_text,beam_dict]
|
180 |
|
181 |
def sentence_maker2(w1,j2):
|
182 |
-
# json_value = {'one':1}
|
183 |
-
# return f"{w1['two']} in sentence22..."
|
184 |
print(w1,j2)
|
185 |
return "in sentence22..."
|
186 |
|
187 |
|
|
|
188 |
with gr.Blocks(js="plotsjs.js") as demo:
|
189 |
gr.Markdown(
|
190 |
"""
|
|
|
16 |
|
17 |
from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
|
18 |
|
|
|
|
|
19 |
model_es = "Helsinki-NLP/opus-mt-en-es"
|
20 |
model_fr = "Helsinki-NLP/opus-mt-en-fr"
|
21 |
model_zh = "Helsinki-NLP/opus-mt-en-zh"
|
|
|
73 |
]
|
74 |
|
75 |
|
76 |
+
def get_k_prob_tokens(transition_scores, result, model, k_values=5):
|
77 |
+
tokenizer_tr = dict_tokenizer_tr[model]
|
78 |
+
gen_sequences = result.sequences[:, 1:]
|
79 |
+
|
80 |
+
result_output = []
|
81 |
+
# bs_alt = []
|
82 |
+
# bs_alt_scores = []
|
83 |
+
|
84 |
+
# First beam only...
|
85 |
+
bs = 0
|
86 |
+
text = ' '
|
87 |
+
for tok, score, i_step in zip(gen_sequences[bs], transition_scores[bs],range(len(gen_sequences[bs]))):
|
88 |
+
# bs_alt.append([tokenizer_tr.decode(tok) for tok in result.scores[i_step][bs].topk(k_values).indices ] )
|
89 |
+
# bs_alt_scores.append(np.exp(result.scores[i_step][bs].topk(k_values).values))
|
90 |
+
|
91 |
+
bs_alt = [tokenizer_tr.decode(tok) for tok in result.scores[i_step][bs].topk(k_values).indices ]
|
92 |
+
bs_alt_scores = np.exp(result.scores[i_step][bs].topk(k_values).values)
|
93 |
+
result_output.append([np.array(result.scores[i_step][bs].topk(k_values).indices), np.array(bs_alt_scores),bs_alt])
|
94 |
+
|
95 |
+
return result_output
|
96 |
+
|
97 |
+
|
98 |
def split_token_from_sequences(sequences, model) -> dict :
|
99 |
n_sentences = len(sequences)
|
100 |
|
|
|
158 |
return dict_parent
|
159 |
|
160 |
|
161 |
+
|
162 |
+
|
163 |
|
164 |
html = """
|
165 |
<html>
|
|
|
170 |
<p id="viz"></p>
|
171 |
|
172 |
<p id="demo2"></p>
|
173 |
+
<h4> Exploring top-k probable tokens </h4>
|
174 |
+
<div id="d3_text_grid">... top 10 tokens generated at each step ...</div>
|
175 |
+
|
176 |
+
<h4> Exploring the Beam Search sequence generation</h4>
|
177 |
+
<div id="d3_beam_search">... top 4 generated sequences using Beam Search...</div>
|
178 |
|
179 |
|
|
|
180 |
|
181 |
</body>
|
182 |
</html>
|
|
|
200 |
beam_dict = split_token_from_sequences(translated.sequences,model )
|
201 |
|
202 |
tgt_text = dict_tokenizer_tr[model].decode(translated.sequences[0], skip_special_tokens=True)
|
203 |
+
transition_scores = dict_models_tr[model].compute_transition_scores(
|
204 |
+
translated.sequences, translated.scores, translated.beam_indices , normalize_logits=True
|
205 |
+
)
|
206 |
+
prob_tokens = get_k_prob_tokens(transition_scores, translated, model, k_values=10)
|
207 |
|
208 |
+
return [tgt_text,[beam_dict,prob_tokens]]
|
209 |
|
210 |
def sentence_maker2(w1,j2):
|
|
|
|
|
211 |
print(w1,j2)
|
212 |
return "in sentence22..."
|
213 |
|
214 |
|
215 |
+
|
216 |
with gr.Blocks(js="plotsjs.js") as demo:
|
217 |
gr.Markdown(
|
218 |
"""
|
plotsjs.js
CHANGED
@@ -41,20 +41,24 @@ async () => {
|
|
41 |
|
42 |
|
43 |
globalThis.testFn_out_json = (data) => {
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
acc[el.id] = i;
|
46 |
return acc;
|
47 |
}, {});
|
48 |
|
49 |
let root;
|
50 |
-
|
51 |
// Handle the root element
|
52 |
if (el.parentId === null) {
|
53 |
root = el;
|
54 |
return;
|
55 |
}
|
56 |
-
// Use our mapping to locate the parent element in our
|
57 |
-
const parentEl =
|
58 |
// Add our current el to its parent's `children` array
|
59 |
parentEl.children = [...(parentEl.children || []), el];
|
60 |
});
|
@@ -63,6 +67,14 @@ async () => {
|
|
63 |
// document.getElementById('d3_beam_search').innerHTML = Tree(root)
|
64 |
d3.select('#d3_beam_search').html("");
|
65 |
d3.select('#d3_beam_search').append(function(){return Tree(root);});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
// $('#d3_beam_search').html(Tree(root)) ;
|
67 |
|
68 |
return(['string', {}])
|
@@ -206,6 +218,130 @@ function Tree(data, { // data is either tabular (array of objects) or hierarchy
|
|
206 |
return svg.node();
|
207 |
}
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
|
211 |
|
|
|
41 |
|
42 |
|
43 |
globalThis.testFn_out_json = (data) => {
|
44 |
+
console.log(data);
|
45 |
+
data_beam = data[0];
|
46 |
+
data_probs = data[1];
|
47 |
+
|
48 |
+
const idMapping = data_beam.reduce((acc, el, i) => {
|
49 |
acc[el.id] = i;
|
50 |
return acc;
|
51 |
}, {});
|
52 |
|
53 |
let root;
|
54 |
+
data_beam.forEach(el => {
|
55 |
// Handle the root element
|
56 |
if (el.parentId === null) {
|
57 |
root = el;
|
58 |
return;
|
59 |
}
|
60 |
+
// Use our mapping to locate the parent element in our data_beam array
|
61 |
+
const parentEl = data_beam[idMapping[el.parentId]];
|
62 |
// Add our current el to its parent's `children` array
|
63 |
parentEl.children = [...(parentEl.children || []), el];
|
64 |
});
|
|
|
67 |
// document.getElementById('d3_beam_search').innerHTML = Tree(root)
|
68 |
d3.select('#d3_beam_search').html("");
|
69 |
d3.select('#d3_beam_search').append(function(){return Tree(root);});
|
70 |
+
|
71 |
+
//probabilities;
|
72 |
+
//
|
73 |
+
d3.select('#d3_text_grid').html("");
|
74 |
+
d3.select('#d3_text_grid').append(function(){return TextGrid(data_probs);});
|
75 |
+
// $('#d3_text_grid').html(TextGrid(data)) ;
|
76 |
+
|
77 |
+
|
78 |
// $('#d3_beam_search').html(Tree(root)) ;
|
79 |
|
80 |
return(['string', {}])
|
|
|
218 |
return svg.node();
|
219 |
}
|
220 |
|
221 |
+
function TextGrid(data, div_name, {
|
222 |
+
width = 640, // outer width, in pixels
|
223 |
+
height , // outer height, in pixels
|
224 |
+
r = 3, // radius of nodes
|
225 |
+
padding = 1, // horizontal padding for first and last column
|
226 |
+
// text = d => d[2],
|
227 |
+
} = {}){
|
228 |
+
// console.log("TextGrid", data);
|
229 |
+
|
230 |
+
// Compute the layout.
|
231 |
+
const dx = 10;
|
232 |
+
const dy = 10; //width / (root.height + padding);
|
233 |
+
|
234 |
+
const marginTop = 20;
|
235 |
+
const marginRight = 20;
|
236 |
+
const marginBottom = 30;
|
237 |
+
const marginLeft = 30;
|
238 |
+
|
239 |
+
// Center the tree.
|
240 |
+
let x0 = Infinity;
|
241 |
+
let x1 = -x0;
|
242 |
+
topk = 10;
|
243 |
+
word_length = 20;
|
244 |
+
const rectWidth = 60;
|
245 |
+
const rectTotal = 70;
|
246 |
+
|
247 |
+
wval = 0
|
248 |
+
|
249 |
+
const realWidth = rectTotal * data.length
|
250 |
+
const totalWidth = (realWidth > width) ? realWidth : width;
|
251 |
+
// root.each(d => {
|
252 |
+
// if (d.x > x1) x1 = d.x;
|
253 |
+
// if (d.x < x0) x0 = d.x;
|
254 |
+
// });
|
255 |
+
|
256 |
+
// Compute the default height.
|
257 |
+
// if (height === undefined) height = x1 - x0 + dx * 2;
|
258 |
+
if (height === undefined) height = topk * word_length + 10;
|
259 |
+
|
260 |
+
const parent = d3.create("div");
|
261 |
+
|
262 |
+
// parent.append("svg")
|
263 |
+
// .attr("width", width)
|
264 |
+
// .attr("height", height)
|
265 |
+
// .style("position", "absolute")
|
266 |
+
// .style("pointer-events", "none")
|
267 |
+
// .style("z-index", 1);
|
268 |
+
|
269 |
+
|
270 |
+
// const svg = d3.create("svg")
|
271 |
+
// // svg = parent.append("svg")
|
272 |
+
// .attr("viewBox", [-dy * padding / 2, x0 - dx, width, height])
|
273 |
+
// .attr("width", width)
|
274 |
+
// .attr("height", height)
|
275 |
+
// .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
|
276 |
+
// .attr("font-family", "sans-serif")
|
277 |
+
// .attr("font-size", 10);
|
278 |
+
|
279 |
+
// div.data([1, 2, 4, 8, 16, 32], d => d);
|
280 |
+
// div.enter().append("div").text(d => d);
|
281 |
+
|
282 |
+
const body = parent.append("div")
|
283 |
+
.style("overflow-x", "scroll")
|
284 |
+
.style("-webkit-overflow-scrolling", "touch");
|
285 |
+
|
286 |
+
const svg = body.append("svg")
|
287 |
+
.attr("width", totalWidth)
|
288 |
+
.attr("height", height)
|
289 |
+
.style("display", "block")
|
290 |
+
.attr("font-family", "sans-serif")
|
291 |
+
.attr("font-size", 10);
|
292 |
+
|
293 |
+
|
294 |
+
data.forEach(words_list => {
|
295 |
+
// console.log(wval, words_list);
|
296 |
+
words = words_list[2]; // {'t': words_list[2], 'p': words_list[1]};
|
297 |
+
scores = words_list[1];
|
298 |
+
words_score = words.map( (x,i) => {return {t: x, p: scores[i]}})
|
299 |
+
// console.log(words_score);
|
300 |
+
// svg.selectAll("text").enter()
|
301 |
+
// .data(words)
|
302 |
+
// .join("text")
|
303 |
+
// .text((d,i) => (d))
|
304 |
+
// .attr("x", wval)
|
305 |
+
// .attr("y", ((d,i) => (20 + i*20)))
|
306 |
+
|
307 |
+
var probs = svg.selectAll("text").enter()
|
308 |
+
.data(words_score).join('g');
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
probs.append("rect")
|
313 |
+
// .data(words)
|
314 |
+
.attr("x", wval)
|
315 |
+
.attr("y", ((d,i) => ( 10+ i*20)))
|
316 |
+
.attr('width', rectWidth)
|
317 |
+
.attr('height', 15)
|
318 |
+
.attr("color", 'gray')
|
319 |
+
.attr("fill", "gray")
|
320 |
+
// .attr("fill-opacity", "0.2")
|
321 |
+
.attr("fill-opacity", (d) => (d.p))
|
322 |
+
.attr("stroke-opacity", 0.8)
|
323 |
+
.append("svg:title")
|
324 |
+
.text(function(d){return d.t+":"+d.p;});
|
325 |
+
|
326 |
+
|
327 |
+
probs.append("text")
|
328 |
+
// .data(words)
|
329 |
+
.text((d,i) => (d.t))
|
330 |
+
.attr("x", wval)
|
331 |
+
.attr("y", ((d,i) => (20 + i*20)))
|
332 |
+
// .attr("fill", 'white')
|
333 |
+
.attr("font-weight", 700);
|
334 |
+
|
335 |
+
wval = wval + rectTotal;
|
336 |
+
});
|
337 |
+
|
338 |
+
|
339 |
+
body.node().scrollBy(totalWidth, 0);
|
340 |
+
// return svg.node();
|
341 |
+
return parent.node();
|
342 |
+
}
|
343 |
+
|
344 |
+
|
345 |
|
346 |
|
347 |
|