radames commited on
Commit
f4070ba
·
1 Parent(s): 64bfae5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +369 -40
app.py CHANGED
@@ -1,45 +1,374 @@
1
- import gradio as gr
 
 
 
 
 
 
2
 
3
- def predict(text, url_params):
4
- print(url_params)
5
- return ["Hello " + text + "!!", url_params]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
7
 
8
- get_window_url_params = """
9
- function(text_input, url_params) {
10
- console.log(text_input, url_params);
11
- const params = new URLSearchParams(window.location.search);
12
- url_params = Object.fromEntries(params);
13
- return [text_input, url_params];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  }
15
- """
16
- set_window_url_params = """
17
- function(text_input, url_params) {
18
- const params = new URLSearchParams(window.location.search);
19
- params.set("text_input", text_input)
20
- url_params = Object.fromEntries(params);
21
- const queryString = '?' + params.toString();
22
- // this next line is only needed inside Spaces, so the child frame updates parent
23
- window.parent.postMessage({ queryString: queryString }, "*")
24
- return [text_input, url_params];
 
 
 
 
 
 
 
 
 
 
25
  }
26
- """
27
- with gr.Blocks() as block:
28
- url_params = gr.JSON({}, visible=True, label="URL Params")
29
- text_input = gr.Text(label="Input")
30
- text_output = gr.Text(label="Output")
31
-
32
- btn = gr.Button("Get Params")
33
- btn.click(fn=predict, inputs=[text_input, url_params],
34
- outputs=[text_output, url_params], _js=get_window_url_params)
35
-
36
- btn2 = gr.Button("Set Params")
37
- btn2.click(fn=predict, inputs=[text_input, url_params],
38
- outputs=[text_output, url_params], _js=set_window_url_params)
39
- block.load(
40
- fn=predict,
41
- inputs=[text_input, url_params],
42
- outputs=[text_output, url_params],
43
- _js=get_window_url_params
44
- )
45
- block.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+ <head>
3
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
4
+ <title>Candle Bert</title>
5
+ </head>
6
+ <body></body>
7
+ </html>
8
 
9
+ <!DOCTYPE html>
10
+ <html>
11
+ <head>
12
+ <meta charset="UTF-8" />
13
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
14
+ <style>
15
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
16
+ html,
17
+ body {
18
+ font-family: "Source Sans 3", sans-serif;
19
+ }
20
+ </style>
21
+ <script src="https://cdn.tailwindcss.com"></script>
22
+ <script type="module" src="./code.js"></script>
23
+ <script type="module">
24
+ import { hcl } from "https://cdn.skypack.dev/d3-color@3";
25
+ import { interpolateReds } from "https://cdn.skypack.dev/d3-scale-chromatic@3";
26
+ import { scaleLinear } from "https://cdn.skypack.dev/d3-scale@4";
27
+ import {
28
+ getModelInfo,
29
+ getEmbeddings,
30
+ getWikiText,
31
+ cosineSimilarity,
32
+ } from "./utils.js";
33
 
34
+ const bertWorker = new Worker("./bertWorker.js", {
35
+ type: "module",
36
+ });
37
 
38
+ const inputContainerEL = document.querySelector("#input-container");
39
+ const textAreaEl = document.querySelector("#input-area");
40
+ const outputAreaEl = document.querySelector("#output-area");
41
+ const formEl = document.querySelector("#form");
42
+ const searchInputEl = document.querySelector("#search-input");
43
+ const formWikiEl = document.querySelector("#form-wiki");
44
+ const searchWikiEl = document.querySelector("#search-wiki");
45
+ const outputStatusEl = document.querySelector("#output-status");
46
+ const modelSelectEl = document.querySelector("#model");
47
+
48
+ const sentencesRegex =
49
+ /(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z]\.)(?<=\.|\?)\s/gm;
50
+
51
+ let sentenceEmbeddings = [];
52
+ let currInputText = "";
53
+ let isCalculating = false;
54
+
55
+ function toggleTextArea(state) {
56
+ if (state) {
57
+ textAreaEl.hidden = false;
58
+ textAreaEl.focus();
59
+ } else {
60
+ textAreaEl.hidden = true;
61
+ }
62
+ }
63
+ inputContainerEL.addEventListener("focus", (e) => {
64
+ toggleTextArea(true);
65
+ });
66
+ textAreaEl.addEventListener("blur", (e) => {
67
+ toggleTextArea(false);
68
+ });
69
+ textAreaEl.addEventListener("focusout", (e) => {
70
+ toggleTextArea(false);
71
+ if (currInputText === textAreaEl.value || isCalculating) return;
72
+ populateOutputArea(textAreaEl.value);
73
+ calculateEmbeddings(textAreaEl.value);
74
+ });
75
+
76
+ modelSelectEl.addEventListener("change", (e) => {
77
+ const query = new URLSearchParams(window.location.search);
78
+ query.set("model", modelSelectEl.value);
79
+ window.history.replaceState(
80
+ {},
81
+ "",
82
+ `${window.location.pathname}?${query}`
83
+ );
84
+ window.parent.postMessage({ queryString: "?" + queryString }, "*")
85
+ if (currInputText === "" || isCalculating) return;
86
+ populateOutputArea(textAreaEl.value);
87
+ calculateEmbeddings(textAreaEl.value);
88
+ });
89
+
90
+ function populateOutputArea(text) {
91
+ currInputText = text;
92
+ const sentences = text.split(sentencesRegex);
93
+
94
+ outputAreaEl.innerHTML = "";
95
+ for (const [id, sentence] of sentences.entries()) {
96
+ const sentenceEl = document.createElement("span");
97
+ sentenceEl.id = `sentence-${id}`;
98
+ sentenceEl.innerText = sentence + " ";
99
+ outputAreaEl.appendChild(sentenceEl);
100
+ }
101
+ }
102
+ formEl.addEventListener("submit", async (e) => {
103
+ e.preventDefault();
104
+ if (isCalculating || currInputText === "") return;
105
+ toggleInputs(true);
106
+ const modelID = modelSelectEl.value;
107
+ const { modelURL, tokenizerURL, configURL, search_prefix } =
108
+ getModelInfo(modelID);
109
+
110
+ const text = searchInputEl.value;
111
+ const query = search_prefix + searchInputEl.value;
112
+ outputStatusEl.classList.remove("invisible");
113
+ outputStatusEl.innerText = "Calculating embeddings for query...";
114
+ isCalculating = true;
115
+ const out = await getEmbeddings(
116
+ bertWorker,
117
+ modelURL,
118
+ tokenizerURL,
119
+ configURL,
120
+ modelID,
121
+ [query]
122
+ );
123
+ outputStatusEl.classList.add("invisible");
124
+ const queryEmbeddings = out.output[0];
125
+ // calculate cosine similarity with all sentences given the query
126
+ const distances = sentenceEmbeddings
127
+ .map((embedding, id) => ({
128
+ id,
129
+ similarity: cosineSimilarity(queryEmbeddings, embedding),
130
+ }))
131
+ .sort((a, b) => b.similarity - a.similarity)
132
+ // getting top 10 most similar sentences
133
+ .slice(0, 10);
134
+
135
+ const colorScale = scaleLinear()
136
+ .domain([
137
+ distances[distances.length - 1].similarity,
138
+ distances[0].similarity,
139
+ ])
140
+ .range([0, 1])
141
+ .interpolate(() => interpolateReds);
142
+ outputAreaEl.querySelectorAll("span").forEach((el) => {
143
+ el.style.color = "unset";
144
+ el.style.backgroundColor = "unset";
145
+ });
146
+ distances.forEach((d) => {
147
+ const el = outputAreaEl.querySelector(`#sentence-${d.id}`);
148
+ const color = colorScale(d.similarity);
149
+ const fontColor = hcl(color).l < 70 ? "white" : "black";
150
+ el.style.color = fontColor;
151
+ el.style.backgroundColor = color;
152
+ });
153
+
154
+ outputAreaEl
155
+ .querySelector(`#sentence-${distances[0].id}`)
156
+ .scrollIntoView({
157
+ behavior: "smooth",
158
+ block: "center",
159
+ inline: "nearest",
160
+ });
161
+
162
+ isCalculating = false;
163
+ toggleInputs(false);
164
+ });
165
+ async function calculateEmbeddings(text) {
166
+ isCalculating = true;
167
+ toggleInputs(true);
168
+ const modelID = modelSelectEl.value;
169
+ const { modelURL, tokenizerURL, configURL, document_prefix } =
170
+ getModelInfo(modelID);
171
+
172
+ const sentences = text.split(sentencesRegex);
173
+ const allEmbeddings = [];
174
+ outputStatusEl.classList.remove("invisible");
175
+ for (const [id, sentence] of sentences.entries()) {
176
+ const query = document_prefix + sentence;
177
+ outputStatusEl.innerText = `Calculating embeddings: sentence ${
178
+ id + 1
179
+ } of ${sentences.length}`;
180
+ const embeddings = await getEmbeddings(
181
+ bertWorker,
182
+ modelURL,
183
+ tokenizerURL,
184
+ configURL,
185
+ modelID,
186
+ [query],
187
+ updateStatus
188
+ );
189
+ allEmbeddings.push(embeddings);
190
+ }
191
+ outputStatusEl.classList.add("invisible");
192
+ sentenceEmbeddings = allEmbeddings.map((e) => e.output[0]);
193
+ isCalculating = false;
194
+ toggleInputs(false);
195
+ }
196
+
197
+ function updateStatus(data) {
198
+ if ("status" in data) {
199
+ if (data.status === "loading") {
200
+ outputStatusEl.innerText = data.message;
201
+ outputStatusEl.classList.remove("invisible");
202
+ }
203
  }
204
+ }
205
+ function toggleInputs(state) {
206
+ const interactive = document.querySelectorAll(".interactive");
207
+ interactive.forEach((el) => {
208
+ if (state) {
209
+ el.disabled = true;
210
+ } else {
211
+ el.disabled = false;
212
+ }
213
+ });
214
+ }
215
+
216
+ searchWikiEl.addEventListener("input", () => {
217
+ searchWikiEl.setCustomValidity("");
218
+ });
219
+
220
+ formWikiEl.addEventListener("submit", async (e) => {
221
+ e.preventDefault();
222
+ if ("example" in e.submitter.dataset) {
223
+ searchWikiEl.value = e.submitter.innerText;
224
  }
225
+ const text = searchWikiEl.value;
226
+
227
+ if (isCalculating || text === "") return;
228
+ try {
229
+ const wikiText = await getWikiText(text);
230
+ searchWikiEl.setCustomValidity("");
231
+ textAreaEl.innerHTML = wikiText;
232
+ populateOutputArea(wikiText);
233
+ calculateEmbeddings(wikiText);
234
+ searchWikiEl.value = "";
235
+ } catch {
236
+ searchWikiEl.setCustomValidity("Invalid Wikipedia article name");
237
+ searchWikiEl.reportValidity();
238
+ }
239
+ });
240
+ document.addEventListener("DOMContentLoaded", () => {
241
+ const query = new URLSearchParams(window.location.search);
242
+ const modelID = query.get("model");
243
+ if (modelID) {
244
+ modelSelectEl.value = modelID;
245
+ modelSelectEl.dispatchEvent(new Event("change"));
246
+ }
247
+ });
248
+ </script>
249
+ </head>
250
+ <body class="container max-w-4xl mx-auto p-4">
251
+ <main class="grid grid-cols-1 gap-5 relative">
252
+ <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
253
+ <div>
254
+ <h1 class="text-5xl font-bold">Candle BERT</h1>
255
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
256
+ <p class="max-w-lg">
257
+ Running sentence embeddings and similarity search in the browser using
258
+ the Bert Model written with
259
+ <a
260
+ href="https://github.com/huggingface/candle/"
261
+ target="_blank"
262
+ class="underline hover:text-blue-500 hover:no-underline"
263
+ >Candle
264
+ </a>
265
+ and compiled to Wasm. Embeddings models from are from
266
+ <a
267
+ href="https://huggingface.co/sentence-transformers/"
268
+ target="_blank"
269
+ class="underline hover:text-blue-500 hover:no-underline">
270
+ Sentence Transformers
271
+ </a>
272
+ and
273
+ <a
274
+ href="https://huggingface.co/intfloat/"
275
+ target="_blank"
276
+ class="underline hover:text-blue-500 hover:no-underline">
277
+ Liang Wang - e5 Models
278
+ </a>
279
+ </p>
280
+ </div>
281
+
282
+ <div>
283
+ <label for="model" class="font-medium block">Models Options: </label>
284
+ <select
285
+ id="model"
286
+ class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max">
287
+ <option value="gte_tiny">gte_tiny (45.5 MB)</option>
288
+ <option value="intfloat_e5_small_v2" selected>
289
+ intfloat/e5-small-v2 (133 MB)
290
+ </option>
291
+ <option value="intfloat_e5_base_v2">
292
+ intfloat/e5-base-v2 (438 MB)
293
+ </option>
294
+ <option value="intfloat_multilingual_e5_small">
295
+ intfloat/multilingual-e5-small (471 MB)
296
+ </option>
297
+ <option value="sentence_transformers_all_MiniLM_L6_v2">
298
+ sentence-transformers/all-MiniLM-L6-v2 (90.9 MB)
299
+ </option>
300
+ <option value="sentence_transformers_all_MiniLM_L12_v2">
301
+ sentence-transformers/all-MiniLM-L12-v2 (133 MB)
302
+ </option>
303
+ </select>
304
+ </div>
305
+ <div>
306
+ <h3 class="font-medium">Examples:</h3>
307
+ <form
308
+ id="form-wiki"
309
+ class="flex text-xs rounded-md justify-between w-min gap-3">
310
+ <input type="submit" hidden />
311
+
312
+ <button data-example class="disabled:cursor-not-allowed interactive">
313
+ Pizza
314
+ </button>
315
+ <button data-example class="disabled:cursor-not-allowed interactive">
316
+ Paris
317
+ </button>
318
+ <button data-example class="disabled:cursor-not-allowed interactive">
319
+ Physics
320
+ </button>
321
+ <input
322
+ type="text"
323
+ id="search-wiki"
324
+ title="Search Wikipedia article by title"
325
+ class="font-light py-0 mx-1 resize-none outline-none w-32 disabled:cursor-not-allowed interactive"
326
+ placeholder="Load Wikipedia article..." />
327
+ <button
328
+ title="Search Wikipedia article and load into input"
329
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal px-2 py-1 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive">
330
+ Load
331
+ </button>
332
+ </form>
333
+ </div>
334
+ <form
335
+ id="form"
336
+ class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
337
+ <input type="submit" hidden />
338
+ <input
339
+ type="text"
340
+ id="search-input"
341
+ class="font-light w-full px-3 py-2 mx-1 resize-none outline-none interactive disabled:cursor-not-allowed"
342
+ placeholder="Search query here..." />
343
+ <button
344
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive">
345
+ Search
346
+ </button>
347
+ </form>
348
+ <div>
349
+ <h3 class="font-medium">Input text:</h3>
350
+ <div class="flex justify-between items-center">
351
+ <div class="rounded-md inline text-xs">
352
+ <span id="output-status" class="m-auto font-light invisible"
353
+ >C</span
354
+ >
355
+ </div>
356
+ </div>
357
+ <div
358
+ id="input-container"
359
+ tabindex="0"
360
+ class="min-h-[250px] bg-slate-100 text-gray-500 rounded-md p-4 flex flex-col gap-2 relative">
361
+ <textarea
362
+ id="input-area"
363
+ hidden
364
+ value=""
365
+ placeholder="Input text to perform semantic similarity search..."
366
+ class="flex-1 resize-none outline-none left-0 right-0 top-0 bottom-0 m-4 absolute interactive disabled:invisible"></textarea>
367
+ <p id="output-area" class="grid-rows-2">
368
+ Input text to perform semantic similarity search...
369
+ </p>
370
+ </div>
371
+ </div>
372
+ </main>
373
+ </body>
374
+ </html>