SiddharthAK commited on
Commit
8fb581d
·
verified ·
1 Parent(s): ff45445

moving share link under input field on the sparse reps tab

Browse files
Files changed (1) hide show
  1. app.py +29 -69
app.py CHANGED
@@ -9,56 +9,9 @@ import os
9
 
10
  # Add this CSS at the top of your file, after the imports
11
  css = """
12
- /* Move share button to top-right corner */
13
- .share-button {
14
- position: fixed !important;
15
- top: 20px !important;
16
- right: 20px !important;
17
- z-index: 1000 !important;
18
- background: #4CAF50 !important;
19
- color: white !important;
20
- border-radius: 8px !important;
21
- padding: 8px 16px !important;
22
- font-weight: bold !important;
23
- box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important;
24
- }
25
-
26
- .share-button:hover {
27
- background: #45a049 !important;
28
- transform: translateY(-1px) !important;
29
- }
30
-
31
- /* Alternative positions - uncomment the one you want instead */
32
-
33
- /* Top-left corner */
34
- /*
35
- .share-button {
36
- position: fixed !important;
37
- top: 20px !important;
38
- left: 20px !important;
39
- z-index: 1000 !important;
40
- }
41
- */
42
-
43
- /* Bottom-right corner (mobile-friendly) */
44
- /*
45
- .share-button {
46
- position: fixed !important;
47
- bottom: 20px !important;
48
- right: 20px !important;
49
- z-index: 1000 !important;
50
- }
51
- */
52
-
53
- /* Bottom-center */
54
- /*
55
- .share-button {
56
- position: fixed !important;
57
- bottom: 20px !important;
58
- left: 50% !important;
59
- transform: translateX(-50%) !important;
60
- z-index: 1000 !important;
61
- }
62
  */
63
  """
64
 
@@ -130,11 +83,11 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
130
  tokenizer.unk_token_id
131
  ]:
132
  meaningful_token_ids.append(token_id)
133
-
134
  if meaningful_token_ids:
135
  # Apply mask to the current row in the batch
136
  bow_masks[i, list(set(meaningful_token_ids))] = 1
137
-
138
  return bow_masks
139
 
140
 
@@ -185,7 +138,7 @@ def get_splade_cocondenser_representation(text):
185
 
186
  info_output = f"" # Line 1
187
  info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
188
-
189
 
190
  return formatted_output, info_output
191
 
@@ -243,7 +196,7 @@ def get_splade_lexical_representation(text):
243
 
244
  info_output = f"" # Line 1
245
  info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
246
-
247
 
248
  return formatted_output, info_output
249
 
@@ -260,11 +213,11 @@ def get_splade_doc_representation(text):
260
  binary_bow_vector = create_lexical_bow_mask(
261
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
262
  ).squeeze() # Squeeze back for single output
263
-
264
  indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
265
  if not isinstance(indices, list):
266
  indices = [indices] if indices else []
267
-
268
  values = [1.0] * len(indices) # All values are 1 for binary representation
269
  token_weights = dict(zip(indices, values))
270
 
@@ -338,12 +291,12 @@ def get_splade_lexical_vector(text):
338
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
339
  dim=1
340
  )[0].squeeze()
341
-
342
  vocab_size = tokenizer_splade_lexical.vocab_size
343
  bow_mask = create_lexical_bow_mask(
344
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
345
  ).squeeze()
346
-
347
  splade_vector = splade_vector * bow_mask
348
  return splade_vector
349
  return None
@@ -377,7 +330,7 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
377
  values = [1.0] * len(indices)
378
  else:
379
  values = splade_vector[indices].cpu().tolist()
380
-
381
  token_weights = dict(zip(indices, values))
382
 
383
  meaningful_tokens = {}
@@ -408,8 +361,8 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
408
 
409
  # This is the line that will now always be split into two
410
  info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
411
-
412
-
413
  return formatted_output, info_output
414
 
415
 
@@ -451,7 +404,7 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
451
  # Combine output into a single string for the Markdown component
452
  full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
453
  full_output += "---\n\n"
454
-
455
  # Query Representation
456
  full_output += f"Query Representation ({query_model_name_display}):\n\n"
457
  full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing
@@ -460,7 +413,7 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
460
  # Document Representation
461
  full_output += f"Document Representation ({doc_model_name_display}):\n\n"
462
  full_output += doc_main_rep_str + "\n\n" + doc_info_str # Added an extra newline for better spacing
463
-
464
  return full_output
465
 
466
 
@@ -488,7 +441,13 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
488
  label="Enter your query or document text here:",
489
  placeholder="e.g., Why is Padua the nicest city in Italy?"
490
  )
491
- # New Markdown component for the info output
 
 
 
 
 
 
492
  info_output_display = gr.Markdown(
493
  value="",
494
  label="Vector Information",
@@ -496,7 +455,7 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
496
  )
497
  with gr.Column(scale=2): # Right column for the main representation output
498
  main_representation_output = gr.Markdown()
499
-
500
  # Connect the interface elements
501
  model_radio.change(
502
  fn=predict_representation_explorer,
@@ -508,15 +467,16 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
508
  inputs=[model_radio, input_text],
509
  outputs=[main_representation_output, info_output_display]
510
  )
511
-
512
  # Initial call to populate on load (optional, but good for demo)
513
  demo.load(
514
  fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
515
  outputs=[main_representation_output, info_output_display]
516
  )
517
 
518
- with gr.TabItem("Compare Encoders"): # NEW TAB
519
-
 
520
  # Define the common model choices for cleaner code
521
  model_choices = [
522
  "MLM encoder (SPLADE-cocondenser-distil)",
@@ -549,7 +509,7 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
549
  )
550
  ],
551
  outputs=gr.Markdown(),
552
- allow_flagging="never"
553
  )
554
 
555
  demo.launch()
 
9
 
10
  # Add this CSS at the top of your file, after the imports
11
  css = """
12
+ /* The global fixed positioning for the share button is no longer needed
13
+ because we'll place gr.ShareButton directly in the UI.
14
+ You can remove or comment out any previous .share-button CSS if it was there.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  */
16
  """
17
 
 
83
  tokenizer.unk_token_id
84
  ]:
85
  meaningful_token_ids.append(token_id)
86
+
87
  if meaningful_token_ids:
88
  # Apply mask to the current row in the batch
89
  bow_masks[i, list(set(meaningful_token_ids))] = 1
90
+
91
  return bow_masks
92
 
93
 
 
138
 
139
  info_output = f"" # Line 1
140
  info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
141
+
142
 
143
  return formatted_output, info_output
144
 
 
196
 
197
  info_output = f"" # Line 1
198
  info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
199
+
200
 
201
  return formatted_output, info_output
202
 
 
213
  binary_bow_vector = create_lexical_bow_mask(
214
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
215
  ).squeeze() # Squeeze back for single output
216
+
217
  indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
218
  if not isinstance(indices, list):
219
  indices = [indices] if indices else []
220
+
221
  values = [1.0] * len(indices) # All values are 1 for binary representation
222
  token_weights = dict(zip(indices, values))
223
 
 
291
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
292
  dim=1
293
  )[0].squeeze()
294
+
295
  vocab_size = tokenizer_splade_lexical.vocab_size
296
  bow_mask = create_lexical_bow_mask(
297
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
298
  ).squeeze()
299
+
300
  splade_vector = splade_vector * bow_mask
301
  return splade_vector
302
  return None
 
330
  values = [1.0] * len(indices)
331
  else:
332
  values = splade_vector[indices].cpu().tolist()
333
+
334
  token_weights = dict(zip(indices, values))
335
 
336
  meaningful_tokens = {}
 
361
 
362
  # This is the line that will now always be split into two
363
  info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
364
+
365
+
366
  return formatted_output, info_output
367
 
368
 
 
404
  # Combine output into a single string for the Markdown component
405
  full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
406
  full_output += "---\n\n"
407
+
408
  # Query Representation
409
  full_output += f"Query Representation ({query_model_name_display}):\n\n"
410
  full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing
 
413
  # Document Representation
414
  full_output += f"Document Representation ({doc_model_name_display}):\n\n"
415
  full_output += doc_main_rep_str + "\n\n" + doc_info_str # Added an extra newline for better spacing
416
+
417
  return full_output
418
 
419
 
 
441
  label="Enter your query or document text here:",
442
  placeholder="e.g., Why is Padua the nicest city in Italy?"
443
  )
444
+ # --- NEW: Place the gr.ShareButton here ---
445
+ gr.ShareButton(
446
+ value="Share My Sparse Representation",
447
+ components=[input_text, model_radio], # You can specify components to share
448
+ visible=True # Make sure it's visible
449
+ )
450
+ # --- End New ---
451
  info_output_display = gr.Markdown(
452
  value="",
453
  label="Vector Information",
 
455
  )
456
  with gr.Column(scale=2): # Right column for the main representation output
457
  main_representation_output = gr.Markdown()
458
+
459
  # Connect the interface elements
460
  model_radio.change(
461
  fn=predict_representation_explorer,
 
467
  inputs=[model_radio, input_text],
468
  outputs=[main_representation_output, info_output_display]
469
  )
470
+
471
  # Initial call to populate on load (optional, but good for demo)
472
  demo.load(
473
  fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
474
  outputs=[main_representation_output, info_output_display]
475
  )
476
 
477
+ with gr.TabItem("Compare Encoders"): # Reverted to original gr.Interface setup
478
+ gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document")
479
+
480
  # Define the common model choices for cleaner code
481
  model_choices = [
482
  "MLM encoder (SPLADE-cocondenser-distil)",
 
509
  )
510
  ],
511
  outputs=gr.Markdown(),
512
+ allow_flagging="never" # Keep this to keep the share button at the bottom of THIS interface
513
  )
514
 
515
  demo.launch()